考点是如何优化空间跟时间
mxn * nxp => mxp
A[i][j] * t[j][k] 累加到 res[i][k]
思路是遍历A,对每个非零A[i][j],进行上述累加操作
普通矩阵乘法则是以最后结果矩阵为遍历顺序做点积运算
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| class Solution { public: vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) { int m = size(A), n = size(B), p = size(B[0]); vector<vector<pair<int, int>>> t(n); for (int i = 0; i < n; ++i) { for (int j = 0; j < p; ++j) { if (B[i][j]) { t[i].emplace_back(j, B[i][j]); } } } vector<vector<int>> res(m, vector<int>(p)); for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { if (A[i][j] == 0) continue; for (auto [k, v] : t[j]) { res[i][k] += A[i][j] * v; } } } return res; } };
|
先把矩阵B变成邻接表C,记录每个元素的列号
因为res[i][j]是A的第i行和B的第j列的点积,所以只需要遍历矩阵A
将A[i][k]和C[k](即原来B的第k行的所有非0元素)的每个元素相乘并累加到对应的res[i][j]上即可
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| class Solution { public: vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) { int m = A.size(), n = B.size(), p = B[0].size(); vector<vector<pair<int, int>>> C(n); for (int i = 0; i < n; ++i) { for (int j = 0; j < p; ++j) { if (B[i][j] != 0) { C[i].emplace_back(j, B[i][j]); } } } vector<vector<int>> res(m, vector<int>(p)); for (int k = 0; k < n; ++k) { for (int i = 0; i < m; ++i) { if (A[i][k] != 0) { for (auto [j, val] : C[k]) { res[i][j] += A[i][k] * val; } } } } return res; } };
|