4.2 矩阵乘法的Strassen算法

news/2024/10/30 13:34:59/

1.伪代码以及用到的公式

2.代码

package collection;
​
public class StrassenMatrixMultiplication {public static int[][] multiply(int[][] a, int[][] b) {int n = a.length;int[][] result = new int[n][n];
​if (n == 1) {result[0][0] = a[0][0] * b[0][0];} else {int[][] a11 = new int[n / 2][n / 2];int[][] a12 = new int[n / 2][n / 2];int[][] a21 = new int[n / 2][n / 2];int[][] a22 = new int[n / 2][n / 2];int[][] b11 = new int[n / 2][n / 2];int[][] b12 = new int[n / 2][n / 2];int[][] b21 = new int[n / 2][n / 2];int[][] b22 = new int[n / 2][n / 2];
​// Divide matrices into sub-matrices of size n/2 x n/2divide(a, a11, 0, 0);divide(a, a12, 0, n / 2);divide(a, a21, n / 2, 0);divide(a, a22, n / 2, n / 2);divide(b, b11, 0, 0);divide(b, b12, 0, n / 2);divide(b, b21, n / 2, 0);divide(b, b22, n / 2, n / 2);
​// Calculate p1 to p7int[][] p1 = multiply(add(a11, a22), add(b11, b22));int[][] p2 = multiply(add(a21, a22), b11);int[][] p3 = multiply(a11, sub(b12, b22));int[][] p4 = multiply(a22, sub(b21, b11));int[][] p5 = multiply(add(a11, a12), b22);int[][] p6 = multiply(sub(a21, a11), add(b11, b12));int[][] p7 = multiply(sub(a12, a22), add(b21, b22));
​// Calculate sub-matrices of result matrixint[][] c11 = add(sub(add(p1, p4), p5), p7);int[][] c12 = add(p3, p5);int[][] c21 = add(p2, p4);int[][] c22 = add(sub(add(p1, p3), p2), p6);
​// Combine sub-matrices into result matrixcombine(c11, result, 0, 0);combine(c12, result, 0, n / 2);combine(c21, result, n / 2, 0);combine(c22, result, n / 2, n / 2);}return result;}
​// Divide matrix into sub-matricespublic static void divide(int[][] parent, int[][] child, int i, int j) {for (int m = 0, n = i; m < child.length; m++, n++) {for (int p = 0, q = j; p < child.length; p++, q++) {child[m][p] = parent[n][q];}}}
​// Combine sub-matrices into matrixpublic static void combine(int[][] child, int[][] parent, int i, int j) {for (int m = 0, n = i; m < child.length; m++, n++) {for (int p = 0, q = j; p < child.length; p++, q++) {parent[n][q] = child[m][p];}}}
​// Add two matricespublic static int[][] add(int[][] a, int[][] b) {int n = a.length;int[][] result = new int[n][n];for (int i = 0; i < n; i++) {for (int j = 0; j < n; j++) {result[i][j] = a[i][j] + b[i][j];}}return result;}
​// Subtract two matricespublic static int[][] sub(int[][] a, int[][] b) {int n = a.length;int[][] result = new int[n][n];for (int i = 0; i < n; i++) {for (int j = 0; j < n; j++) {result[i][j] = a[i][j] - b[i][j];}}return result;}
}
​
​
​
​

3.原理

  1. 如果 n = 1,则每个矩阵包含一个元素。执行单个标量乘法和单个标量加法,就像 MATRIX-Multiply-RECURSIVE 的第3行那样,计算 Θ (1)的时间,然后返回。否则,将输入矩阵 A、 B 和输出矩阵 C 划分为 n/2 × n/2子矩阵,如方程(4.2)所示。这一步通过索引计算 Θ (1)的时间,就像在 MATRIX-Multiply-RECURSIVE 中一样。

  2. 创建 n/2 × n/2矩阵 S~1~,S~2~,... ,S~10~,每个矩阵都是步骤1中两个子矩阵的和或差。建立并归零七个 n/2 × n/2矩阵 P~1~,P~2~,... ,P~7~的条目以保持七个 n/2 × n/2矩阵乘积。所有17个矩阵都可以在 Θ (n2)时间内创建并初始化 P~i~

