Pytroch nn.Unfold() 与 nn.Fold()图码详解

news/2024/11/24 6:04:07/

文章目录

    • Unfold()与Fold()的用途
    • nn.Unfold()
      • Unfold()与Fold() 变化模式图解
    • nn.Fold()
      • 单通道 滑动窗口无重叠
      • 模拟图片数据(b,3,9,9),通道数 C 为3,滑动窗口无重叠。
      • 单通道 滑动窗口有重叠。
    • 卷积等价于:Unfold + Matrix Multiplication + Fold (或view()到卷积输出形状)

Unfold()与Fold()的用途

Unfold()Fold()一般成对出现。常用用途有:

  1. 代替卷积计算,Unfold()Fold()不互逆(参数不一样)(卷积本来就不可逆)
  2. 图片patch化,Unfold()Fold()互逆(参数一样,且滑动窗不重叠)

nn.Unfold()

Extracts sliding local blocks from a batched input tensor.
在各滑动窗中按行展开(行向量化),然后转置成列向量, im2col 的批量形式

input : (N, C, ∗)
output : (N, C × ∏(kernel_size), L)

# 滑动窗口有重叠
unfold = nn.Unfold(kernel_size=(2, 3))
input = torch.randn(2, 5, 3, 4)
print("input: \n", input)
output = unfold(input)
# each patch contains 30 values (2x3=6 vectors, each of 5 channels)
# 4 blocks (2x3 kernels) in total in the 3x4 input
# output.size()  # torch.Size([2, 30, 4])
print("output: \n", output)fold = nn.Fold((3,4),(2,3))
fold_output = fold(output)
print("fold_output: \n", fold_output)

输出为:

