Tensor 基本操作1 unsqueeze, squeeze, softmax | PyTorch 深度学习实战

devtools/2025/1/23 18:06:00/

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

目录

    • 创建 Tensor
    • 常用操作
      • unsqueeze
      • squeeze
      • Softmax
        • 代码1
        • 代码2
        • 代码3
      • argmax
      • item

创建 Tensor

使用 Torch 接口创建 Tensor

import torch

参考:https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html

常用操作

unsqueeze

将多维数组解套,并嵌入新的一层维度。

    data = [[1, 2],[3, 4]]x_data = torch.tensor(data)print("x_data")print(x_data)x2_data = x_data.unsqueeze(-1)print("x_data>> unsqueeze -1")print(x2_data)x2_data = x_data.unsqueeze(0)print("x_data>> unsqueeze 0")print(x2_data)x2_data = x_data.unsqueeze(1)print("x_data>> unsqueeze 1")print(x2_data)x2_data = x_data.unsqueeze(2)print("x_data>> unsqueeze 2")print(x2_data)

结果:

x_data
tensor([[1, 2],[3, 4]])
x_data>> unsqueeze -1   # -1 代表最内层,将最内层的数用一个新的维度包起来
tensor([[[1],[2]],[[3],[4]]])
x_data>> unsqueeze 0 # 0 代表最外层,将原来的多维数组整个多套一层
tensor([[[1, 2],[3, 4]]])
x_data>> unsqueeze 1 # 代表原来第一维里的每个元素,套一层
tensor([[[1, 2]],[[3, 4]]])
x_data>> unsqueeze 2 # 代表原来第二维里的每个元素,套一层
tensor([[[1],        # 当前一共两维,所以效果和 -1 一样[2]],[[3],[4]]])

squeeze

去掉指定或全部的维度中只有一个元素的多维数组。

比如输入为 Ax1xBxCx1xD 维的数组,输出变成了 AxBxCxD 维的数组。

https://pytorch.org/docs/stable/generated/torch.squeeze.html
在这里插入图片描述

    data = [[1], [2],[3], [4]]x_data = torch.tensor(data)print("x_data")print(x_data)x2_data = x_data.squeeze()print("x_data>> squeeze")print(x2_data)x2_data = x_data.squeeze(1)print("x_data>> squeeze 1")print(x2_data)

结果:

x_data
tensor([[1],[2],[3],[4]])
x_data>> squeeze
tensor([1, 2, 3, 4])
x_data>> squeeze 1
tensor([1, 2, 3, 4])

Softmax

https://pytorch.org/docs/stable/generated/torch.softmax.html

归一化操作。
在这里插入图片描述

代码1
    data = torch.tensor([1,2,3], dtype=torch.float) # 维度 3; 注意,此处 dtype 是 int 或 long 接口报错x_data = torch.softmax(data, 0)print("x_data")print(x_data)

结果:

x_data
tensor([0.0900, 0.2447, 0.6652])  # 维度 3
代码2
    data = torch.tensor([[1],[2],[3]], dtype=torch.float) # 维度 3x1x_data2 = torch.softmax(data, 0)print("x_data2")print(x_data2)

结果:

x_data2  # 维度 3x1
tensor([[0.0900],[0.2447],[0.6652]])
代码3
    data = torch.tensor([[1],[2],[3]], dtype=torch.float) # 维度 3x1x_data2 = torch.softmax(data, 1) # 沿着第一维求print("x_data2")print(x_data2)

结果:

x_data2
tensor([[1.],[1.],[1.]])

此时,每维都是 1 个元素,针对自身求 softmax,所以,结果是 1.

argmax

https://pytorch.org/docs/stable/generated/torch.argmax.html

返回一个多维数组的最大值的索引,如果是多维数组,则返回第一维的索引。

在这里插入图片描述

item

https://pytorch.org/docs/stable/generated/torch.Tensor.item.html
返回一个 Tensor 中携带的 Python Number 对象。该接口只对 Tensor 是一维的有效。

x = torch.tensor([1.0])
x.item()

http://www.ppmy.cn/devtools/152941.html

相关文章

如何实现网页不用刷新也能更新

要实现用户在网页上不用刷新也能到下一题,可以使用 前端和后端交互的技术,比如 AJAX(Asynchronous JavaScript and XML)、Fetch API 或 WebSocket 来实现局部页面更新。以下是一个实现思路: 1. 使用前端 AJAX 或 Fetch…

Linux——线程条件变量(同步)

Linux——多线程的控制-CSDN博客 文章目录 目录 文章目录 前言 一、条件变量是什么? 1、死锁的必要条件 1. 互斥条件(Mutual Exclusion) 2. 请求和保持条件(Hold and Wait) 3. 不可剥夺条件(No Preemption&…

shell-特殊位置变量

目录 1.特殊位置变量 $n 2.特殊位置变量 $0 3.特殊位置变量$ # 4.特殊位置变量$*/$ 4.1 $* 4.2 $ 5.shift 命令 1.特殊位置变量 $n $n:表示传递给脚本或函数的第 n 个参数。 $1:第一个参数$2:第二个参数...$9:第九个参数…

chrome游览器JSON Formatter插件无效问题排查,FastJsonHttpMessageConverter导致Content-Type返回不正确

问题描述 chrome游览器又一款JSON插件叫JSON Formatter,游览器GET请求调用接口时,如果返回的数据是json格式,则会自动格式化展示,类似这样: 但是今天突然发现怎么也格式化不了,打开一个json文件倒是可以格…

微信小程序使用上拉加载onReachBottom。页面拖不动。一直无法触发上拉的事件。

1,可能是原因是你使用了scroll-view的标签,用onReachBottom触发加载事件。这两个是有冲突的。没办法一起使用。如果页面的样式是滚动的是无法去触发页面的onReachBottom的函数的。因此,你使用overflow:auto.来使用页面的某些元素滚动&#xf…

Spring源码04 - AOP深入设计原理

AOP深入设计原理 文章目录 AOP深入设计原理一:JVM下的AOP1:Java程序运行在JVM中的特征2:Java程序执行流3:引入了代理模式的Java程序执行流 二:Spring AOP工作原理 本文整理自csdn博主亦山 -> 《Spring设计思想》AOP…

每日一题——二分法求旋转数组的最小数字

二分法求旋转数组的最小数字 描述数据范围:要求: 示例示例 1示例 2 解题思路算法流程 代码实现关键点解释举例分析示例 1:[3, 4, 5, 1, 2]示例 2:[3, 100, 200, 3] 总结 描述 有一个长度为 n 的非降序数组,例如 [1, 2…

考研408笔记之数据结构(三)——串

数据结构(三)——串 1. 串的定义和基本操作 本节内容很少,重点是串的模式匹配,所以对于串的定义和基本操作,我就简单提一些易错点。另外,串也是一种特殊的线性表,只不过线性表是可以存储任何东…