如何理解tensor中张量的维度

devtools/2024/11/28 14:10:23/

理解 dim=0dim=1dim=2 以及 (x, y, z) 的意思,关键在于明确每个维度在张量中的作用。让我们通过具体的例子来详细解释这些概念。

三维张量的维度

一个三维张量可以看作是一个三维数组,通常用形状 (x, y, z) 来表示。这里的 xyz 分别表示张量在三个不同维度上的大小。

  • x 维度:通常称为批处理维度(batch dimension),表示数据的数量或批次。
  • y 维度:通常称为特征维度(feature dimension),表示每个数据点的特征数量。
  • z 维度:通常称为通道维度(channel dimension),表示每个特征的通道数量。

具体例子

假设我们有一个三维张量 tensor,其形状为 (2, 3, 4)。这个张量可以看作是一个包含 2 个批次的数据,每个批次有 3 个特征,每个特征有 4 个通道。

python">import torch# 创建一个形状为 (2, 3, 4) 的三维张量
tensor = torch.randn(2, 3, 4)
print(tensor)

输出可能如下所示:

tensor([[[ 0.1234,  0.5678, -0.9101,  0.2345],[-0.3456,  0.6789,  0.1234, -0.5678],[ 0.7890, -0.1234,  0.5678,  0.9101]],[[-0.2345,  0.3456, -0.4567,  0.5678],[ 0.6789, -0.7890,  0.8901, -0.9012],[-0.1234,  0.2345, -0.3456,  0.4567]]])

拼接操作

现在我们来理解在不同维度上进行拼接操作的意义。

1. 在 dim=0 上拼接
  • 意义:在 dim=0 上拼接意味着在批处理维度上增加数据的数量。也就是说,我们将两个张量在第一个维度上合并,形成一个新的张量。
  • 要求tensor_atensor_bdim=1dim=2 上的大小必须相同。
  • 结果:拼接后的张量形状为 (x1 + x2, y, z)
python">tensor_a = torch.randn(2, 3, 4)  # 形状为 (2, 3, 4)
tensor_b = torch.randn(2, 3, 4)  # 形状为 (2, 3, 4)tensor_c = torch.cat((tensor_a, tensor_b), dim=0)  # 结果形状为 (4, 3, 4)
print("在dim=0上拼接后的形状:", tensor_c.shape)
2. 在 dim=1 上拼接
  • 意义:在 dim=1 上拼接意味着在特征维度上增加特征的数量。也就是说,我们将两个张量在第二个维度上合并,形成一个新的张量。
  • 要求tensor_atensor_bdim=0dim=2 上的大小必须相同。
  • 结果:拼接后的张量形状为 (x, y1 + y2, z)
python">tensor_d = torch.cat((tensor_a, tensor_b), dim=1)  # 结果形状为 (2, 6, 4)
print("在dim=1上拼接后的形状:", tensor_d.shape)
3. 在 dim=2 上拼接
  • 意义:在 dim=2 上拼接意味着在通道维度上增加通道的数量。也就是说,我们将两个张量在第三个维度上合并,形成一个新的张量。
  • 要求tensor_atensor_bdim=0dim=1 上的大小必须相同。
  • 结果:拼接后的张量形状为 (x, y, z1 + z2)
python">tensor_e = torch.cat((tensor_a, tensor_b), dim=2)  # 结果形状为 (2, 3, 8)
print("在dim=2上拼接后的形状:", tensor_e.shape)

图解

假设 tensor_atensor_b 都是形状为 (2, 3, 4) 的张量,可以用以下图解来帮助理解:

tensor_a:
[[[a11, a12, a13, a14],[a21, a22, a23, a24],[a31, a32, a33, a34]],[[a41, a42, a43, a44],[a51, a52, a53, a54],[a61, a62, a63, a64]]
]tensor_b:
[[[b11, b12, b13, b14],[b21, b22, b23, b24],[b31, b32, b33, b34]],[[b41, b42, b43, b44],[b51, b52, b53, b54],[b61, b62, b63, b64]]
]
  • dim=0 上拼接

    [[[a11, a12, a13, a14],[a21, a22, a23, a24],[a31, a32, a33, a34]],[[a41, a42, a43, a44],[a51, a52, a53, a54],[a61, a62, a63, a64]],[[b11, b12, b13, b14],[b21, b22, b23, b24],[b31, b32, b33, b34]],[[b41, b42, b43, b44],[b51, b52, b53, b54],[b61, b62, b63, b64]]
    ]
    
  • dim=1 上拼接

    [[[a11, a12, a13, a14],[a21, a22, a23, a24],[a31, a32, a33, a34],[b11, b12, b13, b14],[b21, b22, b23, b24],[b31, b32, b33, b34]],[[a41, a42, a43, a44],[a51, a52, a53, a54],[a61, a62, a63, a64],[b41, b42, b43, b44],[b51, b52, b53, b54],[b61, b62, b63, b64]]
    ]
    
  • dim=2 上拼接

    [[[a11, a12, a13, a14, b11, b12, b13, b14],[a21, a22, a23, a24, b21, b22, b23, b24],[a31, a32, a33, a34, b31, b32, b33, b34]],[[a41, a42, a43, a44, b41, b42, b43, b44],[a51, a52, a53, a54, b51, b52, b53, b54],[a61, a62, a63, a64, b61, b62, b63, b64]]
    ]
    