input: tensor([[[[ 0.4198,  1.0535,  0.1152,  0.3510],[ 1.1664,  0.3376,  1.2207,  0.3575],[-0.2174, -1.2490,  0.3432,  0.3388]],[[-1.4956,  0.9746, -0.5145,  0.1722],[ 1.7041,  0.9645, -0.6937, -1.9037],[-0.1961, -0.3345,  0.3565, -1.2329]],[[-0.9843, -0.8089,  1.8712, -0.2860],[ 0.0960, -1.7501, -0.1226,  0.9383],[-0.1675,  1.1498, -0.4958, -1.2953]],[[-1.2368,  0.5667,  1.4166, -2.2567],[ 0.9414,  0.8189,  1.5604,  0.1422],[-1.6414, -1.5594,  0.6718,  1.2319]],[[-0.4093,  0.6691,  1.4003,  0.7444],[-0.2858, -0.4375, -1.1301,  0.7377],[-0.0956, -0.1844,  0.7697, -0.3077]]],[[[ 0.4264, -0.0700, -1.5600, -0.0491],[ 1.5027,  3.1625,  0.6080, -1.8794],[-0.3148,  0.6377, -0.7242,  0.1692]],[[ 0.2757, -0.5403,  0.7748, -1.1795],[ 0.1504, -0.4671,  0.9355,  1.3050],[-0.4920, -0.8581,  0.0559, -0.0446]],[[ 2.1627,  0.6758, -0.0968,  1.3401],[-0.1105,  0.8299, -0.3827, -1.0687],[-0.2234, -1.0423,  1.2436, -0.6514]],[[ 0.8085, -0.4159,  0.2022,  0.5747],[-0.1265,  0.2828, -1.3530,  0.2831],[-0.1571,  0.9005,  0.4556, -1.4360]],[[-1.2417,  0.1829,  0.3825, -0.8555],[-2.0170,  0.7537,  2.3406,  0.5866],[-1.1704, -1.8986, -0.7958,  0.2652]]]])
output: tensor([[[ 0.4198,  1.0535,  1.1664,  0.3376],[ 1.0535,  0.1152,  0.3376,  1.2207],[ 0.1152,  0.3510,  1.2207,  0.3575],[ 1.1664,  0.3376, -0.2174, -1.2490],[ 0.3376,  1.2207, -1.2490,  0.3432],[ 1.2207,  0.3575,  0.3432,  0.3388],[-1.4956,  0.9746,  1.7041,  0.9645],[ 0.9746, -0.5145,  0.9645, -0.6937],[-0.5145,  0.1722, -0.6937, -1.9037],[ 1.7041,  0.9645, -0.1961, -0.3345],[ 0.9645, -0.6937, -0.3345,  0.3565],[-0.6937, -1.9037,  0.3565, -1.2329],[-0.9843, -0.8089,  0.0960, -1.7501],[-0.8089,  1.8712, -1.7501, -0.1226],[ 1.8712, -0.2860, -0.1226,  0.9383],[ 0.0960, -1.7501, -0.1675,  1.1498],[-1.7501, -0.1226,  1.1498, -0.4958],[-0.1226,  0.9383, -0.4958, -1.2953],[-1.2368,  0.5667,  0.9414,  0.8189],[ 0.5667,  1.4166,  0.8189,  1.5604],[ 1.4166, -2.2567,  1.5604,  0.1422],[ 0.9414,  0.8189, -1.6414, -1.5594],[ 0.8189,  1.5604, -1.5594,  0.6718],[ 1.5604,  0.1422,  0.6718,  1.2319],[-0.4093,  0.6691, -0.2858, -0.4375],[ 0.6691,  1.4003, -0.4375, -1.1301],[ 1.4003,  0.7444, -1.1301,  0.7377],[-0.2858, -0.4375, -0.0956, -0.1844],[-0.4375, -1.1301, -0.1844,  0.7697],[-1.1301,  0.7377,  0.7697, -0.3077]],[[ 0.4264, -0.0700,  1.5027,  3.1625],[-0.0700, -1.5600,  3.1625,  0.6080],[-1.5600, -0.0491,  0.6080, -1.8794],[ 1.5027,  3.1625, -0.3148,  0.6377],[ 3.1625,  0.6080,  0.6377, -0.7242],[ 0.6080, -1.8794, -0.7242,  0.1692],[ 0.2757, -0.5403,  0.1504, -0.4671],[-0.5403,  0.7748, -0.4671,  0.9355],[ 0.7748, -1.1795,  0.9355,  1.3050],[ 0.1504, -0.4671, -0.4920, -0.8581],[-0.4671,  0.9355, -0.8581,  0.0559],[ 0.9355,  1.3050,  0.0559, -0.0446],[ 2.1627,  0.6758, -0.1105,  0.8299],[ 0.6758, -0.0968,  0.8299, -0.3827],[-0.0968,  1.3401, -0.3827, -1.0687],[-0.1105,  0.8299, -0.2234, -1.0423],[ 0.8299, -0.3827, -1.0423,  1.2436],[-0.3827, -1.0687,  1.2436, -0.6514],[ 0.8085, -0.4159, -0.1265,  0.2828],[-0.4159,  0.2022,  0.2828, -1.3530],[ 0.2022,  0.5747, -1.3530,  0.2831],[-0.1265,  0.2828, -0.1571,  0.9005],[ 0.2828, -1.3530,  0.9005,  0.4556],[-1.3530,  0.2831,  0.4556, -1.4360],[-1.2417,  0.1829, -2.0170,  0.7537],[ 0.1829,  0.3825,  0.7537,  2.3406],[ 0.3825, -0.8555,  2.3406,  0.5866],[-2.0170,  0.7537, -1.1704, -1.8986],[ 0.7537,  2.3406, -1.8986, -0.7958],[ 2.3406,  0.5866, -0.7958,  0.2652]]])
fold_output: tensor([[[[ 0.4198,  2.1070,  0.2304,  0.3510],[ 2.3328,  1.3502,  4.8828,  0.7150],[-0.2174, -2.4979,  0.6865,  0.3388]],[[-1.4956,  1.9493, -1.0290,  0.1722],[ 3.4083,  3.8582, -2.7747, -3.8074],[-0.1961, -0.6690,  0.7129, -1.2329]],[[-0.9843, -1.6178,  3.7424, -0.2860],[ 0.1920, -7.0003, -0.4904,  1.8766],[-0.1675,  2.2995, -0.9917, -1.2953]],[[-1.2368,  1.1334,  2.8331, -2.2567],[ 1.8829,  3.2758,  6.2418,  0.2843],[-1.6414, -3.1189,  1.3435,  1.2319]],[[-0.4093,  1.3382,  2.8006,  0.7444],[-0.5717, -1.7500, -4.5204,  1.4754],[-0.0956, -0.3688,  1.5395, -0.3077]]],[[[ 0.4264, -0.1399, -3.1201, -0.0491],[ 3.0053, 12.6500,  2.4319, -3.7589],[-0.3148,  1.2753, -1.4484,  0.1692]],[[ 0.2757, -1.0806,  1.5497, -1.1795],[ 0.3007, -1.8684,  3.7421,  2.6100],[-0.4920, -1.7162,  0.1119, -0.0446]],[[ 2.1627,  1.3515, -0.1936,  1.3401],[-0.2210,  3.3198, -1.5307, -2.1373],[-0.2234, -2.0846,  2.4872, -0.6514]],[[ 0.8085, -0.8318,  0.4044,  0.5747],[-0.2529,  1.1310, -5.4121,  0.5662],[-0.1571,  1.8010,  0.9112, -1.4360]],[[-1.2417,  0.3659,  0.7650, -0.8555],[-4.0339,  3.0146,  9.3626,  1.1733],[-1.1704, -3.7972, -1.5917,  0.2652]]]])

