【深度学习】6-3 卷积神经网络 - 卷积层和池化层的实现

news/2024/11/29 6:51:05/

卷积层和池化层的实现
如前所述,CNN 中各层间传递的数据是 4 维数据。所谓 4 维数据,比如数据的形状是 (10, 1, 28, 28),则它对应 10 个高为 28、长为 28、通道为 1 的数据。用 Python 来实现的话,如下所示

>>> x = np.random.rand(10, 1, 28, 28) # 随机生成数据
>>> x.shape
(10, 1, 28, 28)

这里,如果要访问第 1 个数据,只要写 x[0]就可以了

如果要访问第 1 个数据的第 1 个通道的空间数据,可以写成下面这样。

>>> x[0, 0] # 或者x[0][0]

像这样,CNN 中处理的是 4 维数据,因此卷积运算的实现看上去会很复杂,但是通过使用下面要介绍的 im2col这个技巧,问题就会变得很简单。

基于 im2col 的展开
如果老老实实地实现卷积运算,估计要重复好几层的 for语句。这样的实现有点麻烦。这里不使用 for语句,而是使用 im2col(image to column)这个便利的函数进行简单的实现。

im2col是一个函数,将输入数据展开以适合滤波器(权重)。
如下图所示,对 3 维的输入数据应用 im2col后,数据转换为 2 维矩阵(正确地讲,是把包含批数量的 4 维数据转换成了 2 维数据)。
在这里插入图片描述
具体地说,如下图所示,对于输入数据,将应用滤波器的区域(3 维方块)横向展开为 1 列。im2col会在所有应用滤波器的地方进行这个展开处理。

在这里插入图片描述
在上图中,为了便于观察,将步幅设置得很大,以使滤波器的应用区域不重叠。而在实际的卷积运算中,滤波器的应用区域几乎都是重叠的。在滤波器的应用区域重叠的情况下,使用 im2col展开后,展开后的元素个数会多于原方块的元素个数。因此,使用 im2col的实现存在比普通的实现消耗更多内存的缺点。但是,汇总成一个大的矩阵进行计算,对计算机的计算颇有益处。比如,在矩阵计算的库(线性代数库)等中,矩阵计算的实现已被高度最优化,可以高速地进行大矩阵的乘法运算。因此,通过归结到矩阵计算上,可以有效地利用线性代数库。

如下图所示,基于 im2col方式的输出结果是 2 维矩阵。因为 CNN 中数据会保存为 4 维数组,所以要将 2 维输出数据转换为合适的形状。以上就是卷积层的实现流程。
在这里插入图片描述
卷积运算的滤波器处理的细节:将滤波器纵向展开为 1 列,并计算和 im2col 展开的数据的矩阵乘积,最后转换(reshape)为输出数据的大小

卷积层的实现
im2col的实现如下

def im2col(input_data, filter_h, filter_w, stride=1, pad=0):"""Parameters----------input_data :(数据量, 通道,,)4维数组构成的输入数据filter_h : 滤波器的高filter_w : 滤波器的长stride : 步幅pad : 填充Returns-------col : 2维数组"""N, C, H, W = input_data.shapeout_h = (H + 2*pad - filter_h)//stride + 1out_w = (W + 2*pad - filter_w)//stride + 1img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))for y in range(filter_h):y_max = y + stride*out_hfor x in range(filter_w):x_max = x + stride*out_wcol[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)return col

input_data——由(数据量,通道,高,长)的 4 维数组构成的输入数据
filter_h——滤波器的高
filter_w——滤波器的长
stride——步幅
pad——填充

下面使用im2col来实现卷积层,这里将卷积层实现为名为Convolution的类

class Convolution:def __init__(self, W, b, stride=1,pad=0):self.W = Wself.b = bself.stride = strideself.pad = paddef forward(self, x):FN, C, FH, FW = self.W.shapeN, C, H, W = x.shapeout_h = init(1+ (H + 2*self.pad - FH) / self.stride)out_w = init(1+ (W + 2*self.pad - FW) / self.stride)col = im2col(x, FH, FW, self.stride, self.pad)col_W = self.W.reshape(FN, -1).T  # 滤波器的展开 out = np.dot(col, col_W) + self.bout = out.reshape(N,out_h,out_w, -1).transpose(0, 3, 1, 2)return out

