Python | Pytorch | Tensor知识点总结

embedded/2025/2/1 1:18:21/

如是我闻: Tensor 是我们接触Pytorch了解到的第一个概念,这里是一个关于 PyTorch Tensor 主题的知识点总结,涵盖了 Tensor 的基本概念、创建方式、运算操作、梯度计算和 GPU 加速等内容。


1. Tensor 基本概念

  • Tensor 是 PyTorch 的核心数据结构,类似于 NumPy 的 ndarray,但支持 GPU 加速和自动求导。
  • PyTorch 的 Tensor 具有 动态计算图,可用于深度学习模型的前向传播和反向传播。

PyTorch Tensor vs. NumPy Array

特性PyTorch TensorNumPy Array
支持 GPU
自动求导✅ (requires_grad=True)
兼容性✅ (可转换为 NumPy)✅ (可转换为 Tensor)

2. Tensor 创建方式

2.1 直接创建 Tensor

python">import torch# 从列表创建
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)print(a, a.dtype)  # 默认 int64
print(b, b.dtype)  # float32

2.2 常见初始化方法

python"># 全零/全一 Tensor
x = torch.zeros((3, 3))
y = torch.ones((2, 2))# 随机初始化
z = torch.rand((3, 3))  # [0, 1) 均匀分布
n = torch.randn((2, 2)) # 标准正态分布# 单位矩阵
I = torch.eye(3)# 创建指定范围的 Tensor
r = torch.arange(0, 10, 2)  # [0, 2, 4, 6, 8]
l = torch.linspace(0, 1, 5) # [0.0, 0.25, 0.5, 0.75, 1.0]

2.3 通过 NumPy 互转

python">import numpy as np# NumPy -> PyTorch
np_array = np.array([[1, 2], [3, 4]])
tensor_from_np = torch.from_numpy(np_array)# PyTorch -> NumPy
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
np_from_tensor = tensor.numpy()

3. Tensor 形状操作

3.1 维度变换

python">x = torch.randn(2, 3, 4)# 改变形状
y = x.view(6, 4)   # 使用 view 改变形状 (必须保证数据连续存储)
z = x.reshape(6, 4)  # reshape 不受数据存储方式限制# 维度扩展
x_exp = x.unsqueeze(0)  # 在第 0 维添加一个维度
x_squeeze = x_exp.squeeze(0)  # 去除维数为 1 的维度

3.2 维度交换

python">x = torch.rand(2, 3, 4)x_t = x.permute(2, 0, 1)  # 交换维度
x_t2 = x.transpose(1, 2)  # 交换 1 和 2 维

4. Tensor 运算

4.1 逐元素运算

python">x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])# 逐元素运算
add = x + y  # 或 torch.add(x, y)
sub = x - y  # 或 torch.sub(x, y)
mul = x * y  # 或 torch.mul(x, y)
div = x / y  # 或 torch.div(x, y)# 指数、对数、幂运算
exp = torch.exp(x)
log = torch.log(y)
pow_2 = x.pow(2)  # 平方

4.2 线性代数运算

python">A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])# 矩阵乘法
C = torch.mm(A, B)  # 矩阵乘法
D = A @ B  # 矩阵乘法 (等价于 mm)# 逆矩阵
A_inv = torch.inverse(A.float())# 计算特征值和特征向量
eigenvalues, eigenvectors = torch.eig(A.float(), eigenvectors=True)

4.3 统计运算

python">x = torch.randn(3, 3)mean_x = x.mean()    # 均值
std_x = x.std()      # 标准差
sum_x = x.sum()      # 总和
max_x = x.max()      # 最大值
argmax_x = x.argmax() # 最大值索引

5. Tensor 计算图和自动求导

5.1 计算梯度

python">x = torch.tensor(2.0, requires_grad=True)y = x**2 + 3*x + 1  # 计算 y
y.backward()  # 计算梯度print(x.grad)  # dy/dx = 2x + 3 -> 2*2 + 3 = 7