Unfold()与Fold() 变化模式图解

以上面代码输出为例,其实是以如下的格式对原数据进行组织排列的:
在各滑动窗中按行展开(行向量化),然后转置成列向量, 是im2col的批量形式。
在这里插入图片描述
在这里插入图片描述
然后,对Unfold()的结果以相同参数运用Fold()后(Fold()的讲解在下面,这里先给出结果),结果如下:
在这里插入图片描述

nn.Fold()

nn.Fold() 是 nn.Unfold() 函数的逆操作。 (参数相同、滑动窗口没有重叠的情况下,可以完全恢复【真互逆】。滑动窗口有重叠情况下不能恢复到Unfold的输入)

需要注意的是,如果滑动窗口有重叠,那么重叠部分相加【倍数关系】。同时,如果原来的图像不够划分的话就会舍去。在恢复时就会以 0 填充

单通道 滑动窗口无重叠

# 单通道  滑动窗口无重叠
import torch.nn as nn
import torchbatches_img = torch.rand(1,1,6,6)
print("batches_img: ",batches_img)unfold = nn.Unfold(kernel_size=(3,3),stride=3)
patche_img = unfold(batches_img)
print("patche_img.shape: ",patche_img.shape)
print(patche_img)fold = torch.nn.Fold(output_size=(6, 6), kernel_size=(3, 3), stride=3)
inputs_restore = fold(patche_img)
print("inputs_restore:", inputs_restore)

输出:

batches_img:  tensor([[[[0.0174, 0.3919, 0.0073, 0.4660, 0.6537, 0.0584],[0.9763, 0.9982, 0.6250, 0.1332, 0.2123, 0.9500],[0.5482, 0.4291, 0.9430, 0.6837, 0.6975, 0.1992],[0.5275, 0.6800, 0.0490, 0.0350, 0.8571, 0.2449],[0.3719, 0.7484, 0.7677, 0.4164, 0.2151, 0.8875],[0.0784, 0.3839, 0.7567, 0.4217, 0.3208, 0.3025]]]])
patche_img.shape:  torch.Size([1, 9, 4])
tensor([[[0.0174, 0.4660, 0.5275, 0.0350],[0.3919, 0.6537, 0.6800, 0.8571],[0.0073, 0.0584, 0.0490, 0.2449],[0.9763, 0.1332, 0.3719, 0.4164],[0.9982, 0.2123, 0.7484, 0.2151],[0.6250, 0.9500, 0.7677, 0.8875],[0.5482, 0.6837, 0.0784, 0.4217],[0.4291, 0.6975, 0.3839, 0.3208],[0.9430, 0.1992, 0.7567, 0.3025]]])
inputs_restore: tensor([[[[0.0174, 0.3919, 0.0073, 0.4660, 0.6537, 0.0584],[0.9763, 0.9982, 0.6250, 0.1332, 0.2123, 0.9500],[0.5482, 0.4291, 0.9430, 0.6837, 0.6975, 0.1992],[0.5275, 0.6800, 0.0490, 0.0350, 0.8571, 0.2449],[0.3719, 0.7484, 0.7677, 0.4164, 0.2151, 0.8875],[0.0784, 0.3839, 0.7567, 0.4217, 0.3208, 0.3025]]]])

