PyTorch2

ops/2024/11/27 9:01:06/

Tensor的常见操作:

获取元素值:

注意:

  • 和Tensor的维度没有关系,都可以取出来!

  • 如果有多个元素则报错;

import torch
def test002():data = torch.tensor([18])print(data.item())pass
if __name__ == "__main__":test002()
元素值运算:

常见的加减乘除次方取反开方等各种操作,带有_的方法则会替换原始值。

import torch
def test001():data = torch.randint(0, 10, (2, 3))print(data)# 元素级别的加减乘除:不修改原始值print(data.add(1))print(data.sub(1))print(data.mul(2))print(data.div(3))print(data.pow(2))# 元素级别的加减乘除:修改原始值data = data.float()data.add_(1)data.sub_(1)data.mul_(2)data.div_(3.0)data.pow_(2)print(data)
if __name__ == "__main__":test001()
阿达玛积:

阿达玛积指的是矩阵对应位置的元素相乘,可以使用mul函数或者*来实现

import torch
def test001():data1 = torch.tensor([[1, 2, 3], [4, 5, 6]])data2 = torch.tensor([[2, 3, 4], [2, 2, 3]])print(data1 * data2)
def test002():data1 = torch.tensor([[1, 2, 3], [4, 5, 6]])data2 = torch.tensor([[2, 3, 4], [2, 2, 3]])print(data1.mul(data2))
if __name__ == "__main__":test001()test002()
Tensor相乘:

点积运算将两个向量映射为一个标量,是向量之间的基本操作。

点积运算要求如果第一个矩阵的shape是 (N, M),那么第二个矩阵 shape必须是 (M, P),最后两个矩阵点积运算的shape为 (N, P)。

import torch
def test006():data1 = torch.tensor([[1, 2, 3], [4, 5, 6]])data2 = torch.tensor([[3, 2], [2, 3], [5, 3]])print(data1 @ data2)print(data1.matmul(data2))print(data1.mm(data2))
if __name__ == "__main__":test006()
索引操作:

掌握张量的花式索引在处理复杂数据时非常有用。

简单索引:

索引,就是根据指定的下标选取数据。

import torch
def test006():data = torch.randint(0, 10, (3, 4))print(data)# 1. 行索引print("行索引:", data[0])    # 2. 列索引print("列索引:", data[:, 0])# 3. 固定位置索引:2种方式都行print("索引:", data[0, 0], data[0][0])
if __name__ == "__main__":test006()
列表索引:

使用list批量的制定要索引的元素位置~此时注意list的维度

import torch
def test008():data = torch.randint(0, 10, (3, 4))print(data)# 1. 使用列表进行索引:(0, 0), (1, 1), (2, 1)print("列表索引:", data[[0, 1, 2], [0, 1, 1]])# 2. 行级别的列表索引print("行级别列表索引:", data[[[2], [1]], [0, 1, 2]])
if __name__ == "__main__":test008()
布尔索引:

根据条件选择张量中的元素。

import torch
def test009():tensor = torch.tensor([1, 2, 3, 4, 5])mask = tensor > 3print(mask)print(tensor[mask])  # 输出: tensor([4, 5])
if __name__ == "__main__":test009()
索引赋值:

使用索引进行批量元素值修修改

import torch
def test666():data = torch.eye(4)print(data)# 赋值data[:, 1:-1] = 0print(data)
if __name__ == "__main__":test666()
张量拼接:

在 PyTorch 中,cat 和 stack 是两个用于拼接张量的常用操作,但它们的使用方式和结果略有不同:

  • cat:在现有维度上拼接,不会增加新维度。

  • stack:在新维度上堆叠,会增加一个维度。

#拼接
import torch
a = torch.full((2,3),1)
b = torch.full((2,3),2)
c = torch.cat((a,b),dim=0)#0行拼接,1列拼接
print(c)
d = torch.cat((a,b),dim=1)
print(d)
#堆叠
import torch
a = torch.full((2,3), 5)
b = torch.full((2,3), 6)
print(torch.stack([a,b],dim=0))
print(torch.stack([a,b],dim=1))
a = torch.full((2,3,3), 5)
b = torch.full((2,3,3), 6)
print(torch.stack([a,b],dim=2))
a = torch.full((4,808,555), 5)
print(torch.stack([a[0],a[1],a[3]],dim=2).shape)
形状操作:

 在 PyTorch 中,张量的形状操作是非常重要的,因为它允许你灵活地调整张量的维度和结构,以适应不同的计算需求。

