深度学习torch基础知识

news/2025/1/15 12:05:16/

torch.

  • detach()
  • 拼接函数torch.stack()
  • torch.nn.DataParallel()
  • np.clip()
  • torch.linspace()
  • PyTorch中tensor.repeat()
  • pytorch索引查找 index_select

detach()

detach是截断反向传播的梯度流
将某个node变成不需要梯度的Varibale。因此当反向传播经过这个node时,梯度就不会从这个node往前面传播。

拼接函数torch.stack()

拼接:将多个维度参数相同的张量连接成一个张量

a=torch.tensor([[1,2,3],[4,5,6]])
b=torch.tensor([[10,20,30],[40,50,60]])
c=torch.tensor([[100,200,300],[400,500,600]])
print(torch.stack([a,b,c],dim=0))
print(torch.stack([a,b,c],dim=1))
print(torch.stack([a,b,c],dim=2))
print(torch.stack([a,b,c],dim=0).size())
print(torch.stack([a,b,c],dim=1).size())
print(torch.stack([a,b,c],dim=2).size())
#输出结果为:
tensor([[[  1,   2,   3],[  4,   5,   6]],[[ 10,  20,  30],[ 40,  50,  60]],[[100, 200, 300],[400, 500, 600]]])
tensor([[[  1,   2,   3],[ 10,  20,  30],[100, 200, 300]],[[  4,   5,   6],[ 40,  50,  60],[400, 500, 600]]])
tensor([[[  1,  10, 100],[  2,  20, 200],[  3,  30, 300]],[[  4,  40, 400],[  5,  50, 500],[  6,  60, 600]]])
torch.Size([3, 2, 3])
torch.Size([2, 3, 3])
torch.Size([2, 3, 3])

torch.nn.DataParallel()

torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
module即表示你定义的模型,device_ids表示你训练的device,output_device这个参数表示输出结果的device,而这最后一个参数output_device一般情况下是省略不写的,那么默认就是在device_ids[0],也就是第一块卡上。

np.clip()

a = np.arange(10)
np.clip(a, 1, 8)
array([1, 1, 2, 3, 4, 5, 6, 7, 8, 8]) # a被限制在1-8之间
a
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) # 没改变a的原值np.clip(a, 3, 6, out=a) # 修剪后的数组存入到a中
array([3, 3, 3, 3, 4, 5, 6, 6, 6, 6])

torch.linspace()

函数的作用是,返回一个一维的tensor,这个张量包含了从start到end,分成steps个线段得到的向量。

torch.linspace(start, end, steps=100, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) 
→ Tensor

例如

import torch
print(torch.linspace(3,10,5))
#tensor([ 3.0000,  4.7500,  6.5000,  8.2500, 10.0000])type=torch.float
print(torch.linspace(-10,10,steps=6,dtype=type))
#tensor([-10.,  -6.,  -2.,   2.,   6.,  10.])

PyTorch中tensor.repeat()

当参数只有两个时,第一个参数表示的是行复制的次数,第二个参数表示列复制的次数;
当参数有三个时,第一个参数表示的是通道复制的次数,第二个参数表示的是行复制的次数,第三个参数表示列复制的次数。
(1). 对于已经存在的维度复制

import torcha = torch.tensor([[1], [2], [3]])  # 3 * 1
b = a.repeat(3, 2)
print('a\n:', a)
print('shape of a', a.size())  # 原始shape = (3,1)
print('b:\n', b)
print('shape of b', b.size())  # 新的shape = (3*3,1*2),新增加的数据通过复制得到'''   运行结果   '''
a:
tensor([[1],[2],[3]])
shape of a torch.Size([3, 1])  注: 原始shape =31)
b:tensor([[1, 1],[2, 2],[3, 3],[1, 1],[2, 2],[3, 3],[1, 1],[2, 2],[3, 3]])
shape of b torch.Size([9, 2])  新的shape =3*31*2

(2). 对于原始不存在的维度数量拓展

