深度学习Pytorch常用api详解记录

news/2024/10/21 10:18:30/

深度学习常用的torch函数

  • torch.cat()
  • torch.Tensor.repeat():
  • 持续更新中...

torch.cat()

对象:给定的序列化张量,即Tensor型。
功能:实现两个张量在指定维度上的拼接。
输出:拼接后的张量。
函数以及参数torch.cat(tensor, dim),官方给出的有四个参数,但是我们平时只会用到前两个参数即可。
tensor:有相同形状的张量序列,所有的张量需要有相同的形状才能够拼接,除非是在拼接维度上两个张量可以有不同的尺寸,或者两个张量都是空的。
dim:两个张量或者多个张量拼接的维度。

应用实例1:两个张量形状相同

代码

import torchx = torch.randn(2,4)
y = torch.randn(2,4)print(f'x={x}','\n',f'y={y}')
print(f'z={torch.cat((x,y), 0)}')

输出:

x=tensor([[-1.2870, -0.7040,  0.3016, -0.2970],[-0.8151, -0.5236, -1.7680,  0.7675]]) y=tensor([[-1.4207, -0.2694,  0.2521, -0.7187],[ 0.8776, -0.0352, -0.5094,  0.0602]])
z=tensor([[-1.2870, -0.7040,  0.3016, -0.2970],[-0.8151, -0.5236, -1.7680,  0.7675],[-1.4207, -0.2694,  0.2521, -0.7187],[ 0.8776, -0.0352, -0.5094,  0.0602]])

应用实例2:多个张量形状相同

代码

import torchx = torch.randn(2,4)
y = torch.randn(2,4)print(f'x={x}','\n',f'y={y}')
print(f'z={torch.cat((x,y,x,y), 0)}')

输出

x=tensor([[ 0.4697, -0.4881, -2.0199, -0.8661],[ 0.4911, -0.1259,  1.1939,  0.7730]]) y=tensor([[ 0.8633,  0.4438, -0.6975,  0.5440],[ 0.1554, -1.6358, -1.2234, -0.6597]])
z=tensor([[ 0.4697, -0.4881, -2.0199, -0.8661],[ 0.4911, -0.1259,  1.1939,  0.7730],[ 0.8633,  0.4438, -0.6975,  0.5440],[ 0.1554, -1.6358, -1.2234, -0.6597],[ 0.4697, -0.4881, -2.0199, -0.8661],[ 0.4911, -0.1259,  1.1939,  0.7730],[ 0.8633,  0.4438, -0.6975,  0.5440],[ 0.1554, -1.6358, -1.2234, -0.6597]])

应用实例3:两个张量形状不同,但只在拼接维度上

代码

import torchx = torch.randn(3,4)
y = torch.randn(2,4)print(f'x={x}','\n',f'y={y}')
print(f'z={torch.cat((x,y), 0)}')x_1 = torch.randn(2,3)
y_1 = torch.randn(2,4)print(f'x_1=\n{x_1}','\n',f'y_1=\n{y_1}')
print(f'z_1=\n{torch.cat((x_1,y_1), 1)}')

输出

x=tensor([[-0.1966, -0.9648,  1.2787, -1.4578],[-1.2216,  0.1663,  0.5380, -0.0376],[-1.7365, -0.4151, -1.0336, -0.6732]]) y=tensor([[ 1.4477,  0.3616, -0.1504,  0.4662],[-1.1334,  1.3100,  0.1624,  0.8206]])
z=tensor([[-0.1966, -0.9648,  1.2787, -1.4578],[-1.2216,  0.1663,  0.5380, -0.0376],[-1.7365, -0.4151, -1.0336, -0.6732],[ 1.4477,  0.3616, -0.1504,  0.4662],[-1.1334,  1.3100,  0.1624,  0.8206]])
x_1=
tensor([[ 1.1418,  0.0774,  0.2047],[-0.0673, -1.5794,  0.0131]]) y_1=
tensor([[ 1.4149, -1.9538,  0.1660,  1.1142],[-1.6455,  0.5595, -0.1162,  0.8628]])
z_1=
tensor([[ 1.1418,  0.0774,  0.2047,  1.4149, -1.9538,  0.1660,  1.1142],[-0.0673, -1.5794,  0.0131, -1.6455,  0.5595, -0.1162,  0.8628]])Process finished with exit code 0

torch.Tensor.repeat():

对象:给定的张量,即Tensor型。
功能:在指定的维度上对张量进行重复扩充,也可以用来增加维度。
输出:升维或扩充后的张量。
函数以及参数torch.tensor.repeat(size),size所在的索引表示扩充的维度的索引。
size:表示张量在这个索引维度下的扩充倍数。
注意事项:函数的参数量必须大于等于tensor的维度,如a.shape=(2,3),那么如果我们想扩充2倍a的第0个维度时,应该这么写a.repeat(2,1),对于不扩充的维度则写1。

应用实例1:一维张量扩充

代码

import torchx = torch.randn(3)print(f'x={x}')print(f'x_1={x.repeat(2)}')

输出

x=tensor([-0.1485,  1.8445,  1.4257])
x_1=tensor([-0.1485,  1.8445,  1.4257, -0.1485,  1.8445,  1.4257])

应用实例2:多维张量扩充

代码

import torchx = torch.randn(3, 4, 3)print(f'x={x}')
#在第2个维度上扩充两倍,其他维度保持不变
print(f'x_1={x.repeat(1,1,2)}')

