Skip to content

23555: 节省存储的矩阵乘法

matrices, http://cs101.openjudge.cn/practice/23555/

由于矩阵存储非常耗费空间,一个长度n宽度m的矩阵需要花费n*m的存储,因此我们选择用另一种节省空间的方法表示矩阵。一个矩阵X可以表示为三元组的序列,每个三元组代表(行号,列号,元素值),如果元素值是0则我们不存储这个三元组,这样对于0很多的大型矩阵,我们节省了很多存储空间。现在我们有两个用这种方式表示的矩阵X和Y,我们想要计算这两个矩阵的乘积,并且也用三元组形式表达,该如何完成呢。

如果不知道矩阵如何相乘,可以参考:http://cs101.openjudge.cn/practice/18161

输入

输入第一行是三个整数n,m1, m2,两个矩阵X,Y的维度都是n*n,m1是矩阵X中的非0元素数,m2是矩阵Y中的非0元素数。 之后是m1行,每行是一个三元组(行号,列号,元素值),代表X矩阵的元素值,注意行列编号都从0开始。 之后是m2行,每行是一个三元组(行号,列号,元素值),代表Y矩阵的元素值,注意行列编号都从0开始。

输出

输出是m3行,代表X和Y两个矩阵乘积中的非0元素的数目,按照先行号后列号的方式递增排序。 每行仍然是前述的三元组形式。

样例输入

Sample Input1:
3 3 2
0 0 1
1 0 -1
1 2 3
0 0 7
2 2 1

Sample Output1:
0 0 7
1 0 -7
1 2 3

解释:
A = [
  [ 1, 0, 0],
  [-1, 0, 3],
  [0, 0, 0]
]

B = [
  [ 7, 0, 0 ],
  [ 0, 0, 0 ],
  [ 0, 0, 1 ]
]

A*B = [
[7,0,0],
[-7,0,3],
[0,0,0]]

样例输出

Sample Input2:
2 2 4
0 0 1
1 1 1
0 0 2
0 1 3
1 0 4
1 1 5

Sample Output2:
0 0 2
0 1 3
1 0 4
1 1 5

解释:
A = [
[1,0],
[0,1]
]

B = [
[2,3],
[4,5]
]

A*B = [
[2,3],
[4,5]
]

提示:tags: implementation,matrices

来源:2021fall-cs101, hy

问题理解

矩阵 X、Y 是稀疏矩阵(只存储非零项),我们要计算矩阵乘法 Z = X * Y,并以相同稀疏形式输出。

普通矩阵乘法公式:

Z[i][j]=kX[i][k]Y[k][j]

输入中:

  • n 是矩阵的大小(n×n)
  • m1 是 X 的非零项数
  • m2 是 Y 的非零项数 每个三元组 (r, c, v) 表示:矩阵[r][c] = v

输出要:

  • 只输出非零项
  • 先按行号升序、再按列号升序排序

要求用节省内存的方式实现,不能还原矩阵的方式实现。

思路:对第一个矩阵里的某元素和第二个矩阵里的某元素,若这二者能相乘,结果矩阵的对应元素加上这二者的乘积,用字典记录,最后输出所有值不为0的元素

python
from collections import defaultdict
n,m1,m2=map(int,input().split())
d=defaultdict(int)
l1,l2=[],[]
for i in range(m1):
    l1.append(tuple(map(int,input().split())))
for i in range(m2):
    l2.append(tuple(map(int,input().split())))
for i in range(m1):
    for j in range(m2):
        if l1[i][1]==l2[j][0]:
            d[(l1[i][0],l2[j][1])]+=l1[i][2]*l2[j][2]
for i in range(n):
    for j in range(n):
        if d[(i,j)]:
            print(i,j,d[(i,j)])

思路:掌握好矩阵运算法则,把索引搞清楚就行。

python
n,m1,m2 = map(int,input().split())

matrix1 = {}
matrix2 = {}

for _ in range(m1):
    r,c,v = map(int,input().split())
    matrix1[(r,c)] = v
for _ in range(m2):
    r,c,v = map(int,input().split())
    matrix2[(r,c)] = v

result = {}

for (r1,c1),v1 in matrix1.items():
    for (r2,c2),v2 in matrix2.items():
        if c1 == r2:
            if (r1,c2) in result:
                result[(r1,c2)] += v1 * v2
            else:
                result[(r1,c2)] = v1 * v2

for (r,c),v in sorted(result.items()):
    print(r,c,v)
python
# 节省存储的矩阵乘法
# 采用字典,对每一个矩阵,键(i,j)对应值X(i,j)或者Y(i,j), 
# 遍历所有的X(i,k),寻找Y(k,j),乘积累和得到C(i,j)
n,m1,m2 = map(int,input().split())
X = {}
for _ in range (m1):
    i,j,value = map(int,input().split())
    X[(i,j)] = value