reshape:

可以用于将张量转换为不同的形状,但要确保转换后的形状与原始形状具有相同的元素数量。

view:

view进行形状变换的特征:

  • 张量在内存中是连续的;

  • 返回的是原始张量视图,不重新分配内存,效率更高;

  • 如果张量在内存中不连续,view 将无法执行,并抛出错误。

transpose:

transpose 用于交换张量的两个维度,注意,是2个维度,它返回的是原张量的视图。

permute:

permute 用于改变张量的所有维度顺序。与 transpose 类似,但它可以交换多个维度。

flatten:

用于将向量展平为一维向量

#形状操作
import torch
a = torch.tensor([[1,2,3],[4,5,6]])
b = a.reshape(3,2)
c = b.reshape(2,-1)
a.view(2,-1)#view函数和reshape函数功能相同,但reshape函数返回的是新数组,而view函数返回的是原数组,且要求是连续的
print(b)
print(c)
print(a)
a = torch.tensor([[1,2,3],[4,5,6]])
b = a.transpose(0,1)#0,1表示交换a的维度0和1,只能交换两个维度
c = torch.transpose(a,0,1)
print(b)
a=torch.randint(0,300,(2,3,4,5))
b=a.transpose(1,2)
c=a.permute(1,2,0,3)#交换a的维度.可以交换多个维度
d=a.flatten(start_dim=1,end_dim=2)#将a拉平为一维,从第一个维度开始,到第二个维度结束
print(b.shape)
print(b.shape)
print(c.shape)
print(d)
print(d.shape)
squeeze降维:

用于移除所有大小为 1 的维度,或者移除指定维度的大小为 1 的维度。

unsqueeze升维:

用于在指定位置插入一个大小为 1 的新维度

#升维和降维
import torch
x = torch.randint(10, (1,4, 5,1))
print(x.squeeze().shape)#删除所有维度为1的维度
print(x.squeeze(3).shape)#删除指定为1的维度
a = torch.randint(10, (4,5))
print(a.unsqueeze(0).shape)#在第0维增加一个大小为1的维度
 张量分割:

可以按照指定的大小或者块数进行分割。

#张量分割
import torch
a = torch.randint(1, 10, (3, 3))
b = a.split(2, dim=0)#按行分割 每份2行
c = a.split(2, dim=1)#按列分割 每份2列
d = a.chunk(3, dim=0)#按行分割成三份
print(b)
print(c)
print(d)
广播机制:

广播机制允许在对不同形状的张量进行计算,而无需显式地调整它们的形状。广播机制通过自动扩展较小维度的张量,使其与较大维度的张量兼容,从而实现按元素计算。

#广播机制
#需要满足右对齐
import torch
a = torch.tensor([1,2,3])
b = torch.tensor([[4],[5],[6]])
print(b+a)
a = torch.tensor([[1,2,3],[4,5,6]])
b = torch.full((3,2,3),1)
print(a+b)
数学运算:
#数学运算
import torch
torch.manual_seed(2)
torch.initial_seed()
a = torch.randn(1,6)
print(a)
print(torch.floor(a))#向下取整
print(torch.ceil(a))#向上取整
print(torch.round(a))#四舍五入
print(torch.sign(a))#符号
print(torch.abs(a))#绝对值
print(torch.trunc(a))#取整数部分
print(torch.frac(a))#取小数部分
print(torch.fix(a))#向零取整
print(a%2)#取余数
统计学函数:
#统计学函数
a = torch.tensor([1,2,3,2,2],dtype=torch.float)
print(torch.mean(a))#平均值 要求数据为浮点型
print(torch.median(a))#中位数
print(torch.std(a))#标准差
print(torch.var(a))#方差
print(torch.norm(a))#范数
print(torch.max(a))#最大值
print(torch.min(a))#最小值
print(torch.sum(a))#求和
print(torch.prod(a))#乘积
print(torch.mode(a))#众数
print(torch.topk(a,2))#取前2个最大值
print(torch.sort(a))#排序,给出排序后的索引地址
print(torch.unique(a))#去重
print(torch.topk(a,2,largest=False))#取前2个最小值
b=a.int()
print(torch.bincount(b))#统计每个元素的个数 要求数据为整数
arr = [1,2,3,4,5,11]
arr.sort(key = lambda x:abs(x-10))
print(arr)
print(torch.histc(a,bins=5))
三角函数:
#三角函数
a = torch.tensor([torch.pi, torch.pi / 2])
print(a)
print(torch.sin(a))#正弦函数
print(torch.cos(a))#余弦函数
print(torch.tan(a))#正切函数
print(torch.asin(torch.sin(a)))#反正弦函数
print(torch.acos(torch.cos(a)))#反余弦函数
print(torch.atan(torch.tan(a)))#反正切函数
print(torch.sinh(a))#双曲正弦函数
print(torch.cosh(a))#双曲余弦函数
print(torch.tanh(a))#双曲正切函数
保存和加载:

