稀疏矩阵乘法

您所在的位置:网站首页 矩阵乘法的时间复杂度最低是多少 稀疏矩阵乘法

稀疏矩阵乘法

2024-04-30 20:00| 来源: 网络整理| 查看: 265

给定两个 稀疏矩阵 A 和 B,返回AB的结果。您可以假设A的列数等于B的行数。

题目地址:https://www.jiuzhang.com/solution/sparse-matrix-multiplication/#tag-other

本参考程序来自九章算法,由 @Roger 提供。

题目解法:

时间复杂度分析:假设矩阵A,B均为 n x n 的矩阵,矩阵A的稀疏系数为a,矩阵B的稀疏系数为b,a,b∈[0, 1],矩阵越稀疏,系数越小。

方法一:暴力,不考虑稀疏性Time (n^2 * (1 + n)) = O(n^2 + n^3)Space O(1)

方法二:改进,仅考虑A的稀疏性Time O(n^2 * (1 + a * n) = O(n^2 + a * n^3)Space O(1)

方法三(最优):进一步改进,考虑A与B的稀疏性Time O(n^2 * (1 + a * b * n)) = O(n^2 + a * b * n^3)Space O(b * n^2)

方法四:另外一种思路,将矩阵A, B非0元素的坐标抽出,对非0元素进行运算和结果累加Time O(2 * n^2 + a * b * n^4) = O(n^2 + a * b * n^4)Space O(a * n^2 + b * n^2)

解读:矩阵乘法的两种形式,假设 A(n, t) * B(t, m) = C(n, m)

// 形式一:外层两个循环遍历C (常规解法) for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { for (int k = 0; k < t; k++) { C[i][j] += A[i][k] * B[k][j]; } } } // 或者写成下面这样子 for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { int sum = 0; for (int k = 0; k < t; k++) { sum += A[i][k] * B[k][j]; } C[i][j] = sum; } } // 形式二:外层两个循环遍历A for (int i = 0; i < n; i++) { for (int k = 0; k < t; k++) { for (int j = 0; j < m; j++) { C[i][j] += A[i][k] * B[k][j]; } } }

两种方法的区别

代码上的区别(表象):调换了第二三层循环的顺序

核心区别(内在):形式一以C为核心进行遍历,每个C[i][j]只会被计算一次,就是最终答案。形式二以A为核心进行遍历,每个A[i][k] 乘上 B[k][j]之后,会被累加到 C[i][j],每个C[i][j]将被累加t次。

 

举个例子,若A矩阵2x3,B矩阵3x2,C矩阵2x2 A B C a00 , a01 , a02 b00 , b01 c00 , c01 a10 , a11 , a12 b10 , b11 c10 , c11 b20 , b21 形式一的计算过程:遍历C,假设遍历到c00,计算c00 = a00 * b00 + a01 * b10 + a02 * b20 形式二的计算过程:遍历A, 假设遍历到a00,a00 * b00 累加到 c00, a00 * b01 累加到c01; 假设遍历到a01,a01 * b10 累加到 c00, a01 * b11 累加到c01;

 

 再回到本题目,可以发现是否为稀疏矩阵,对于上述形式一来说,并无法进行优化,因为是以C为核心但是对于形式二来说,以A为核心,若A[i][k]为0,那么该元素就不必进行对应相乘并累加的操作了。故方法二,就是基于此进行优化的。

// 方法一 public class Solution { /** * @param A: a sparse matrix * @param B: a sparse matrix * @return: the result of A * B */ public int[][] multiply(int[][] A, int[][] B) { // write your code here // A(n, t) * B(t, m) = C(n, m) int n = A.length; int t = A[0].length; int m = B[0].length; int[][] C = new int[n][m]; for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { int sum = 0; for (int k = 0; k < t; k++) { sum += A[i][k] * B[k][j]; } C[i][j] = sum; } } return C; } } // 方法二 public class Solution { /** * @param A: a sparse matrix * @param B: a sparse matrix * @return: the result of A * B */ public int[][] multiply(int[][] A, int[][] B) { // write your code here // A(n, t) * B(t, m) = C(n, m) int n = A.length; int t = A[0].length; int m = B[0].length; int[][] C = new int[n][m]; for (int i = 0; i < n; i++) { for (int k = 0; k < t; k++) { if (A[i][k] == 0) { continue; } for (int j = 0; j < m; j++) { C[i][j] += A[i][k] * B[k][j]; } } } return C; } } // 方法三 public class Solution { /** * @param A: a sparse matrix * @param B: a sparse matrix * @return: the result of A * B */ public int[][] multiply(int[][] A, int[][] B) { // write your code here // A(n, t) * B(t, m) = C(n, m) int n = A.length; int t = A[0].length; int m = B[0].length; int[][] C = new int[n][m]; List B_nonZero_colIndices = new ArrayList(); for (int k = 0; k < t; k++) { List colIndices = new ArrayList(); for (int j = 0; j < m; j++) { if (B[k][j] != 0) { colIndices.add(j); } } B_nonZero_colIndices.add(colIndices); } for (int i = 0; i < n; i++) { for (int k = 0; k < t; k++) { if (A[i][k] == 0) { continue; } for (int colIndex : B_nonZero_colIndices.get(k)) { C[i][colIndex] += A[i][k] * B[k][colIndex]; } } } return C; } } // 方法四 public class Solution { /** * @param A: a sparse matrix * @param B: a sparse matrix * @return: the result of A * B */ public int[][] multiply(int[][] A, int[][] B) { // write your code here // A(n, t) * B(t, m) = C(n, m) int n = A.length; int t = A[0].length; int m = B[0].length; int[][] C = new int[n][m]; List A_Points = getNonZeroPoints(A); List B_Points = getNonZeroPoints(B); for (Point pA : A_Points) { for (Point pB : B_Points) { if (pA.j == pB.i) { C[pA.i][pB.j] += A[pA.i][pA.j] * B[pB.i][pB.j]; } } } return C; } private List getNonZeroPoints(int[][] matrix) { List nonZeroPoints = new ArrayList(); for (int i = 0; i < matrix.length; i++) { for (int j = 0; j < matrix[0].length; j++) { if (matrix[i][j] != 0) { nonZeroPoints.add(new Point(i, j)); } } } return nonZeroPoints; } class Point { int i, j; Point(int i, int j) { this.i = i; this.j = j; } } }

 



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3