Y = {}
for _ in range (m2):
    i,j,value = map(int,input().split())
    Y[(i,j)] = value
C = {}
'''
这里的.items()用于返回字典中所有键值对的视图对象,视图对
象包含字典中的键值对,并以 (key, value) 的形式组成
一个可迭代的对象,可以在循环中方便地同时获取键和值。在需要对比或更新两个字典的键值对时,
.items() 可以直接获取字典的所有键值对,方便查找或修改。
'''
for (i,k1),value1 in X.items():
    for (k2,j),value2 in Y.items():
        if k1 == k2:
            if (i,j) not in C:
                C[(i,j)] = 0
            C[(i,j)] += value1*value2
'''
题目要求按照行列升序输出
对字典,sorted()默认按照字典的键排序,如果键是高维的,
从第一位开始优先排序,第一位相同,按照第二位排序,以此类推
'''
for (i,j),value in sorted(C.items()):
    print(i,j,value)

18161:矩阵运算。http://cs101.openjudge.cn/practice/18161

矩阵乘法运算必须要前一个矩阵的列数与后一个矩阵的行数相同, 如m行n列的矩阵A与n行p列的矩阵B相乘,可以得到m行p列的矩阵C, 矩阵C的每个元素都由A的对应行中的元素与B的对应列中的元素一一相乘并求和得到, 即C[i][j] = A[i][0]*B[0][j] + A[i][1]*B[1][j] + …… +A[i][n-1]*B[n-1][j]

(C[i][j]表示C矩阵中第i行第j列元素)。

即,cij=Σaikbkj

输入放到矩阵里面就好了,在计算乘法之后一旦有值不是0就可以在遍历中直接把位置和值输出。

python
# 汤伟杰,24信息管理系
n, m1, m2 = map(int, input().split())
a = [[0] * n for _ in range(n)]
b = [[0] * n for _ in range(n)]
for _ in range(m1):
    x, y, v = map(int, input().split())
    a[x][y] = v
for _ in range(m2):
    x, y, v = map(int, input().split())
    b[x][y] = v
c = [[0] * n for _ in range(n)]
for i in range(n):
    for j in range(n):
        c[i][j] = sum(a[i][k] * b[k][j] for k in range(n))
        if c[i][j] != 0:
            print(i, j, c[i][j])

思路:将A数组以第二坐标为索引,B数组以第一坐标为索引,依次相乘

python
import sys
data=list(map(int,sys.stdin.readline().split()))
n=data[0]
m_1=data[1]
m_2=data[2]
matrix_a={}
matrix_b={}
for i in range(n):
    matrix_a[i]=[]
    matrix_b[i]=[]
for i in range(m_1):
    data=list(map(int,sys.stdin.readline().split()))
    matrix_a[data[1]].append((data[0],data[2]))
for i in range(m_2):
    data=list(map(int,sys.stdin.readline().split()))
    matrix_b[data[0]].append((data[1],data[2]))
ans=[]
for i in range(n):
    ans.append([0]*n)
for i in range(n):
    for j in matrix_a[i]:
        for k in matrix_b[i]:
            ans[j[0]][k[0]]+=j[1]*k[1]
for i in range(n):
    for j in range(n):
        if ans[i][j]!=0:
            print(' '.join(map(str,[i,j,ans[i][j]])))
python
def getSparseRepresentation(A):
    posList = []
    m = len(A)
    n = len(A[0])
    for i in range(m):
        for j in range(n):
            if A[i][j] != 0:
                posList.append([i,j, A[i][j]])
    return posList

n, m1, m2 = [int(x) for x in input().split()]
A, B = [], []
for i in range(m1):
    A.append([int(x) for x in input().split()])
for i in range(m2):
    B.append([int(x) for x in input().split()])
    
res = [[0 for i in range(n)] for j in range(n)]
for xA, yA, valA in A:
    for xB, yB, valB in B:
        if yA == xB:
            res[xA][yB] += valA * valB
            
res = getSparseRepresentation(res)

for x in res:
    print(x[0], x[1], x[2])

思路:为了高效计算:

  • 先把 Y 的非零项按行分组:Y_rows[k] = [(j, val), ...]
  • 然后遍历 X 的每个非零项 (i, k, valx)
    • 对 Y 的第 k 行的每个 (j, valy)
      • Z[i][j] += valx * valy
  • 最后只输出非零项。
python
n, m1, m2 = map(int, input().split())

X = [tuple(map(int, input().split())) for _ in range(m1)]
Y = [tuple(map(int, input().split())) for _ in range(m2)]