模拟图片数据(b,3,9,9),通道数 C 为3,滑动窗口无重叠。

相较于上面的代码,变化仅此

# 模拟图片数据(b,3,9,9),通道数 C 为3,滑动窗口无重叠。 相较于上面的代码,变化仅此
import torch.nn as nn
import torchbatches_img = torch.rand(1,3,6,6)
print("batches_img: ",batches_img)unfold = nn.Unfold(kernel_size=(3,3),stride=3)
patche_img = unfold(batches_img)
print("patche_img.shape: ",patche_img.shape)
print(patche_img)fold = torch.nn.Fold(output_size=(6, 6), kernel_size=(3, 3), stride=3)
inputs_restore = fold(patche_img)
print("inputs_restore:", inputs_restore)

输出为:

batches_img:  tensor([[[[0.6072, 0.9496, 0.4149, 0.1085, 0.6808, 0.3949],[0.9770, 0.4831, 0.3964, 0.6597, 0.1749, 0.7326],[0.4379, 0.0159, 0.2946, 0.4129, 0.1445, 0.5479],[0.1664, 0.6725, 0.5104, 0.4171, 0.6656, 0.3146],[0.5126, 0.2331, 0.8167, 0.2695, 0.6420, 0.8591],[0.2282, 0.6300, 0.9205, 0.6741, 0.6085, 0.7866]],[[0.7943, 0.8348, 0.5379, 0.1951, 0.2629, 0.7281],[0.5726, 0.4912, 0.5636, 0.7816, 0.9746, 0.3764],[0.5440, 0.3434, 0.5914, 0.5925, 0.9556, 0.0455],[0.0810, 0.0730, 0.2580, 0.0785, 0.2483, 0.3810],[0.4182, 0.7024, 0.4904, 0.6935, 0.1789, 0.1015],[0.2571, 0.9138, 0.1987, 0.6266, 0.0760, 0.4618]],[[0.3554, 0.2476, 0.3415, 0.5014, 0.1018, 0.3563],[0.2180, 0.5690, 0.9975, 0.8152, 0.5812, 0.2704],[0.5717, 0.9419, 0.4398, 0.5708, 0.2666, 0.3507],[0.3868, 0.6889, 0.0326, 0.7873, 0.7444, 0.8057],[0.1440, 0.9667, 0.2522, 0.9718, 0.6078, 0.2911],[0.1442, 0.3061, 0.4116, 0.4190, 0.2343, 0.2608]]]])
patche_img.shape:  torch.Size([1, 27, 4])
tensor([[[0.6072, 0.1085, 0.1664, 0.4171],[0.9496, 0.6808, 0.6725, 0.6656],[0.4149, 0.3949, 0.5104, 0.3146],[0.9770, 0.6597, 0.5126, 0.2695],[0.4831, 0.1749, 0.2331, 0.6420],[0.3964, 0.7326, 0.8167, 0.8591],[0.4379, 0.4129, 0.2282, 0.6741],[0.0159, 0.1445, 0.6300, 0.6085],[0.2946, 0.5479, 0.9205, 0.7866],[0.7943, 0.1951, 0.0810, 0.0785],[0.8348, 0.2629, 0.0730, 0.2483],[0.5379, 0.7281, 0.2580, 0.3810],[0.5726, 0.7816, 0.4182, 0.6935],[0.4912, 0.9746, 0.7024, 0.1789],[0.5636, 0.3764, 0.4904, 0.1015],[0.5440, 0.5925, 0.2571, 0.6266],[0.3434, 0.9556, 0.9138, 0.0760],[0.5914, 0.0455, 0.1987, 0.4618],[0.3554, 0.5014, 0.3868, 0.7873],[0.2476, 0.1018, 0.6889, 0.7444],[0.3415, 0.3563, 0.0326, 0.8057],[0.2180, 0.8152, 0.1440, 0.9718],[0.5690, 0.5812, 0.9667, 0.6078],[0.9975, 0.2704, 0.2522, 0.2911],[0.5717, 0.5708, 0.1442, 0.4190],[0.9419, 0.2666, 0.3061, 0.2343],[0.4398, 0.3507, 0.4116, 0.2608]]])
inputs_restore: tensor([[[[0.6072, 0.9496, 0.4149, 0.1085, 0.6808, 0.3949],[0.9770, 0.4831, 0.3964, 0.6597, 0.1749, 0.7326],[0.4379, 0.0159, 0.2946, 0.4129, 0.1445, 0.5479],[0.1664, 0.6725, 0.5104, 0.4171, 0.6656, 0.3146],[0.5126, 0.2331, 0.8167, 0.2695, 0.6420, 0.8591],[0.2282, 0.6300, 0.9205, 0.6741, 0.6085, 0.7866]],[[0.7943, 0.8348, 0.5379, 0.1951, 0.2629, 0.7281],[0.5726, 0.4912, 0.5636, 0.7816, 0.9746, 0.3764],[0.5440, 0.3434, 0.5914, 0.5925, 0.9556, 0.0455],[0.0810, 0.0730, 0.2580, 0.0785, 0.2483, 0.3810],[0.4182, 0.7024, 0.4904, 0.6935, 0.1789, 0.1015],[0.2571, 0.9138, 0.1987, 0.6266, 0.0760, 0.4618]],[[0.3554, 0.2476, 0.3415, 0.5014, 0.1018, 0.3563],[0.2180, 0.5690, 0.9975, 0.8152, 0.5812, 0.2704],[0.5717, 0.9419, 0.4398, 0.5708, 0.2666, 0.3507],[0.3868, 0.6889, 0.0326, 0.7873, 0.7444, 0.8057],[0.1440, 0.9667, 0.2522, 0.9718, 0.6078, 0.2911],[0.1442, 0.3061, 0.4116, 0.4190, 0.2343, 0.2608]]]])