5.2 阻止梯度计算

python">x = torch.tensor(2.0, requires_grad=True)with torch.no_grad():y = x**2 + 3*x + 1  # 计算过程中不记录梯度

6. GPU 计算

6.1 设备选择

python">device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

6.2 在 GPU 上创建 Tensor

python">x = torch.randn(3, 3, device=device)

6.3 在 CPU 和 GPU 之间转换

python">x_cpu = x.to("cpu")  # 移回 CPU
x_gpu = x.to("cuda")  # 移至 GPU

7. 总的来说

主题关键知识点
Tensor 创建torch.tensor()torch.zeros()torch.rand()
NumPy 互转torch.from_numpy().numpy()
形状变换.view().reshape().unsqueeze()
运算逐元素计算、矩阵运算、统计运算
自动求导requires_grad=True.backward()
GPU 加速torch.device("cuda").to("cuda")

以上


http://www.ppmy.cn/embedded/158497.html

相关文章

《DeepSeek R1:开启AI推理新时代》

《DeepSeek R1:开启AI推理新时代》 一、AI 浪潮中的新星诞生二、DeepSeek R1 的技术探秘(一)核心技术架构(二)强化学习的力量(三)多阶段训练策略(四)长序列处理优势 三、…

【fly-iot飞凡物联】(20):2025年总体规划,把物联网整套技术方案和实现并落地,完成项目开发和课程录制。

前言 fly-iot飞凡物联专栏: https://blog.csdn.net/freewebsys/category_12219758.html 1,开源项目地址进行项目开发 https://gitee.com/fly-iot/fly-iot-platform 完成项目开发,接口开发。 把相关内容总结成文档,并录制课程。…

RabbitMQ模块新增消息转换器

文章目录 1.目录结构2.代码1.pom.xml 排除logging2.RabbitMQConfig.java3.RabbitMQAutoConfiguration.java 1.目录结构 2.代码 1.pom.xml 排除logging <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/PO…

SpringBoot Web开发(SpringMVC)

SpringBoot Web开发&#xff08;SpringMVC) MVC 核心组件和调用流程 Spring MVC与许多其他Web框架一样&#xff0c;是围绕前端控制器模式设计的&#xff0c;其中中央 Servlet DispatcherServlet 做整体请求处理调度&#xff01; . 除了DispatcherServletSpringMVC还会提供其他…

PostgreSQL 约束

PostgreSQL 约束 在数据库设计中,约束(Constraint)是一种规则,用于确保数据库中的数据满足特定的条件。PostgreSQL 作为一款功能强大的开源关系型数据库管理系统,提供了多种约束类型,以帮助开发者维护数据的一致性和准确性。本文将详细介绍 PostgreSQL 中常见的约束类型…

MySQL安装教程

一、下载 点开下面的链接&#xff1a;下载地址 点击Download 就可以下载对应的安装包了, 安装包如下: 二、解压 下载完成后我们得到的是一个压缩包&#xff0c;将其解压&#xff0c;我们就可以得到MySQL 8.0.34 的软件本体了(就是一个文件夹)&#xff0c;我们可以把它放在你想…

5.进程基本概念

5.进程基本概念 **1. 进程的基本概念****2. 进程与程序的区别****3. 进程的状态****4. 进程调度****5. 进程相关命令****6. 进程创建与管理****7. 进程的应用场景****8. 练习与作业****9. 进程的地址空间****10. 进程的分类****11. 进程的并发与并行****12. 总结** 1. 进程的基…

UE求职Demo开发日志#12 完善击杀获得物品逻辑和UI

1 实现思路 1.给WarehouseManager添加一个按TArray增加物品的函数 2.Enemy身上一个变量记录掉落物品&#xff0c;死亡时调用增加物品函数 3.同时调用UI显示 2 实现过程 2.1 在WarehouseManager里添加一个AddItemByArray函数 遍历数组调用添加函数 void UWarehouseManage…