Pytorch中矩阵乘法使用及案例

ops/2025/3/15 0:44:24/

六种矩阵乘法

torch中包含许多矩阵乘法,大致可以分为以下几种:

  • *:即a * b 按位相乘,要求ab的形状必须一致,支持广播操作

  • torch.matmul():最广泛的矩阵乘法

  • @:与torch.matmul()效果一样(等价),即torch.matmul(a, b) == a @ b

  • torch.dot():两个一维向量乘法,不支持广播

  • torch.mm():两个二维矩阵的乘法,不支持广播

  • torch.bmm():两个三维矩阵乘法(批次batch粒度),且两个矩阵必须是三维的,不支持广播操作

其中,torch.matmul()中包含torch.dot()torch.mm()torch.bmm()

代码验证

torch.dot()

a = torch.tensor([2, 3])
b = torch.tensor([2, 1])## 下面四个函数的结果是一样的  结果都是7
a.dot(b)
torch.dot(a, b)
a @ b
torch.matmul(a, b)

输出结果:
在这里插入图片描述

torch.matmul()torch.dot()的主要区别就是,当两个向量(矩阵)的维度不一致时,torch.matmul()会进行广播,而torch.dot()会报错

*

对向量ab进行按位相乘

a = torch.tensor([2, 3])
b = torch.tensor([2, 1])a * b  # [4, 3]

torch.mm()

用于二维矩阵的相乘——第一个向量的和第二个向量的必须相等

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)## 下面三个输出结果是一样的
torch.mm(mat1, mat2)
mat1.matmul(mat2)
mat1 @ mat2

输出结果:
在这里插入图片描述

torch.matmul()torch.mm()的主要区别就是,当两个矩阵的维度不一致时,torch.matmul()会进行广播,而torch.mm()会报错

torch.bmm()

应用于三维矩阵,要求:

  • 两个矩阵的第一个维度的大小必须相同
  • 必须满足第一个矩阵(b × n × m),第二个矩阵(b × m × p),即第一个矩阵的第三个维度必须和第二个矩阵的第二个维度相同
  • 输出大小:(b × n × p)

该函数相当于分别对每个batch进行二维矩阵相乘

bmat1 = torch.randn(2, 1, 4)
bmat2 = torch.randn(2, 4, 2)## 下面三个输出是一样的
torch.bmm(bmat1, bmat2)
bmat1.matmul(bmat2)
bmat1 @ bmat2

输出结果:
在这里插入图片描述

换一种角度想,torch.bmm()就是相当于按照批次batch进行索引,然后将每个批次内的二维矩阵进行相乘

for i in range(bmat1.shape[0]):  # 索引出来批次bmat1.shape[0]temp =torch.mm(bmat1[i, :, :], bmat2[i, :, :])print(temp)

在这里插入图片描述


http://www.ppmy.cn/ops/165802.html

相关文章

如何用终端运行一个SpringBoot项目

在项目开发阶段,为了能够快速测试一个SpringBoot项目的执行结果,就可以采用终端(黑窗)运行查看,因为我们不能要求每一个客户都安装idea并且适配我们的项目版本。 下面将展示打包运行这两个方面的过程: 创建…

重新安排行程 (leetcode 332

看了一上午题解,还是没明白 targets[result[result.size() - 1]] 是什么意思/(ㄒoㄒ)/~~ 然后搜到了: 对于targets[result[result.size() - 1]]的解释 突然就清楚多了!!

Ktor库使用HTTP编写了一个下载程序

使用 Ktor 库编写一个下载程序也是非常简单的,Ktor 是一个强大的 Kotlin 网络框架,支持 HTTP 请求和响应,适用于构建客户端和服务器应用。 下面是使用 Ktor 库编写的一个简单下载程序,功能是从指定的 URL 下载文件并保存到本地。…

RoboVQA:机器人多模态长范围推理

23 年 11 月来自 Google Deepmind 的论文“RoboVQA: Multimodal Long-Horizon Reasoning for Robotics”。 本文提出一种可扩展、自下而上且本质多样化的数据收集方案,该方案可用于长期和中期的高级推理,与传统的狭窄自上而下的逐步收集相比&#xff0c…

Python学习第十三天

正则表达式 什么是正则表达式:简单来说就是通过特殊符号匹配想要的字符串,正则表达式本身就是基于字符串的一套搜索规则,掌握了正则表达式对于字符串有了更深的把握和理解。 概念 官网概念:正则表达式(Regular Expres…

【QT】-一文读懂抽象类

抽象类(Abstract Class)是面向对象编程中的一个概念,指的是无法被实例化的类,它通常作为其他类的基类。抽象类的作用是定义一个接口(或约定),让派生类(继承自抽象类的类)来实现具体的功能。 抽象类的特点: 包含纯虚函数(Pure Virtual Function): 抽象类通常包含一…

Linux 使用 docker 安装 Gogs 公司私有 Git 仓库

Gogs 简介 Gogs(Go Git Service)是一个用 Go 语言编写的自托管 Git 服务,类似于 GitHub 或 GitLab,但更轻量、易于部署和使用。Gogs 的目标是提供一个简单、快速且低资源占用的 Git 服务,适合个人开发者、小团队或企业…

LeetCode Hot100刷题——对称二叉树

101.对称二叉树 给你一个二叉树的根节点 root , 检查它是否轴对称。 示例 1: 输入:root [1,2,2,3,4,4,3] 输出:true示例 2: 输入:root [1,2,2,null,3,null,3] 输出:false提示: 树…