单通道 滑动窗口有重叠。

kernel_size=(3,3),stride=2

# 单通道 滑动窗口有重叠。  kernel_size=(3,3),stride=2
import torch.nn as nn
import torchbatches_img = torch.rand(1,1,6,6)
print("batches_img: \n",batches_img)unfold = nn.Unfold(kernel_size=(3,3),stride=2)
patche_img = unfold(batches_img)
print("patche_img.shape: ",patche_img.shape)
print(patche_img)fold = torch.nn.Fold(output_size=(6, 6), kernel_size=(3, 3), stride=2)
inputs_restore = fold(patche_img)
print("inputs_restore: \n", inputs_restore)

输出为:

batches_img: tensor([[[[0.4171, 0.0129, 0.2183, 0.0610, 0.5242, 0.9530],[0.7112, 0.7892, 0.2548, 0.4604, 0.7200, 0.0294],[0.0754, 0.0451, 0.2892, 0.6765, 0.8671, 0.5574],[0.4220, 0.4499, 0.8946, 0.0149, 0.6790, 0.0719],[0.1529, 0.2815, 0.8502, 0.5781, 0.0339, 0.9916],[0.6900, 0.4843, 0.3190, 0.0676, 0.8558, 0.0060]]]])
patche_img.shape:  torch.Size([1, 9, 4])
tensor([[[0.4171, 0.2183, 0.0754, 0.2892],[0.0129, 0.0610, 0.0451, 0.6765],[0.2183, 0.5242, 0.2892, 0.8671],[0.7112, 0.2548, 0.4220, 0.8946],[0.7892, 0.4604, 0.4499, 0.0149],[0.2548, 0.7200, 0.8946, 0.6790],[0.0754, 0.2892, 0.1529, 0.8502],[0.0451, 0.6765, 0.2815, 0.5781],[0.2892, 0.8671, 0.8502, 0.0339]]])
inputs_restore: tensor([[[[0.4171, 0.0129, 0.4365, 0.0610, 0.5242, 0.0000],[0.7112, 0.7892, 0.5095, 0.4604, 0.7200, 0.0000],[0.1507, 0.0902, 1.1567, 1.3530, 1.7342, 0.0000],[0.4220, 0.4499, 1.7892, 0.0149, 0.6790, 0.0000],[0.1529, 0.2815, 1.7005, 0.5781, 0.0339, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
# 重复累加次数最多的元素:
batches_img[0,0,2,2]*4  
# 输出:tensor(1.1567)

卷积等价于:Unfold + Matrix Multiplication + Fold (或view()到卷积输出形状)

注: 使用 Unfold + Matrix Multiplication + Fold 来代替卷积时,Fold 中的 kernel size 需要为 (1,1)

inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
(torch.nn.functional.conv2d(inp, w) - out).abs().max()  # tensor(1.9073e-06)

http://www.ppmy.cn/news/69560.html

相关文章

spring 容器结构/机制debug分析--Spring 学习的核心内容和几个重要概念--IOC 的开发模式--综合解图

目录 Spring Spring 学习的核心内容 解读上图: Spring 几个重要概念 ● 传统的开发模式 解读上图 ● IOC 的开发模式 解读上图 代码示例—入门 xml代码 注意事项和细节 1、说明 2、解释一下类加载路径 3、debug 看看 spring 容器结构/机制 综合解图 Spring Spr…

Android不基于第三发依赖包解析shp文件(2)

接着上篇文章继续 2)Point (点)   一个 Point 由一对双精度坐标组成,存储顺序为 X,Y。    /*** PointGeometry记录读取* */static Geometry renderPointGeometry(byte[] recordContent,GeometryFactory geometryFactory) {int shapetype2

离散数学_九章:关系(6)

🪐9.6 偏序 1、⛺偏序关系和偏序集⛲偏序关系⛲偏序(关系)的例子 a. “大于或等于” 关系b. “整除” 关系c. “包含” 关系 🎬偏序集🎬可比性(comparability) " ≼ " 符号a. 可比 &a…

谁说的前端已死后端已亡?你给我站出来

前言 今年莫名其妙的网上就冒出来一句话:前端已死后端已亡,然后就会有很多人说工作不好找,要求高等等,本人也是在编程领域混迹了很多年,今天我们就来客观的分析一下,现在互联网到底是一个什么情况。 互联…

【JavaSE】数组复制的几种方式

找准方向&#xff0c;全力出击 文章目录 1. Arrays.copyOf()2. System.arraycopy()3. for循环4. clone() 1. Arrays.copyOf() 它的语法为 public static <T> T[] copyOf(T[] original, int newLength)方法返回值为一个数组&#xff0c;newLength为数组的新长度&#xf…

C++ Primer第五版_第十六章习题答案(61~67)

文章目录 练习16.61练习16.62Sales_data.hex62.cpp 练习16.63练习16.64练习16.65练习16.66练习16.67 练习16.61 定义你自己版本的 make_shared。 template <typename T, typename ... Args> auto make_shared(Args&&... args) -> std::shared_ptr<T> {r…

实时更新天气微信小程序开发

1.新建一个天气weather项目 2.在app.json中创建一个路由页面 当我们点击保存的时候&#xff0c;微信小程序会自动的帮我们创建好页面 3.在weather页面上书写我们的骨架 4.此时我们的页面很怪&#xff0c;因为没有给它添加样式和值。此时我们给它一个样式。&#xff08;样式写在…

事件循环Event Loop

什么是事件循环&#xff08;event loop&#xff09; 主线程不断的从消息队列中获取消息&#xff0c;执行消息&#xff0c;这个过程被称为事件循环&#xff0c;在javaScript中就是采用事件循环来解决单线程带来的问题 线程和进程 进程&#xff1a;计算机已经运行的程序&#…