理解四维张量的关键在于明确每个维度的作用。四维张量通常用于表示批量的图像数据,其中每个图像有多个通道(例如RGB图像)。让我们详细解释四维张量的各个维度及其含义。

四维张量的维度

假设你有一个四维张量 tensor,其形状为 (N, C, H, W)。这里的 NCHW 分别表示张量在四个不同维度上的大小。

  • N 维度:批处理维度(batch dimension),表示数据的数量或批次。
  • C 维度:通道维度(channel dimension),表示每个图像的通道数量(例如,RGB图像有3个通道)。
  • H 维度:高度维度(height dimension),表示图像的高度。
  • W 维度:宽度维度(width dimension),表示图像的宽度。

具体例子

假设我们有一个四维张量 tensor,其形状为 (2, 3, 4, 4)。这个张量可以看作是一个包含 2 张图像的数据集,每张图像有 3 个通道(RGB),高度为 4,宽度为 4。

python">import torch# 创建一个形状为 (2, 3, 4, 4) 的四维张量
tensor = torch.randn(2, 3, 4, 4)
print(tensor)

输出可能如下所示:

tensor([[[[ 0.1234,  0.5678, -0.9101,  0.2345],[-0.3456,  0.6789,  0.1234, -0.5678],[ 0.7890, -0.1234,  0.5678,  0.9101],[-0.2345,  0.3456, -0.4567,  0.5678]],[[-0.2345,  0.3456, -0.4567,  0.5678],[ 0.6789, -0.7890,  0.8901, -0.9012],[-0.1234,  0.2345, -0.3456,  0.4567],[ 0.7890, -0.1234,  0.5678,  0.9101]],[[ 0.1234,  0.5678, -0.9101,  0.2345],[-0.3456,  0.6789,  0.1234, -0.5678],[ 0.7890, -0.1234,  0.5678,  0.9101],[-0.2345,  0.3456, -0.4567,  0.5678]]],[[[ 0.1234,  0.5678, -0.9101,  0.2345],[-0.3456,  0.6789,  0.1234, -0.5678],[ 0.7890, -0.1234,  0.5678,  0.9101],[-0.2345,  0.3456, -0.4567,  0.5678]],[[-0.2345,  0.3456, -0.4567,  0.5678],[ 0.6789, -0.7890,  0.8901, -0.9012],[-0.1234,  0.2345, -0.3456,  0.4567],[ 0.7890, -0.1234,  0.5678,  0.9101]],[[ 0.1234,  0.5678, -0.9101,  0.2345],[-0.3456,  0.6789,  0.1234, -0.5678],[ 0.7890, -0.1234,  0.5678,  0.9101],[-0.2345,  0.3456, -0.4567,  0.5678]]]])

拼接操作

现在我们来理解在不同维度上进行拼接操作的意义。

1. 在 dim=0 上拼接
  • 意义:在 dim=0 上拼接意味着在批处理维度上增加数据的数量。也就是说,我们将两个张量在第一个维度上合并,形成一个新的张量。
  • 要求tensor_atensor_bdim=1dim=2dim=3 上的大小必须相同。
  • 结果:拼接后的张量形状为 (N1 + N2, C, H, W)
python">tensor_a = torch.randn(2, 3, 4, 4)  # 形状为 (2, 3, 4, 4)
tensor_b = torch.randn(2, 3, 4, 4)  # 形状为 (2, 3, 4, 4)tensor_c = torch.cat((tensor_a, tensor_b), dim=0)  # 结果形状为 (4, 3, 4, 4)
print("在dim=0上拼接后的形状:", tensor_c.shape)
2. 在 dim=1 上拼接
  • 意义:在 dim=1 上拼接意味着在通道维度上增加通道的数量。也就是说,我们将两个张量在第二个维度上合并,形成一个新的张量。
  • 要求tensor_atensor_bdim=0dim=2dim=3 上的大小必须相同。
  • 结果:拼接后的张量形状为 (N, C1 + C2, H, W)