张量数据可以保存下来并加载再次使用

#保存和加载
import torch
a = torch.ones(3)
torch.save(a, 'data/tensor.pt')
#a.save('data/tensor.pt')
b = torch.load('data/tensor.pt')
print(b)
并行化:

在 PyTorch 中,你可以查看和设置用于 CPU 运算的线程数。PyTorch 使用多线程来加速 CPU 运算,但有时你可能需要调整线程数来优化性能。

使用 torch.get_num_threads() 来查看当前 PyTorch 使用的线程数:

使用 torch.set_num_threads() 设置 PyTorch 使用的线程数:

#并行化
#查看线程数
import torch
print(torch.get_num_threads())
#设置线程数
torch.set_num_threads(10)


http://www.ppmy.cn/ops/137039.html

相关文章

数字图像处理(6):除法运算、除法器

(1)当除数是常数时,可以先转化为乘法,再右移,乘法的N越大,计算误差越小。 如:计算x/122,可以看成(x * 67)>>13,N13,使用verilog实现: reg …

Docker部署mysql:8.0.31+dbsyncer

Docker部署mysql8.0.31 创建本地mysql配置文件 mkdir -p /opt/mysql/log mkdir -p /opt/mysql/data mkdir -p /opt/mysql/conf cd /opt/mysql/conf touch my.config [mysql] #设置mysql客户端默认字符集 default-character-setUTF8MB4 [mysqld] #设置3306端口 port33…

【贪心算法第五弹——300.最长递增子序列】

目录 1.题目解析 题目来源 测试用例 2.算法原理 3.实战代码 代码解析 注意本题还有一种动态规划的解决方法,贪心的方法就是从动态规划的方法总结而来,各位可以移步博主的另一篇博客先了解一下:动态规划-子序列问题——300.长递增子序列…

基于K8S编排部署EFK日志收集系统

基于K8S编排部署EFK日志收集系统 案例分析 1. 规划节点 节点规划,见表1。 表1 节点规划 IP主机名节点192.168.100.3masterk8s-master192.168.100.4nodek8s-node 2. 基础准备 Kubernete环境已安装完成,将提供的软件包efk-img.tar.gz上传至master节…

ElasticSearch学习笔记六:Springboot整合

一、前言 在前一篇文章中,我们学习了ES中的一部分的搜索功能,作为一名Java工程师,更多时候我们是用代码去操作ES,同时对于Java而言时下最流行的就是Springboot了,所以这里我们将ES和Springboot整合将上一篇文章中的所…

基础入门-Web应用架构搭建域名源码站库分离MVC模型解析受限对应路径

知识点: 1、基础入门-Web应用-域名上的技术要点 2、基础入门-Web应用-源码上的技术要点 3、基础入门-Web应用-数据上的技术要点 4、基础入门-Web应用-解析上的技术要点 5、基础入门-Web应用-平台上的技术要点 一、演示案例-域名差异-主站&分站&端口站&…

在远程服务器和本地同步数据的指南

在远程服务器和本地同步数据的指南 在现代软件开发和数据管理中,保持本地和远程服务器之间的数据同步是至关重要的。无论是代码、配置文件还是其他数据,确保它们在不同环境中的一致性都是高效工作的关键。本文将介绍如何使用 Bash 脚本和 rsync 工具在本…

Unity中动态生成贴图并保存成png图片实现

实现原理&#xff1a; 要生成长x宽y的贴图&#xff0c;就是生成x*y个像素填充到贴图中&#xff0c;如下图&#xff1a; 如果要改变局部颜色&#xff0c;就是从x1到x2(x1<x2),y1到y2(y1<y2)这个范围做处理&#xff0c; 或者要想做圆形就是计算距某个点&#xff08;x1,y1&…