Pytorch API

news/2025/3/6 2:28:40/
  1. torch.squeeze(input, dim=None)
    将给定的 input 这个 tensor 中,大小为 1 的 dim 全部压缩。
    如下例子:

    import torch
    t = torch.tensor([[1], [2], [3]])
    print(t) # tensor([[1], [2], [3]]) shape=(3,1)t = torch.squeeze(t)
    print(t) # tensor([1, 2, 3]) shape=(3,)
    
  2. torch.unsqueeze(input, dim)
    将给定的 input 这个 tensor 中,指定的 dim 扩充一维
    如下例子:

    import torch
    t = torch.tensor([1, 2, 3])
    print(torch.unsqueeze(t, 0)) # tensor([[1, 2, 3]]) shape=(1,3)print(torch.unsqueeze(t, 1)) # tensor([[1], [2], [3]]) shape=(3,1)
    
  3. torch.index_select(input, dim, index, *, out=None)
    在给定的 input 这个 tensor 中,选择维度 dim ,然后在这个维度中选择索引 index 的部分返回。
    如下例子:

    import torcht = torch.arange(1, 13).view(3, 4)print(t)
    """
    tensor([[ 1,  2,  3,  4],[ 5,  6,  7,  8],[ 9, 10, 11, 12]])
    shape=(3, 4)
    """indices = torch.tensor([0, 2])print(torch.index_select(t, 0, indices))
    """
    tensor([[ 1,  2,  3,  4],[ 9, 10, 11, 12]])
    选择 dim=0 ,有 (0, 1, 2) 三个,选择第 0 行和第 2 行
    """print(torch.index_select(t, 1, indices))
    """
    tensor([[ 1,  3],[ 5,  7],[ 9, 11]])
    选择 dim=1 ,有 (0, 1, 2, 3) 四个,选择第 0 列和第 4 列
    """
    
  4. torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)
    将给定的 input 这个 tensor 按照给定的 dim 计算范数,具体计算的是什么范数由 p 决定
    p=1 表示第一范数,即 tensor 中每个元素绝对值之和
    p=2 表示第二范数,即 tensor 中每个元素平方和的和,再开根号
    其他表示无穷范数,即 tensor 中绝对值最大的元素

    import torch
    """inputs.shape = (3, 3, 4)
    """
    inputs = torch.tensor([[[ 1.,  2.,  3.,  4.],[ 2.,  4.,  6.,  8.],[ 3.,  6.,  9., 12.]],[[ 1.,  2.,  3.,  4.],[ 2.,  4.,  6.,  8.],[ 3.,  6.,  9., 12.]],[[ 1.,  2.,  3.,  4.],[ 2.,  4.,  6.,  8.],[ 3.,  6.,  9., 12.]]])"""
    inputs1.shape = (1, 3, 4)
    对于 dim=0 进行 L2 范数的计算,就是考虑将 (i, j, k) 其中所有的 i 的元素平方和加起来再开根号
    这里 sqrt((0, 0, 0)^2 + (1, 0, 0)^2 + (2, 0, 0)^2) = sqrt(3) = 1.7321
    tensor([[[ 1.7321,  3.4641,  5.1962,  6.9282],[ 3.4641,  6.9282, 10.3923, 13.8564],[ 5.1962, 10.3923, 15.5885, 20.7846]]])
    """
    inputs1 = torch.norm(inputs, p=2, dim=0, keepdim=True)
    print(inputs1)"""
    inputs2.shape = (3, 1, 4)
    对于 dim=1 进行 L2 范数的计算,就是考虑将 (i, j, k) 其中所有的 j 的元素平方和加起来再开根号
    这里 sqrt((0, 0, 0)^2 + (0, 1, 0)^2 + (0, 2, 0)^2) = sqrt(1+4+9) = 3.7417
    tensor([[[ 3.7417,  7.4833, 11.2250, 14.9666]],[[ 3.7417,  7.4833, 11.2250, 14.9666]],[[ 3.7417,  7.4833, 11.2250, 14.9666]]])
    """
    inputs2 = torch.norm(inputs, p=2, dim=1, keepdim=True)
    print(inputs2)"""
    inputs3.shape = (3, 3, 1)
    对于 dim=2 进行 L2 范数的计算,就是考虑将 (i, j, k) 其中所有的 k 的元素平方和加起来再开根号
    这里 sqrt((0, 0, 0)^2+(0, 0, 1)^2+(0, 0, 2)^2+(0, 0, 3)^2) = sqrt(1+4+9+16) = 5.4772
    tensor([[[ 5.4772],[10.9545],[16.4317]],[[ 5.4772],[10.9545],[16.4317]],[[ 5.4772],[10.9545],[16.4317]]])
    """
    inputs3 = torch.norm(inputs, p=2, dim=2, keepdim=True)
    print(inputs3)
    
  5. torch.chunk(input, chunks, dim=0) → List of Tensors
    input 这个 tensor 分成 chunks 个 tensors ,按照 dim 来划分。

    import torcht = torch.arange(1, 28).view(3, 3, 3)"""
    tensor([[[ 1,  2,  3],[ 4,  5,  6],[ 7,  8,  9]],[[10, 11, 12],[13, 14, 15],[16, 17, 18]],[[19, 20, 21],[22, 23, 24],[25, 26, 27]]])
    shape = (3, 3, 3)
    """
    print(t)"""
    按照 dim=0 划分,那么划分结果为
    (tensor([[[1, 2, 3],[4, 5, 6],[7, 8, 9]]]), tensor([[[10, 11, 12],[13, 14, 15],[16, 17, 18]]]), tensor([[[19, 20, 21],[22, 23, 24],[25, 26, 27]]]))
    """
    print(torch.chunk(t, chunks=3, dim=0))"""
    按照 dim=1 划分,那么划分结果为
    (tensor([[[ 1,  2,  3]],[[10, 11, 12]],[[19, 20, 21]]]), tensor([[[ 4,  5,  6]],[[13, 14, 15]],[[22, 23, 24]]]), tensor([[[ 7,  8,  9]],[[16, 17, 18]],[[25, 26, 27]]]))
    """
    print(torch.chunk(t, chunks=3, dim=1))"""
    按照 dim=2 划分,那么划分结果为
    (tensor([[[ 1],[ 4],[ 7]],[[10],[13],[16]],[[19],[22],[25]]]), 
    tensor([[[ 2],[ 5],[ 8]],[[11],[14],[17]],[[20],[23],[26]]]), 
    tensor([[[ 3],[ 6],[ 9]],[[12],[15],[18]],[[21],[24],[27]]]))
    """
    print(torch.chunk(t, chunks=3, dim=2))
    

