20250110_ PyTorch中的张量操作

news/2025/1/14 22:22:11/

文章目录

  • 前言
  • 1、torch.cat 函数
  • 2、索引、维度扩展和张量的广播
  • 3、切片操作
    • 3.1、 encoded_first_node
    • 3.2、probs
  • 4、长难代码分析
    • 4.1、selected
      • 4.1.1、multinomial(1)工作原理:
  • 总结


前言


1、torch.cat 函数

torch.cat 函数将两个张量拼接起来,具体地是在第三个维度(dim=2)上进行拼接。注:dim取值范围是0~2

node_xy_demand = torch.cat((node_xy, node_demand[:, :, None]), dim=2)

其中所用参数为:

node_xy = reset_state.node_xy
# shape: (batch, problem, 2)
node_demand = reset_state.node_demand
# shape: (batch, problem)

若要拼接node_xy 与node_demand 需要将node_demand 进行维度拓展node_demand[:, :, None])

node_xy = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
node_demand = torch.tensor([[[10], [20]], [[30], [40]]])
node_xy_demand = torch.tensor([[[ 1,  2, 10], [ 3,  4, 20]],[[ 5,  6, 30], [ 7,  8, 40]]])

2、索引、维度扩展和张量的广播

_ = self.decoder.regret_embedding[None, None, :].expand(encoded_nodes.size(0), 1, self.decoder.regret_embedding.size(-1))
  • self.decoder.regret_embedding是一个张量。
  • self.decoder.regret_embedding[None, None, :]增加regret_embedding的维度。维度扩展成 (1, 1, D)
.expand(encoded_nodes.size(0), 1, self.decoder.regret_embedding.size(-1))
  • expand 用来沿特定维度复制张量,以实现广播。
  • encoded_nodes.size(0) 返回的是 encoded_nodes 张量的第一个维度大小。
  • 1 表示第二个维度的大小。
  • self.decoder.regret_embedding.size(-1) 返回的是 self.decoder.regret_embedding 的最后一个维度的大小,也就是嵌入的维度 D

总结: 将张量建立为所需维度在此为三维,使用expand沿着新建维度进行拓展到所需形状


3、切片操作

3.1、 encoded_first_node

 encoded_first_node = self.encoded_nodes[:, [0], :]

这行代码中的切片操作是从 self.encoded_nodes 中提取特定的数据部分:

  • : 表示选择所有批次的样本,保留第一个维度(batch)。
  • [0] 表示选择每个样本中的第一个节点,因此提取的是第一个节点的嵌入向量。
  • : 表示选择该节点的所有嵌入维度,即保留第三个维度(embedding)的所有值。

最终,经过这些操作,encoded_first_node 的形状为 (batch, 1, embedding),即每个样本只包含第一个节点的嵌入向量,保留了嵌入维度。

3.2、probs

probs[:, :, :-1]
  • 这是对 probs 张量的切片操作,作用是从 probs 的第三个维度(即最后一个维度)中移除最后一列。
selected = probs.argmax(dim=2)
  • argmax(dim=2) 表示在 probs 张量的第3维度(类别维度)上,找到每个样本中概率最大的类别索引。

  • argmax 返回的是最大值的索引,而不是最大值本身。


4、长难代码分析

4.1、selected

selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1).squeeze(dim=1).reshape(batch_size, pomo_size)

prob的shape: (batch, pomo, problem+1)

  • probs.reshape(batch_size * pomo_size, -1)

    • 这一步将 probs 的形状从 (batch, pomo, problem + 1) 转变为 (batch * pomo, problem + 1)。
    • -1:表示自动推算出第二维的大小(即 problem + 1)
    • 新的形状 (batch * pomo, problem + 1)。
  • multinomial(1)

    • multinomial(1) 用于从给定的概率分布中选择一个类别。它会返回一个形状为 (batch_size * pomo_size, 1) 的张量,每一行选择一个元素的索引,代表从 probs 中选择的元素。
  • .squeeze(dim=1)

    • squeeze(dim=1) 是去除第二个维度(索引维度),将形状变为 (batch_size * pomo_size)
  • .reshape(batch_size, pomo_size)

    • 最后,通过 reshape(batch_size, pomo_size) 将张量恢复到原来的形状 (batch_size, pomo_size),即每个批次对应一个选择的元素索引。

4.1.1、multinomial(1)工作原理:

  • 输入:
    multinomial(1) 需要一个形状为 (N, C) 的张量,其中 N 是样本的数量,C 是类别的数量。这个张量表示每个样本在各个类别下的概率分布。

  • 输出:
    multinomial(1) 返回一个形状为 (N, 1) 的张量,每个元素是该样本选择的类别的索引。

具体来说,multinomial(1) 会根据每个类别的概率,从概率分布中选取一个类别。这个选择是随机的,但是会遵循给定的概率分布,即概率较大的类别被选中的几率较高,概率较小的类别被选中的几率较低。


总结


http://www.ppmy.cn/news/1563143.html

相关文章

第27章 汇编语言--- 设备驱动开发基础

汇编语言是低级编程语言的一种,它与特定的计算机架构紧密相关。在设备驱动开发中,汇编语言有时用于编写性能关键的部分或直接操作硬件,因为它是接近机器语言的代码,可以提供对硬件寄存器和指令集的直接访问。 要展开源代码详细叙…

Python单例模式的代码实现和原理

Python单例设计模式的代码实现 import threading import timeclass Singleton(object): def __new__(cls, *args, **kw):if not hasattr(cls, _instance):orig super(Singleton, cls)cls._instance orig.__new__(cls, *args, **kw)return cls._instanceclass Bus(Singleton…

数据链路层-STP

生成树协议STP(Spanning Tree Protocol) 它的实现目标是:在包含有物理环路的网络中,构建出一个能够连通全网各节点的树型无环逻辑拓扑。 选举根交换机: 选举根端口: 选举指定端口: 端口名字&…

2025软件测试面试题大全(含答案)备战“金三银四”

🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 001、软件的生命周期(prdctrm) 计划阶段(planning)-〉需求分析(requirement)-〉设计阶段(design)-〉编码(coding)->测试(testing)->运行与维护(running m…

Django后端相应类设计

通用的ApiResponse类:用于生成统一的 API 响应格式。每个响应都包含以下字段(每个接口最终的返回数据格式): status_code:HTTP 状态码(如 200、400、500 等)message:响应的描述信息…

LeetCode:39. 组合总和

跟着carl学算法,本系列博客仅做个人记录,建议大家都去看carl本人的博客,写的真的很好的! 代码随想录 LeetCode:39. 组合总和 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target ,找出 cand…

EdgeOne安全专项实践:上传文件漏洞攻击详解与防范措施

靶场搭建 当我们考虑到攻击他人服务器属于违法行为时,我们需要思考如何更好地保护我们自己的服务器。为了测试和学习,我们可以搭建一个专门的靶场来模拟文件上传漏洞攻击。以下是我搭建靶场的环境和一些参考资料,供大家学习和参考&#xff0…

[人工智能自学] Python包学习-pandas

紧接上篇numpy的学习教程 本篇参考: Pandas 教程|菜鸟教程 官方教程 - 10分钟入门pandas joyful-pandas pandas中文教程 它建立在 NumPy 库的基础之上,提供了高效的数据结构和数据分析工具,使得在 Python 中进行数据操作变得更加容易和高效。…