  3. 使用步骤1中的子矩阵和步骤2中创建的矩阵 S1,S2,... ,S10,递归地计算7个矩阵乘积 P~1~,P~2~,... ,P~7~中的每一个,花费7T (n/2)的时间。

  4. 对结果矩阵 C 的四个子矩阵 C11,C12,C21,C22进行修正,通过加减各种 P~i~ 矩阵来实现,这需要 Θ (n2)的时间。

假定一旦矩阵规模从n变为1,就进行简单的标量乘法计算,正如SQUARE-MATRIX-MULTIPLY­RECURSIVE的第4行那样。当n>l时,步骤1、2和4共花费Θ(n2)时间,步骤3要求进行7次n/2Xn/2矩阵的乘法。因此,我们得到如下描述Strassen算法运行时间T(n)的递归式:

4.数学计算原理

步骤2,创建如下10个矩阵

10个矩阵:​

由于必须进行10次n/2Xn/2矩阵的加减法,因此,该步骤花费8(n^2^)时间。

步骤3,递归地计算7次n/2*n/2矩阵的乘法

步骤4,对步骤3创建的P;矩阵进行加减法运算,计算出C的4个n/2* n/2的子矩阵.

C11 = C11 + P5 + P4-P2 + P6.​

C​~12~​ = ​C​~12~​ + ​P​~1~​ + ​P​~2~​

C​~21~​ = ​C​~21~​ + ​P​~3~​ + ​P​~4~​

C22 = C22 + P5 + P1-P3-P7​​

5.疑问

1.为什么要使用P1 到p7

用加减法来降低乘法的次数

在 Strassen 算法中,使用了七个中间矩阵 P1 到 P7 来计算矩阵乘积,具体原因如下: 为了避免重复计算,需要将输入的两个矩阵分别拆分成四个小矩阵,然后将这些小矩阵用来计算中间矩阵 P1 到 P7。 中间矩阵 P1 到 P7 的计算方式是通过加减法来实现的,因此可以减少乘法的次数,从而减少计算量。 通过使用中间矩阵 P1 到 P7,可以将原来需要进行 8 次矩阵乘法的计算量减少到 7 次。 综上所述,使用中间矩阵 P1 到 P7 可以提高矩阵乘法的效率,减少计算量。

2.可以使用p1到p10吗

在 Strassen 算法中,使用中间矩阵 P1 到 P7 是该算法的核心思想之一,它们用于表示两个矩阵相乘的结果。通常情况下,使用七个中间矩阵是足够的,因为它们已经能够减少矩阵乘法的计算量,从而提高算法的效率。 虽然可以添加更多的中间矩阵,例如 P8、P9、P10 等,但是这样并不一定能够进一步提高算法的效率,反而可能会增加计算量和程序的复杂度。因此,使用七个中间矩阵 P1 到 P7 是 Strassen 算法的常规做法,也是较为有效的优化策略。

3.至于为什么c11,c12 是这样的表达式

在 Strassen 算法中,矩阵乘积的计算采用了递归的方式,将输入矩阵分解成多个子矩阵进行计算,最终将这些子矩阵合并得到矩阵乘积。

具体地说,对于两个 n×n 的矩阵 A 和 B,Strassen 算法将 A 和 B 分解成四个 n/2×n/2 的子矩阵,然后通过以下公式计算矩阵乘积 C = A × B:

C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7

其中,P1 到 P7 是中间矩阵,表示为:

P1 = A11 × (B12 - B22)
P2 = (A11 + A12) × B22
P3 = (A21 + A22) × B11
P4 = A22 × (B21 - B11)
P5 = (A11 + A22) × (B11 + B22)
P6 = (A12 - A22) × (B21 + B22)
P7 = (A11 - A21) × (B11 + B12)

在这个公式中,C11 表示乘积的左上角 n/2×n/2 的子矩阵,P5、P4、P2 和 P6 都是中间矩阵,它们通过加减法来计算 C11 的值。具体来说:

P5 表示 (A11 + A22) × (B11 + B22) 的结果,它包含 C11、C12、C21 和 C22 中的所有元素。 P4 表示 A22 × (B21 - B11) 的结果,它包含 C11 和 C21 中的所有元素。 P2 表示 (A11 + A12) × B22 的结果,它包含 C11 和 C12 中的所有元素。 P6 表示 (A12 - A22) × (B21 + B22) 的结果,它包含 C11 和 C21 中的所有元素。 因此,将这些中间矩阵相加减,可以得到 C11 的值。具体来说,C11 = P5 + P4 - P2 + P6。这个公式的含义是,将 P5、P4、P2 和 P6 中包含 C11 的部分相加减,可以得到 C11 的值。


http://www.ppmy.cn/news/47333.html

相关文章

渲染管线介绍

返回目录 大家好&#xff0c;我是阿赵。 渲染管线网上很多人都介绍过&#xff0c;我这个基本上是写给自己的看的一个笔记&#xff0c;各位不用介意。 一、渲染流水线的简单说明&#xff1a; 如果把渲染流水线列出来&#xff0c;大概有这些过程&#xff1a; CPU模型数据-顶点…

AIGC大潮下:入局门槛极低,投资人陷入空前焦虑

创业门槛低、基金不好投。但是金子总会发光&#xff0c;当项目找到合适的痛点&#xff0c;AIGC的能量将会逐渐释放。 “我的朋友在开发一个‘骂人’机器人&#xff0c;用AIGC训练&#xff0c;保证网络上和别人对骂不吃亏&#xff0c;”某位投资人讲道。在我们所了解的项目中&am…

PXE+Kickstart自动化安装操作系统

文章目录 PXEKickstart 完美自动化部署系统理论知识&#xff1a;1、PXE2、DHCP 实践实验&#xff1a;1、DHCP服务器配置2、TFTP服务器配置3、HTTP服务器安装4、PXE配置5、Kickstart实践配置 PXEKickstart 完美自动化部署系统 理论知识&#xff1a; 无人值守原理&#xff1a;K…

大厂对ChatGPT的开发利用和评估案例收录

ChatGPT已经进入各行各业&#xff0c;但是实际在工作中的有哪些应用呢&#xff1f;这里分享互联网一线大厂分享的一些实际使用案例&#xff0c;所有文章收录到 大厂对ChatGPT的开发利用和评估案例收录http://​www.webhub123.com/#/home/detail?projectHashid67792343&own…

初学SSM时做的-IKUN图书管理系统

项目介绍 项目工具:IntelliJ IDEA 2021.2.2 图书后台管理系统&#xff0c;采用SpringBootMybatiusThymeleaf&#xff0c;页面使用Element框架&#xff0c;使用RESTful API风格编写接口。 数据库使用mysql 已实现功能 基本增删改查,联表查询 拦截器登录验证 项目技术栈 Sp…

外卖小程序10

目录 Apache ECharts介绍入门案例实现步骤代码 总结需求1Service层思路代码实现Controller层ReportController Service层ReportServiceReportServiceImpl Mapper层OrderMapperOrderMapper.xml 需求2Service层思路代码实现Controller层ReportController Service层ReportServiceR…

技术分析内核并发消杀器(KCSAN)一文解决!

一、KCSAN介绍 KCSAN(Kernel Concurrency Sanitizer)是一种动态竞态检测器&#xff0c;它依赖于编译时插装&#xff0c;并使用基于观察点的采样方法来检测竞态&#xff0c;其主要目的是检测数据竞争。 KCSAN是一种检测LKMM(Linux内核内存一致性模型)定义的数据竞争(data race…

电路原理-反激式电路

1、1反激式电路是小功率电源(150W以下)当中&#xff0c;最常用的电路&#xff0c;它的工作原理如下。 1、2如图1&#xff0c;变压器T1&#xff0c;标记红点的端&#xff0c;12、3、A为同名端&#xff0c;10、1、B为异名端。 当MOS管导通的时候&#xff0c;初级绕组N1、…