python">tensor_d = torch.cat((tensor_a, tensor_b), dim=1)  # 结果形状为 (2, 6, 4, 4)
print("在dim=1上拼接后的形状:", tensor_d.shape)
3. 在 dim=2 上拼接
  • 意义:在 dim=2 上拼接意味着在高度维度上增加高度的数量。也就是说,我们将两个张量在第三个维度上合并,形成一个新的张量。
  • 要求tensor_atensor_bdim=0dim=1dim=3 上的大小必须相同。
  • 结果:拼接后的张量形状为 (N, C, H1 + H2, W)
python">tensor_e = torch.cat((tensor_a, tensor_b), dim=2)  # 结果形状为 (2, 3, 8, 4)
print("在dim=2上拼接后的形状:", tensor_e.shape)
4. 在 dim=3 上拼接
  • 意义:在 dim=3 上拼接意味着在宽度维度上增加宽度的数量。也就是说,我们将两个张量在第四个维度上合并,形成一个新的张量。
  • 要求tensor_atensor_bdim=0dim=1dim=2 上的大小必须相同。
  • 结果:拼接后的张量形状为 (N, C, H, W1 + W2)
python">tensor_f = torch.cat((tensor_a, tensor_b), dim=3)  # 结果形状为 (2, 3, 4, 8)
print("在dim=3上拼接后的形状:", tensor_f.shape)

图解

假设 tensor_atensor_b 都是形状为 (2, 3, 4, 4) 的张量,可以用以下图解来帮助理解:

tensor_a:
[[[[a1111, a1112, a1113, a1114],[a1121, a1122, a1123, a1124],[a1131, a1132, a1133, a1134],[a1141, a1142, a1143, a1144]],[[a1211, a1212, a1213, a1214],[a1221, a1222, a1223, a1224],[a1231, a1232, a1233, a1234],[a1241, a1242, a1243, a1244]],[[a1311, a1312, a1313, a1314],[a1321, a1322, a1323, a1324],[a1331, a1332, a1333, a1334],[a1341, a1342, a1343, a1344]]],[[[a2111, a2112, a2113, a2114],[a2121, a2122, a2123, a2124],[a2131, a2132, a2133, a2134],[a2141, a2142, a2143, a2144]],[[a2211, a2212, a2213, a2214],[a2221, a2222, a2223, a2224],[a2231, a2232, a2233, a2234],[a2241, a2242, a2243, a2244]],[[a2311, a2312, a2313, a2314],[a2321, a2322, a2323, a2324],[a2331, a2332, a2333, a2334],[a2341, a2342, a2343, a2344]]]
]tensor_b:
[[[[b1111, b1112, b1113, b1114],[b1121, b1122, b1123, b1124],[b1131, b1132, b1133, b1134],[b1141, b1142, b1143, b1144]],[[b1211, b1212, b1213, b1214],[b1221, b1222, b1223, b1224],[b1231, b1232, b1233, b1234],[b1241, b1242, b1243, b1244]],[[b1311, b1312, b1313, b1314],[b1321, b1322, b1323, b1324],[b1331, b1332, b1333, b1334],[b1341, b1342, b1343, b1344]]],[[[b2111, b2112, b2113, b2114],[b2121, b2122, b2123, b2124],[b2131, b2132, b2133, b2134],[b2141, b2142, b2143, b2144]],[[b2211, b2212, b2213, b2214],[b2221, b2222, b2223, b2224],[b2231, b2232, b2233, b2234],[b2241, b2242, b2243, b2244]],[[b2311, b2312, b2313, b2314],[b2321, b2322, b2323, b2324],[b2331, b2332, b2333, b2334],[b2341, b2342, b2343, b2344]]]
]
  • dim=0 上拼接
    [[[[a1111, a1112, a1113, a1114],[a1121, a1122, a1123, a1124],[a1131, a1132, a1133, a1134],[a1141, a1142, a1143, a1144]],[[a1211, a1212, a1213, a1214],[a1221, a1222, a1223, a1224],[a1231, a1232, a1233, a1234],[a1241, a1242, a1243, a1244]],[[a1311, a1312, a1313, a1314],[a1321, a1322, a1323, a1324],[a1331, a1332, a1333, a1334],[a1341, a1342, a1343, a1344]]],[[[a2111, a2112, a2113, a2114],[a2121, a2122, a2123, a2124],[a2131, a2132, a2133, a2134],[a2141, a2142, a2143, a2144]],[[a2211, a2212, a2213, a2214],[a2221, a2222, a2223, a2224],[a2231, a2232, a2233, a2234],[a2241, a2242, a2243, a2244]],[[a2311, a2312, a2313, a2314],[a2321, a2322, a2323, a2324],[a2331, a2332, a2333, a2334],[a2341, a2342, a2343, a2344]]],[[[b1111, b1112, b1113, b1114],[b1121, b1122, b1123, b1124],[b1131, b1132, b1133, b1134],[b1141, b1142, b1143, b1144]],[[b1211, b1212, b1213, b1214],[b1221, b1222, b1223, b1224],[b1231, b1232, b1233, b1234],[b1241, b1242, b1243, b1244]],[[b1311, b1312, b1313, b1314],[b1321, b1322, b1323, b1324],[b1331, b1332, b1333, b1334],[b1341, b1342, b1343, b1344]]],[[[b2111, b2112, b2113, b2114],[b2121, b2122, b2123, b2124],[b2131, b2132, b2133, b2134],[b2141, b2142, b2143, b2144]],[[b2211, b2212, b2213, b2214],[b2221, b2222, b2223, b2224],[b2231, b2232, b2233, b2234],[b2241, b2242, b2243, b2244]],[[b2311, b2312, b2313, b2314],[b2321, b2322, b2323, b2324],[b2331, b2332, b2333, b2334],[b2341, b2342, b2343, b2344]]]
    ]
    

