在操作张量时,经常要去获取某些元素进行处理或者修改操作,在这里需要了解torch中的索引操作。
准备数据:
python">data = torch.randint(0,10,[4,5])
print('data--->',data)
输出结果:
data---> tensor([[3, 9, 4, 0, 5],[7, 5, 9, 9, 7],[5, 9, 8, 9, 7],[9, 2, 6, 7, 7]])
-
简单行、列索引
python">print('第一行:',data[0]) print('第一列:',data[:,0])
输出结果:
第一行: tensor([3, 9, 4, 0, 5]) 第一列: tensor([3, 7, 5, 9])
-
列表索引
python">print('-----------------返回(0,1)、(1,2) 2个位置的元素------------------') print(data[[0,1],[1,2]]) print('-----------------返回0、1 行的1、2 列共4个元素------------------') print(data[[[0],[1]],[1,2]])
输出结果:
-----------------返回(0,1)、(1,2) 2个位置的元素------------------ tensor([9, 9]) -----------------返回0、1 行的1、2 列共4个元素------------------ tensor([[9, 4],[5, 9]])
-
范围索引
python">print('-----------------前3行、前2列的数据------------------') print(data[:3,:2]) print('-----------------第2行到最后的前2列数据------------------') print(data[2:,:2])
输出结果:
-----------------前3行、前2列的数据------------------ tensor([[3, 9],[7, 5],[5, 9]]) -----------------第2行到最后的前2列数据------------------ tensor([[5, 9],[9, 2]])
-
布尔索引
python">print('-----------------第三列大于5的行数据------------------') print(data[data[:,2] > 5]) print('-----------------第二行大于5的行数据------------------') print(data[:,data[1] > 5])
输出结果:
-----------------第三列大于5的行数据------------------ tensor([[7, 5, 9, 9, 7],[5, 9, 8, 9, 7],[9, 2, 6, 7, 7]]) -----------------第二行大于5的行数据------------------ tensor([[3, 4, 0, 5],[7, 9, 9, 7],[5, 8, 9, 7],[9, 6, 7, 7]])
-
多维索引
python">data = torch.randint(0,10,[3,4,5]) print(data) # 获取0轴上的第一个数据 print(data[0,:,:]) # 获取1轴上的第一个数据 print(data[:,0,:]) # 获取2轴上的第一个数据 print(data[:,:,0])
输出结果:
tensor([[[8, 3, 6, 1, 5],[5, 0, 4, 3, 8],[8, 3, 3, 5, 0],[6, 4, 0, 8, 4]],[[7, 2, 3, 8, 5],[6, 2, 9, 5, 0],[4, 2, 7, 1, 1],[5, 4, 4, 1, 1]],[[2, 4, 7, 2, 5],[6, 1, 4, 5, 6],[9, 2, 3, 1, 0],[2, 1, 2, 7, 9]]]) tensor([[8, 3, 6, 1, 5],[5, 0, 4, 3, 8],[8, 3, 3, 5, 0],[6, 4, 0, 8, 4]]) tensor([[8, 3, 6, 1, 5],[7, 2, 3, 8, 5],[2, 4, 7, 2, 5]]) tensor([[8, 5, 8, 6],[7, 6, 4, 5],[2, 6, 9, 2]])