【深度学习实验】卷积神经网络(二):实现简单的二维卷积神经网络

news/2024/9/23 0:27:38/

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 二维互相关运算(corr2d)

2. 二维卷积层类(Conv2D)

a. __init__(初始化)

b. forward(前向传播函数)

3. 模型训练


一、实验介绍

        本实验实现了一个简单的二维卷积神经网络,包括二维互相关运算函数和自定义二维卷积层类,并对一个随机生成是二维张量进行了卷积操作。

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,广泛应用于图像识别、计算机视觉和模式识别等领域。它的设计灵感来自于生物学中视觉皮层的工作原理。

        卷积神经网络通过多个卷积层、池化层全连接层组成。

  • 卷积层主要用于提取图像的局部特征,通过卷积操作和激活函数的处理,可以学习到图像的特征表示。
  • 池化层则用于降低特征图的维度,减少参数数量,同时保留主要的特征信息。
  • 全连接层则用于将提取到的特征映射到不同类别的概率上,进行分类或回归任务。

        卷积神经网络在图像处理方面具有很强的优势,它能够自动学习到具有层次结构的特征表示,并且对平移、缩放和旋转等图像变换具有一定的不变性。这些特点使得卷积神经网络成为图像分类、目标检测、语义分割等任务的首选模型。除了图像处理,卷积神经网络也可以应用于其他领域,如自然语言处理和时间序列分析。通过将文本或时间序列数据转换成二维形式,可以利用卷积神经网络进行相关任务的处理。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

0. 导入必要的工具包

import torch
from torch import nn
import torch.nn.functional as F
  • torch.nn:PyTorch中的神经网络模块,提供了各种神经网络层和函数。
  • torch.nn.functional:PyTorch中的函数形式的神经网络层,如激活函数和损失函数等。
 

1. 二维互相关运算(corr2d)

def corr2d(X, K): h, w = K.shapeY = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))for i in range(Y.shape[0]):for j in range(Y.shape[1]):Y[i, j] = (X[i:i + h, j:j + w] * K).sum()return Y

  • 输入:输入张量X和卷积核张量K。
  • 输出:互相关运算结果张量Y,形状为(X.shape[0] - K.shape[0] + 1, X.shape[1] - K.shape[1] + 1)。
  • 通过两个嵌套的循环遍历输出张量Y的每个元素,使用局部相乘和求和的方式计算互相关运算结果。

2. 二维卷积层类(Conv2D)

class Conv2D(nn.Module):def __init__(self, kernel_size, weight=None):super().__init__()if weight is not None:self.weight = weightelse:self.weight = nn.Parameter(torch.rand(kernel_size))self.bias = nn.Parameter(torch.zeros(1))def forward(self, x):return corr2d(x, self.weight) + self.bias

a. __init__(初始化)

  • 接受一个kernel_size参数作为卷积核的大小,并可选地接受一个weight参数作为卷积核的权重。
  • 如果没有提供weight参数,则会随机生成一个与kernel_size相同形状的权重,并将其设置为可训练的参数(nn.Parameter)。
  • 定义了一个偏置项bias,也将其设置为可训练的参数。

b. forward(前向传播函数)

        调用之前的corr2d函数,对输入x和卷积核权重self.weight进行相关性计算,并将计算结果与偏置项self.bias相加,作为前向传播的输出。

3. 模型测试

# 由于卷积层还未实现多通道,所以我们的图像也默认是单通道的
fake_image = torch.randn((5,5))
# 实例化卷积算子
conv = Conv2D(kernel_size=(3,3))
output = conv(fake_image)

        创建了一个大小为(5, 5)的随机输入图像fake_image,然后实例化了Conv2D类,传入了卷积核大小为(3, 3)。接着调用conv对象的forward方法,对fake_image进行卷积操作,并将结果保存在output变量中。最后输出output的形状。

注意:本实验仅简单的实现了二维卷积神经网络,只支持单通道的卷积操作,且不包含包含训练和优化等过程,欲知后事如何,请听下回分解。


http://www.ppmy.cn/news/1124539.html

相关文章

理德外汇:俄货币供应创新高,需平衡刺激与通胀

今年俄罗斯广义货币供应量M2继续快速增长,7月份同比增长24.7%,年初以来增长超过11%。这种增长很好地刺激了GDP,是今年GDP表现良好的主要原因之一,但其也加速了物价上涨,今年年初至今,俄通货膨胀率同比超过5…

LLM-TAP随笔——语言模型训练数据【深度学习】【PyTorch】【LLM】

文章目录 3、语言模型训练数据3.1、词元切分3.2、词元分析算法 3、语言模型训练数据 数据质量对模型影响非常大。 典型数据处理:质量过滤、冗余去除、隐私消除、词元切分等。 训练数据的构建时间、噪音或有害信息情况、数据重复率等因素都对模型性能有较大影响。训…

苹果Vision Pro头显内置AI芯片

苹果首席执行官蒂姆库克近日在接受采访时确认,备受瞩目的Vision Pro头显将按计划于明年初在美国上市。这款头显被认为是苹果自iPhone以来最重要的产品之一,售价高达3499美元。 蒂姆库克在接受CBS Sunday Morning的采访时透露,他的团队对Visi…

vue安装依赖报错install i 报错提示npm audit fix --force,or `npm audit` for details

vue项目执行npm install初始化后报错 run npm audit fix to fix them, or npm audit for details 出现这类提醒,按照如下操作进行 1、首先安装模块依赖: npm install (npm audit fix 含义: 检测项目依赖中的漏洞并自动安装需要…

太阳能供电模块

基于Solar Cell的锂电池充放电模块 由于一些需求,最近做了一款基于太阳能的锂电池充放电模块。该模块能够利用太阳能为锂电池充电和为负载提供5V的电压,在太阳能不充足的条件下,由锂电池提供需要的能量。 主要思路是将太阳能板获得的能量存储…

Ubuntu 安装PostgreSQL

网上有各种版本的,也可以去官网看官方的文档。我是下载的PostgreSQL-11.4版本的。找到以后直接复制网上的压缩包链接就可以。 $ mkdir /opt/postgresql && cd /opt/postgresql $ wget https://ftp.postgresql.org/pub/source/v11.4/postgresql-11.4.tar.gz…

Linux 终端命令总结

一、常用的七条命令 命令 对应英文作用lslist查看当前文件夹下的内容pwdprint work directory查看当前所在文件夹cd [目录名]change directory切换文件夹 touch [文件名]touch如果文件不存在新建文件mkdir [目录名]make directory创建目录rm[文件名]remo…

面试必杀技:Jmeter性能测试攻略大全(第二弹)

1. JMeter介绍与安装 JMeter介绍 JMeter是Apache组织开发的基于Java的压力测试工具。具有开源免费、框架灵活、多平台支持等优势。除了压力测试外,JMeter也可以应用的接口测试上。JMeter下载、安装及启动 下载: 访问JMeter官网:https://j…