输出

x=tensor([[[-0.0294,  1.2902,  0.9825],[-0.3032,  1.6733,  0.9163],[ 0.3079, -0.0159,  0.2626],[-0.2934, -0.6076,  0.1593]],[[ 1.7661, -1.0698,  0.4074],[-0.3660, -0.3219,  0.3732],[-1.3314, -0.8263, -1.0793],[ 1.2589,  0.1886,  0.5453]],[[ 0.2520, -0.5695, -0.6685],[ 0.5554,  0.0119, -0.5650],[ 0.9733, -0.3812,  0.1963],[-1.1284,  0.2561,  0.4507]]])
x_1=tensor([[[-0.0294,  1.2902,  0.9825, -0.0294,  1.2902,  0.9825],[-0.3032,  1.6733,  0.9163, -0.3032,  1.6733,  0.9163],[ 0.3079, -0.0159,  0.2626,  0.3079, -0.0159,  0.2626],[-0.2934, -0.6076,  0.1593, -0.2934, -0.6076,  0.1593]],[[ 1.7661, -1.0698,  0.4074,  1.7661, -1.0698,  0.4074],[-0.3660, -0.3219,  0.3732, -0.3660, -0.3219,  0.3732],[-1.3314, -0.8263, -1.0793, -1.3314, -0.8263, -1.0793],[ 1.2589,  0.1886,  0.5453,  1.2589,  0.1886,  0.5453]],[[ 0.2520, -0.5695, -0.6685,  0.2520, -0.5695, -0.6685],[ 0.5554,  0.0119, -0.5650,  0.5554,  0.0119, -0.5650],[ 0.9733, -0.3812,  0.1963,  0.9733, -0.3812,  0.1963],[-1.1284,  0.2561,  0.4507, -1.1284,  0.2561,  0.4507]]])

应用实例3:张量维度扩充

代码

import torchx = torch.randn(1,2)print(f'x={x}')
#将a多扩充一个维度,这个维度扩充的倍数需要写在最前面,如此案例的3
print(f'x_1={x.repeat(3,1,1)}')

输出

x=tensor([[-0.2581, -0.8387]])
x_1=tensor([[[-0.2581, -0.8387]],[[-0.2581, -0.8387]],[[-0.2581, -0.8387]]])

持续更新中…


http://www.ppmy.cn/news/1097806.html

相关文章

【Python程序设计】 工厂模式【07/8】

一、说明 我们探索数据工程中使用的设计模式 - 软件设计中常见问题的可重用解决方案。 以下文章是有关 Python 数据工程系列文章的一部分,旨在帮助数据工程师、数据科学家、数据分析师、机器学习工程师或其他刚接触 Python 的人掌握基础知识。 迄今为止,…

【CMake工具】工具CMake编译轻度使用(C/C++)

目录 CMake编译工具 一、CMake概述 二、CMake的使用 2.1 注释 2.1.1 注释行 2.1.2 注释块 2.2 源文件 2.1.1 共处一室 2.1.2 VIP包房 2.3 私人定制 2.2.1 定义变量 2.2.2 指定使用的C标准 2.2.3 指定输出的路径 2.4 搜索文件 2.3.1 方式1 2.3.2 方式2 2.5 包含…

【Spring面试】BeanFactory与IoC容器的加载

文章目录 Q1、BeanFactory的作用是什么?Q2、BeanDefinition的作用是什么?Q3、BeanFactory和ApplicationContext有什么区别?Q4、BeanFactory和FactoryBean有什么区别?Q5、说下Spring IoC容器的加载过程(※)Q…

Unity的UI面板基类

使用这个组件实现淡入淡出 public abstract class BasePanel : MonoBehaviour {//控制面板透明度 用于淡入淡出private CanvasGroup canvasGroup;//淡入淡出速度private float alphaSpeed 10;//隐藏还是显示public bool isShow false;//隐藏完毕后做的事private UnityAction …

27.方向标

题目 描述 一位木匠收到了一个木制指示牌的订单。每块木板必须与前一块垂直对齐,要么与前一个箭头的基部对齐,要么与相反的一侧对齐,在那里用特制的螺钉固定。两块木板必须重叠。木匠将设计师发送的草图编码成了一个整数序列,但…

Java基础09 —— 字符序列--String、StringBuilder、StringBuffer区别及其方法介绍

Java基础09 —— 字符序列 字符串类型 字符与字符串 字符类型(char)是Java中的基本数据类型,占2个字节16位,默认值是 ‘\u0000’ 。字符是用单引号引住的单个符号. // 字符 char c A; //单引号 char cA 65; //数字 char c1 \u8888; //Unicode码 S…

嵌入式IDE(2):KEIL中SCF分散加载链接文件详解和实例分析

在上一篇文章IAR中ICF链接文件详解和实例分析中,我通过I.MX RT1170的SDK中的内存映射关系,分析了IAR中的ICF链接文件的语法。对于MCU编程所使用的IDE来说,IAR和Keil用得比较多,所以这一篇文章就来分析一下Keil的分散文件.scf(scat…

java中将List数据平均切分成N份

话不多说&#xff0c;直接上代码&#xff0c;直接用 public static <T> List<List<T>> averageList(List<T> source, int n) {List<List<T>> ret new ArrayList<List<T>>();int number source.size() / n;int remainder so…