0%

311. Sparse Matrix Multiplication

考点是如何优化空间跟时间
mxn * nxp => mxp
A[i][j] * t[j][k] 累加到 res[i][k]
思路是遍历A,对每个非零A[i][j],进行上述累加操作
普通矩阵乘法则是以最后结果矩阵为遍历顺序做点积运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
public:
vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) {
int m = size(A), n = size(B), p = size(B[0]);
vector<vector<pair<int, int>>> t(n);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < p; ++j) {
if (B[i][j]) {
t[i].emplace_back(j, B[i][j]);
}
}
}
vector<vector<int>> res(m, vector<int>(p));
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
if (A[i][j] == 0) continue;
for (auto [k, v] : t[j]) {
res[i][k] += A[i][j] * v;
}
}
}
return res;
}
};

先把矩阵B变成邻接表C,记录每个元素的列号
因为res[i][j]是A的第i行和B的第j列的点积,所以只需要遍历矩阵A
将A[i][k]和C[k](即原来B的第k行的所有非0元素)的每个元素相乘并累加到对应的res[i][j]上即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution {
public:
vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) {
int m = A.size(), n = B.size(), p = B[0].size();
vector<vector<pair<int, int>>> C(n);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < p; ++j) {
if (B[i][j] != 0) {
C[i].emplace_back(j, B[i][j]);
}
}
}
vector<vector<int>> res(m, vector<int>(p));
for (int k = 0; k < n; ++k) {
for (int i = 0; i < m; ++i) {
if (A[i][k] != 0) {
for (auto [j, val] : C[k]) {
res[i][j] += A[i][k] * val; // res[i][j]A的第i行和B的第j列的点积
}
}
}
}
return res;
}
};