度上的元素个数,以使多维数组的元素个数前后一致。比如,(10, 3, 5, 5) 形状的数组的元素个数共有 750 个,指定 reshape(10,-1)后,就会转换成 (10, 75) 形状的数组。

展开滤波器的部分将各个滤波器的方块纵向展开为 1 列。这里通过 reshape(FN,-1)将参数指定为 -1,这是 reshape的一个便利的功能。通过在 reshape时指定为 -1,reshape函数会自动计算 -1维度上的元素个数,以使多维数组的元素个数前后一致。比如,(10, 3, 5, 5) 形状的数组的元素个数共有 750 个,指定 reshape(10,-1)后,就会转换成 (10, 75) 形状的数组

池化层的实现
池化层也要使用 im2col展开输入数据
不过,池化的情况下,在通道方向上是独立的,这一点和卷积层不同。具体地讲,池化的应用区域按通道单独展开
如下图:
在这里插入图片描述
像这样展开之后,只需对展开的矩阵求各行的最大值,并转换为合适的形状即可
在这里插入图片描述
池化层的实现按下面 3 个阶段进行

  1. 展开输入数据。
  2. 求各行的最大值。
  3. 转换为合适的输出大小。

下面看Python的实际实现:

class Pooling:def __init__(self, pool_h, pool_w, stride=1, pad=0):self.pool_h = pool_hself.pool_w = pool_wself.stride = strideself.pad = paddef forward(self, x):N, C, H, W = x.shapeout_h = init(1+ (H - self.pool_h) / self.stride)out_w = init(1+ (W - self.pool_w) / self.stride)# 展开(1)col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)col = col.reshape(-1, self.pool_h*pool_w)# 最大值(2)out = np.max(col, axis=1)# 转换(3)out = out.reshape(N, out_h, out_w, C).transpose(0,3,1,2)return out

http://www.ppmy.cn/news/521489.html

相关文章

T1153,T1121,T1154

T1153,T1121,T1154 T1153T1121T1154计蒜客网址 T1153 蒜术师给了你一个 1010 个整数的序列,要求对其重新排序。排序要求: 奇数在前,偶数在后; 奇数按从大到小排序; 偶数按从小到大排序。 解析&#xf…

T2081,T2131,T1136

T2081,T2131,T1136 T2081T2131T1136计蒜客网址 T2081 国王将金币作为工资,发放给忠诚的骑士。第一天,骑士收到一枚金币;之后两天(第二天和第三天),每天收到两枚金币;之后…

日志打印[System.out.println(s)]引起的血案

日志打印引起的血案 起因 某天突然收到Cat关于某个站点的异常告警,错误数达到1000多个。过了一分钟后就恢复了。通过Cat错误日志,发现是调用下游服务异常。 报错有两种错误 com.alipay.sofa.rpc.core.exception.SofaRpcException: com.alipay.remoting.rpc.excep…

T2046,T1112,T1114

T2046,T1112,T1114 T2046T1112T1114计蒜客网址 T2046 凯凯刚写了一篇美妙的作文,请问这篇作文的标题中有多少个字符? 注意:标题中可能包含大、小写英文字母、数字字符、空格和换行符。统计标题字符数时,空格和换行符…

T1191,T1142,T1312,T1957

T1191,T1142,T1312,T1957 T1191T1142T1312T1957 T1191 一个笼子里面关了鸡和兔子(鸡有 2 只脚,兔子有 4 只脚,没有残疾的)。已经知道了笼子里面脚的总数 a,问笼子里面至少有多少只动…

Java | extends关键字【面向对象的第二大特征——继承】

CSDN话题挑战赛第2期 参赛话题:Java技术分享 Java之extends关键字 一、继承的概念引入1、继承是什么?有什么好处?2、怎么继承?格式是怎样的?3、继承之后会怎样呢?4、Java继承与C继承的区别 二、简单案例&am…

T2135,T1429,T1133,T1246

T2135,T1429,T1133,T1246 T2135T1429T1133T1246参考文献 T2135 某小学最近得到了一笔赞助,打算拿出其中一部分为学习成绩优秀的前 5 名学生发奖学金。期末,每个学生都有 3 门课的成绩:语文、数学、英语。先…

Linux kernel的中断子系统之(七):GIC代码分析

转载地址:https://www.cnblogs.com/arnoldlu/p/7599595.html 总结: 原文地址:《linux kernel的中断子系统之(七):GIC代码分析》 参考代码:http://elixir.free-electrons.com/linux/v3.17-rc3/s…