文章目录
- 1. torch.roll
- 2. pytorch代码
1. torch.roll
torch.roll 的作用是可以将矩阵A中的元素按照指定的维度移动指定步长
- excel 图示:
pytorch_6">2. pytorch代码
python">import torch
import torch.nn as nn
import torch.nn.functional as Fif __name__=="__main__":run_code = 0batch_size = 2mat1_w = 3mat1_h = 4mat1_total = batch_size*mat1_w*mat1_hmat1 = torch.arange(mat1_total).reshape(batch_size,mat1_w,mat1_h)print(f"mat1=\n{mat1}")mat1_roll0 = torch.roll(input=mat1,shifts=1,dims=(0))print(f"mat1_roll0=\n{mat1_roll0}")mat1_roll1 = torch.roll(input=mat1,shifts=1,dims=(1))print(f"mat1_roll1=\n{mat1_roll1}")mat1_roll2 = torch.roll(input=mat1,shifts=1,dims=(2))print(f"mat1_roll2=\n{mat1_roll2}")
- 结果:
python">mat1=
tensor([[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])
mat1_roll0=
tensor([[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]],[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]]])
mat1_roll1=
tensor([[[ 8, 9, 10, 11],[ 0, 1, 2, 3],[ 4, 5, 6, 7]],[[20, 21, 22, 23],[12, 13, 14, 15],[16, 17, 18, 19]]])
mat1_roll2=
tensor([[[ 3, 0, 1, 2],[ 7, 4, 5, 6],[11, 8, 9, 10]],[[15, 12, 13, 14],[19, 16, 17, 18],[23, 20, 21, 22]]])