http://www.ppmy.cn/news/1143923.html

相关文章

请求的转发和重定向

RequestDispatcher接口实现转发: jsp1上链接到Servlet,Servlet再转发(关键在这里怎么实现转发??) 演示index.html页面---->Servlet1(转发到)------>Servlet2 实现转发流程 1.用HttpServletReques…

运行vue create vue-demo报错:无法加载文件,因为在此系统上禁止运行脚本

运行vue create vue-demo报错:无法加载文件 C:\Users\sun\AppData\Roaming\nvm\nodejs\vue.ps1,因为在此系统上禁止运行脚本。 因为win中段禁止脚本运行,需要调整运行策略,打开PowerShell(只能PowerShell)&#xff0c…

【Leetcode】189. 轮转数组

一、题目 1、题目描述 给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。 示例1: 输入: nums = [1,2,3,4,5,6,7], k = 3 输出: [5,6,7,1,2,3,4] 解释: 向右轮转 1 步: [7,1,2,3,4,5,6] 向右轮转 2 步: [6,7,1,2,3,4,5] 向右轮转 3 步: [5,6,7,1…

Spark 9:Spark 新特性

Spark 3.0 新特性 Adaptive Query Execution 自适应查询(SparkSQL) 由于缺乏或者不准确的数据统计信息(元数据)和对成本的错误估算(执行计划调度)导致生成的初始执行计划不理想,在Spark3.x版本提供Adaptive Query Execution自适应查询技术,通过在”运行…

2.1 Qemu系统模拟:简介

目录 1 后端/加速器2 特性简介3 运行 1 后端/加速器 系统模拟主要用于在host设备上运行guest OSQEMU支持多种hypervisors,同时也支持JIT模拟方案(TCG) 例如从上表我们可以看出,运行在x86硬件上的Linux系统支持KVM,Xen,TCG 2 特性简介 提供…

浅谈分散式存储项目MEMO

Memo本质上是互联网项目,应用了一些区块链技术而已,或者叫做包了层区块链皮的互联网项目。 最开始对标Filcoin,后来发现Filcoin也有问题,分布式存储解决方案并不完美,抑或者是自己团队的研发能力无法与IPFS团队PK&…

解决Playwright无法登录Google账号的问题

文章目录 问题描述解决问题免费登录生成代码问题描述 当使用playwright需要登入google帐号的时候,有可能会出现下面的情况:无法登录,提示浏览器不安全(因为我们是脚本使用) 【Python自学笔记】微软自动化测试工具playwright,微软版selenium解决问题 解决上面这个无法登入…

【Linux】基本指令-入门级文件操作(一)

目录 前言 ⭕linux的树状文件结构 ⭕绝对路径和相对路径 ⭕当前路径和上级路径 ⭕隐藏文件 基本指令(重点) 1 pwd 指令 2 mkdir 指令 3 touch 指令 4 ls 指令 4.1 ls只加选项不加文件/目录名,默认查看当前目录下的文件 4.1.1 ls -a…