[深度学习]卷积神经网络CNN

news/2024/9/29 23:58:33/

1 图像基础知识

import numpy as np
import matplotlib.pyplot as plt
# 图像数据
#img=np.zeros((200,200,3))
img=np.full((200,200,3),255)
# 可视化
plt.imshow(img)
plt.show()
# 图像读取
img=plt.imread('img.jpg')
plt.imshow(img)
plt.show()

2 CNN概述

  • 卷积层conv+relu
  • 池化层pool
  • 全连接层FC/Linear

3 卷积层

 

import matplotlib.pyplot as plt
import torch
from torch import nn
# 数据
img=plt.imread('img.jpg')
print(img.shape)
# conv
img=torch.tensor(img).permute(2,0,1).unsqueeze(0).to(torch.float32)
conv=nn.Conv2d(in_channels=3,out_channels=5,kernel_size=(3,5),stride=(1,2),padding=2)
# 处理
fm=conv(img)
print(fm.shape)

4 池化层

  • 下采样:样本减少
  • 上采样(深采样):样本增多
  • 最大池化相交平均池化使用更多
  • 通常kernel_size=(3,3),stride=(2,2),padding=(自定义)

import torch
from torch import nn
# 创建数据
torch.random.manual_seed(22)
data=torch.randint(0,10,[1,3,3],dtype=torch.float32)
print(data)

# 最大池化
pool=nn.MaxPool2d(kernel_size=(2,2),stride=(1,1),padding=0)
print(pool(data))

# 平均池化
pool=nn.AvgPool2d(kernel_size=(2,2),stride=(1,1),padding=0)
print(pool(data))

5 图像分类案例(LeNet)

import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
import matplotlib.pyplot as plt
from torchsummary import summary
from torch import optim
from torch.utils.data import DataLoader
# 获取数据
train_dataset=CIFAR10(root='cnn_net',train=True,transform=Compose([ToTensor()]),download=True)
test_dataset=CIFAR10(root='cnn_net',train=False,transform=Compose([ToTensor()]),download=True)
print(train_dataset.class_to_idx)
print(train_dataset.data.shape)
print(test_dataset.data.shape)

plt.imshow(test_dataset.data[100])
plt.show()
print(test_dataset.targets[100])

# 模型构建
class ImageClassification(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=3,stride=1,padding=0)self.pool1=nn.MaxPool2d(kernel_size=2,stride=2)self.pool2=nn.MaxPool2d(kernel_size=2,stride=2)self.fc1=nn.Linear(in_features=576,out_features=120)self.fc2=nn.Linear(in_features=120,out_features=84)self.out=nn.Linear(in_features=84,out_features=10)def forward(self,x):x=self.pool1(torch.relu(self.conv1(x)))x=self.pool2(torch.relu(self.conv2(x)))x=x.reshape(x.size(0),-1)x=torch.relu(self.fc1(x))x=torch.relu(self.fc2(x))out=self.out(x)return outmodel=ImageClassification()
summary(model,(3,32,32),batch_size=1)
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1             [1, 6, 30, 30]             168MaxPool2d-2             [1, 6, 15, 15]               0Conv2d-3            [1, 16, 13, 13]             880MaxPool2d-4              [1, 16, 6, 6]               0Linear-5                   [1, 120]          69,240Linear-6                    [1, 84]          10,164Linear-7                    [1, 10]             850
================================================================
Total params: 81,302
Trainable params: 81,302
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.08
Params size (MB): 0.31
Estimated Total Size (MB): 0.40
----------------------------------------------------------------
# 模型训练
optimizer=optim.Adam(model.parameters(),lr=0.0001,betas=[0.9,0.99])
error=nn.CrossEntropyLoss()
epoches=10
for epoch in range(epoches):dataloader=DataLoader(train_dataset,batch_size=2,shuffle=True)loss_sum=0num=0.1for x,y in dataloader:y_=model(x)loss=error(y_,y)loss_sum+=loss.item()num+=1optimizer.zero_grad()loss.backward()optimizer.step()print(loss_sum/num)
# 模型保存
torch.save(model.state_dict(),'model.pth')
# 模型预测
test_dataloader=DataLoader(test_dataset,batch_size=8,shuffle=False)
model.load_state_dict(torch.load('model.pth',weights_only=False))
corr=0
num=0
for x,y in test_dataloader:y_=model(x)out=torch.argmax(y_,dim=-1)corr+=(out==y).sum()num+=len(y)print(corr/num)

优化方向


http://www.ppmy.cn/news/1532155.html

相关文章

基本数据结构简记

简单记录一下常见的一些数据结构的概念,不包含树和图。 一、线性数据结构 1、主要成员(或形式) 栈,队列,双端队列,列表 2、特点 有两端 区分方式:元素添加与移除方式 二、栈 1、特点 添加…

行内对齐 vertical-align

MDN vertical-align 在CSS中,文本的垂直对齐通常指的是行内元素(inline elements)或表格单元格中的内容如何相对于其容器进行上下对齐。对于这些情况,可以使用 vertical-align 属性来控制。 vertical-align 属性适用于以下几种情…

构建网络遇到的问题-AlexNet

1.对模型进行初始化采用的一般代码 def _initialize_weights(self):for m in self.modules(): # 遍历模型每一层if isinstance(m, nn.Conv2d): # 判定m层是否属于nn.Conv2d类型nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu)if m.bias is not None:nn.in…

专深与广博的平衡艺术

一、引言 ----  随着人工智能(AI)和生成式人工智能(AIGC)如ChatGPT、Midjourney、Claude等大语言模型的快速发展,AI辅助编程工具正逐渐成为程序员日常工作的得力助手。这一变革不仅对编程工作方式产生了深远影响&…

在云渲染中3D工程文件安全性怎么样?

在云渲染中,3D工程文件的安全性是用户最关心的问题之一。随着企业对数据保护意识的增强,云渲染平台采取了严格的安全措施和加密技术,以确保用户数据的安全性和隐私性。 云渲染平台为了保障用户数据的安全,采取了多层次的安全措施。…

Redis系列补充:聊聊布隆过滤器(go语言实践篇)

1 介绍 布隆过滤器(Bloom Filter)是 Redis 4.0 版本之后提供的新功能,我们一般将它当做插件加载到 Redis Service服务器中,给 Redis 提供强大的滤重功能。 它是一种概率性数据结构,可用于判断一个元素是否存在于一个集…

【小bug】使用 RestTemplate 工具从 JSON 数据反序列化为 Java 对象时报类型转换异常

起因:今天编写一个请求时需要通过RestTemplate调用外部接口,获取一些信息,但是在获取了外部接口响应内容后,使用强制转换发现报了类型转换异常。之前也遇到过,但是没记录下来,今天又查了一遍……干脆记录一…

Tomcat中间件常见漏洞复现

#1.CVE-2017-12615 -----Tomcat put方法任意文件写入漏洞 1.打开靶场 cd vulhub/tomcat/CVE-2017-12615 docker-compose up -d docker ps 2.访问8080端口,来到靶场 3.首页进抓包,Tomcat允许适⽤put⽅法上传任意⽂件类型,但不允许jsp后缀…