torch.gather(input_tensor, dim=1, index=index_tensor)
dim=0代表按着行的顺序取,即列方向上取;
dim=1代表按着列的顺序取,即行方向上取。
import torch
# 示例输入张量 (2D)
input_tensor = torch.tensor([[10, 20, 30],
[40, 50, 60]])
# 索引张量 (2D)
index_tensor = torch.tensor([[0, 2], # 第 0 行选择第 0 和第 2 个元素
[1, 1]]) # 第 1 行选择第 1 个元素两次
# 使用 torch.gather 从 input_tensor 中根据 index_tensor 选择元素
output_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output_tensor)
输出
tensor([[10, 30],
[50, 50]])