# 将 Y 按行分组:Y_rows[k] = [(col, val), ...]
Y_rows = {}
for r, c, v in Y:
    Y_rows.setdefault(r, []).append((c, v))

# 存储结果矩阵的非零项
res = {}

# 遍历 X 的非零项
for i, k, vx in X:
    if k not in Y_rows:
        continue
    for j, vy in Y_rows[k]:
        res[(i, j)] = res.get((i, j), 0) + vx * vy

# 过滤掉值为 0 的(可能相加后变为 0)
items = [(i, j, v) for (i, j), v in res.items() if v != 0]

# 按行号、列号排序输出
items.sort()

for i, j, v in items:
    print(i, j, v)

时间复杂度

  • X 中非零项:m1,Y 中非零项:m2
  • 若稀疏性好,整体复杂度约为 O(m1 × 平均非零列数),远优于 O(n³)。

思路是经典稀疏矩阵乘法实现:

只处理非零元素,利用中间索引 j 连接 A 和 B,通过按行存储 B 来快速查找 B[j][k]

虽然乘法涉及“列”,但实际访问模式要求我们按“行”组织 B,这是性能优化的关键所在。

python
from collections import defaultdict

n, m1, m2 = map(int, input().split())

A = defaultdict(list)
for _ in range(m1):
    i, j, v = map(int, input().split())
    A[i].append((j, v))

B = defaultdict(list)
for _ in range(m2):
    i, j, v = map(int, input().split())
    B[i].append((j, v))

C = defaultdict(lambda: defaultdict(int))

for i in A:
    for j, a_val in A[i]:
        if j in B:
            for k, b_val in B[j]:
                C[i][k] += a_val * b_val

result = []
for i in sorted(C.keys()):
    for k in sorted(C[i].keys()):
        if C[i][k] != 0:
            result.append((i, k, C[i][k]))

for r in result:
    print(*r)

初始化结果矩阵 C

python
C = defaultdict(lambda: defaultdict(int))

这是一个嵌套的 defaultdict

  • C[i][k] 表示结果矩阵 $ C = A \times B $ 中 (i, k) 位置的值。
  • 外层:C[i] 是一个 defaultdict
  • 内层:每个 C[i][k] 初始为 0(因为 lambda: defaultdict(int)

💡 这样可以直接写 C[i][k] += x,无需判断键是否存在。

python
n, m1, m2 = map(int, input().split())

X = []
for _ in range(m1):
    r, c, v = map(int, input().split())
    X.append((r, c, v))

Y = []
for _ in range(m2):
    r, c, v = map(int, input().split())
    Y.append((r, c, v))

Y_by_row = [[] for _ in range(n)]
for r, c, v in Y:
    Y_by_row[r].append((c, v))

result = {}
for rx, cx, vx in X:
    for cy, vy in Y_by_row[cx]:
        result[(rx, cy)] = result.get((rx, cy), 0) + vx * vy

output = []
for (r, c), val in result.items():
    if val != 0:
        output.append((r, c, val))

output.sort()

for r, c, v in output:
    print(r, c, v)

思路:defaultdict,左边的矩阵按行储存,右边的矩阵按列储存

python
from collections import defaultdict

n, m1, m2 = map(int, input().split())
A = defaultdict(dict)
B = defaultdict(dict)
for _ in range(m1):
    r, c, v = map(int, input().split())
    A[r][c] = v
for _ in range(m2):
    r, c, v = map(int, input().split())
    B[c][r] = v
ans = []
for r in A.keys():
    for c in B.keys():
        cur = 0
        for k in A[r].keys() & B[c].keys():
            cur += A[r][k] * B[c][k]
        if cur != 0:
            ans.append([r, c, cur])
ans.sort()
for l in ans:
    print(*l)

思路:我一开始没看条件,用还原矩阵的方法先写了一遍,ac了,后来发现有要求。重新写了一下,其实只要搞清楚矩阵的运算规律即可,判断最后哪些部位不是零即可。

python
n,m1,m2=map(int,input().split())

row_a={}
for _ in range(m1):
    i,k,a=map(int,input().split())
    if i not in row_a:
        row_a[i]=[]
    row_a[i].append((k, a))

row_b={}
for _ in range(m2):
    k,j,b=map(int,input().split())
    if k not in row_b:
        row_b[k]=[]
    row_b[k].append((j, b))

result={}
for i in row_a:
    for k,a in row_a[i]:
        if k not in row_b:
            continue
        for j,b in row_b[k]:
            key=(i, j)
            if key in result:
                result[key]+=a * b
            else:
                result[key]=a * b

for (i,j),val in sorted(result.items()):
    print(i,j,val)