http://www.ppmy.cn/devtools/137680.html

相关文章

学习使用jquery实现在指定div前面增加内容

学习使用jquery实现在指定div前面增加内容 设计思路代码示例 设计思路 选择要添加内容的指定元素‌: 使用jQuery选择器来选择你希望在其前添加内容的元素。例如,如果你有一个 元素,其ID为qipa250,你可以使用$(‘#qipa250’)来选择…

【carla生成车辆时遇到的问题】carla显示的坐标和carlaworld中提取的坐标y值相反

项目需要重新运行了一下generate_car.py的脚本,发现死活生成不了,研究了半天,发现脚本里面生成车辆的坐标值y和carla_ros_bridge_with_example_ego_vehicle.launch脚本打开的驾驶操控界面里面的y值正好是相反数! y1-y2 因为,我运行…

『VUE』elementUI dialog的子组件created生命周期不刷新(详细图文注释)

目录 1. 测试代码分析令人迷惑的效果 分析原因解决方法 如何在dialog中反复触发created呢?总结 欢迎关注 『VUE』 专栏,持续更新中 欢迎关注 『VUE』 专栏,持续更新中 主要是在做表单的时候想要有一个编辑表单在dialog弹窗中出现,同时dialog调用的封装的…

使用docker搭建hysteria2服务端

源链接:https://github.com/apernet/hysteria/discussions/1248 官网地址:https://v2.hysteria.network/zh/docs/getting-started/Installation/ 首选需要安装docker和docker compose 切换到合适的目录 cd /home创建文件夹 mkdir hysteria创建docke…

neo4j图数据库community-5.50创建多个数据库————————————————

1.找到neo4J中的conf文件,我的路径是:D:\Program Files\neo4j-community-5.5.0-windows\neo4j-community-5.5.0\conf 这里找自己的安装路径, 2.用管理员模式打开conf文件,右键管理员,记事本或者not 3.选中的一行新建一…

通过抓包,使用frida定位加密位置

首先我们抓取一下我们要测试的app的某一个目标api,通过抓api的包,得到关键字。 例如:关键字:x-sap-ri 我们得到想要的关键字后,通过拦截 类,寻找我们的关键字,及找到发包收包的位置&#xff0c…

D 型 GaN HEMT 在功率转换方面的优势

氮化镓 (GaN) 是一种 III-V 族宽带隙半导体,由于在用作横向高电子迁移率晶体管 (HEMT) 时具有卓越的材料和器件性能,因此在功率转换应用中得到越来越多的采用。 HEMT 中产生的高击穿电场 (3.3 MV/cm) 和高二维电子气 (2DEG) 载流子迁移率 (2,000 cm 2 /…

RabbitMQ 安装延迟队列插件 rabbitmq_delayed_message_exchange

前言: RabbitMQ 延迟队列插件(rabbitmq_delayed_message_exchange)是一个社区开发的插件,它为 RabbitMQ 添加了支持延迟消息的功能。通过这个插件,用户可以创建一种特殊的交换机类型 x-delayed-message,该…