六种矩阵乘法
torch
中包含许多矩阵乘法,大致可以分为以下几种:
-
*
:即a * b
按位相乘,要求a
和b
的形状必须一致,支持广播操作 -
torch.matmul()
:最广泛的矩阵乘法 -
@
:与torch.matmul()
效果一样(等价),即torch.matmul(a, b) == a @ b
-
torch.dot()
:两个一维向量乘法,不支持广播 -
torch.mm()
:两个二维矩阵的乘法,不支持广播
其中,torch.matmul()
中包含torch.dot()
、torch.mm()
和torch.bmm()
代码验证
torch.dot()
a = torch.tensor([2, 3])
b = torch.tensor([2, 1])## 下面四个函数的结果是一样的 结果都是7
a.dot(b)
torch.dot(a, b)
a @ b
torch.matmul(a, b)
输出结果:
但torch.matmul()
和torch.dot()
的主要区别就是,当两个向量(矩阵)的维度不一致时,torch.matmul()
会进行广播,而torch.dot()
会报错
*
对向量a
和b
进行按位相乘
a = torch.tensor([2, 3])
b = torch.tensor([2, 1])a * b # [4, 3]
torch.mm()
用于二维矩阵的相乘——第一个向量的列和第二个向量的行必须相等
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)## 下面三个输出结果是一样的
torch.mm(mat1, mat2)
mat1.matmul(mat2)
mat1 @ mat2
输出结果:
但torch.matmul()
和torch.mm()
的主要区别就是,当两个矩阵的维度不一致时,torch.matmul()
会进行广播,而torch.mm()
会报错
torch.bmm()
应用于三维矩阵,要求:
- 两个矩阵的第一个维度的大小必须相同
- 必须满足第一个矩阵:
(b × n × m)
,第二个矩阵:(b × m × p)
,即第一个矩阵的第三个维度必须和第二个矩阵的第二个维度相同 - 输出大小:
(b × n × p)
该函数相当于分别对每个batch
进行二维矩阵相乘
bmat1 = torch.randn(2, 1, 4)
bmat2 = torch.randn(2, 4, 2)## 下面三个输出是一样的
torch.bmm(bmat1, bmat2)
bmat1.matmul(bmat2)
bmat1 @ bmat2
输出结果:
换一种角度想,torch.bmm()
就是相当于按照批次batch
进行索引,然后将每个批次内的二维矩阵进行相乘
for i in range(bmat1.shape[0]): # 索引出来批次bmat1.shape[0]temp =torch.mm(bmat1[i, :, :], bmat2[i, :, :])print(temp)