import torch
a = torch.tensor([[1, 2], [3, 4], [5, 6]])  # 3 * 2
b = a.repeat(3, 2, 1)   # 在原始tensor的0维前拓展一个维度,并把原始tensor的第1维扩张2倍,都是通过复制来完成的
print('a:\n', a)
print('shape of a', a.size())  # 原始维度为 (3,2)
print('b:\n', b)
print('shape of b', b.size())  # 新的维度为 (3,2*2,2*1)=(3,4,2)'''   运行结果   '''
a:tensor([[1, 2],[3, 4],[5, 6]])
shape of a torch.Size([3, 2])   注:原始维度为 (32)
b:tensor([[[1, 2],[3, 4],[5, 6],[1, 2],[3, 4],[5, 6]],[[1, 2],[3, 4],[5, 6],[1, 2],[3, 4],[5, 6]],[[1, 2],[3, 4],[5, 6],[1, 2],[3, 4],[5, 6]]])
shape of b torch.Size([3, 6, 2])   新的维度为 (32*22*1=342

pytorch索引查找 index_select

anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))

先定义了一个tensor,这里用到了linspace和view方法。
第一个参数是索引的对象,第二个参数0表示按行索引,1表示按列进行索引,第三个参数是一个tensor,就是索引的序号,比如b里面tensor[0, 2]表示第0行和第2行,c里面tensor[1, 3]表示第1列和第3列

a = torch.linspace(1, 12, steps=12).view(3, 4)
print(a)
b = torch.index_select(a, 0, torch.tensor([0, 2]))
print(b)
print(a.index_select(0, torch.tensor([0, 2])))
c = torch.index_select(a, 1, torch.tensor([1, 3]))
print(c)-----输出结果-----
tensor([[ 1.,  2.,  3.,  4.],[ 5.,  6.,  7.,  8.],[ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],[ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],[ 9., 10., 11., 12.]])
tensor([[ 2.,  4.],[ 6.,  8.],[10., 12.]])

http://www.ppmy.cn/news/1006959.html

相关文章

保姆级教程,Linux服务器docker搭建jenkins持续集成一键部署SpringBoot项目(Gradle)

前言: 在后台项目开发过程从Java延伸到Kotlin开发,从maven pom到gradle,IDEA新项目SpringBoot init框架官方推荐kotlingradle,本章以此为jenkins持续集成做项目部署,服务器为Centos,JDK 17,Spr…

centos7 ‘xxx‘ is not in the sudoers file...

如题 执行命令输入密码后时报错: [sudo] password for admin (我的账户)原因,当前用户还没有加入到root的配置文件中。 解决 vim打开配置文件,如下: #切换到root用户 su #编辑配置文件 vim /etc/sudoe…

python系列教程211——map

朋友们,如需转载请标明出处:https://blog.csdn.net/jiangjunshow 声明:在人工智能技术教学期间,不少学生向我提一些python相关的问题,所以为了让同学们掌握更多扩展知识更好地理解AI技术,我让助理负责分享…

android 实现拨打电话号码。

在拨打电话号码之前,预设一个B号码,正常使用电话时,本来输入的是A号码。实际拨打的是B号码。但是接听页面显示的是A号码。是不是比较绕,在android9之前,各厂商的实现不了,android7以下可以实现。但是现在很…

打破疑惑:一次搞懂hasattr()、getattr()、setattr()在Python中的应用

简介 在Python中,hasattr()、getattr()和setattr()是一组内置函数,用于对对象的属性进行操作和查询。这些函数提供了一种方便的方式来检查对象是否具有特定属性,获取属性的值,以及设置属性的值。本文将从入门到精通,全…

第二十二章 原理篇:UP-DETR

最近一直在忙各种各样的面试,顺便重新刷了一遍西瓜书。 感觉自己快八股成精了,但是一到写代码的环节就拉跨,人真是麻了。 许愿搬家前可以拿到offer! 参考教程: https://arxiv.org/pdf/2011.09094.pdf https://zhuanla…

【NLP pytorch】基于BiLSTM-CRF模型医疗数据实体识别实战(项目详解)

基于BiLSTM-CRF模型医疗数据实体识别实战 1数据来源与加载1.1 数据来源1.2 数据类别名称和定义1.3 数据介绍2 模型介绍2 数据预处理2.1 数据读取2.2 数据标注2.3 数据集划分2.4 词表和标签的生成3 Dataset和DataLoader3.1 Dataset3.2 DataLoader4 BiLSTM模型定义5 CRF模型6 模型…

浏览器 判断

浏览器的 类别判断 无非就是从三个方面&#xff1a; 是否是 移动端 判断是否为 微信浏览器 判断浏览器 所在的 系统 判断 <!DOCTYPE html> <html><head><meta charset"utf-8"><meta name"viewport" content"widthdevic…