Perceptron

news/2024/10/28 22:35:31/

Perceptron

原理

简单的感知机可以看作一个二分类,假定我们的公式为

f(x) = sign(w *x + b)

我们把 -b 做为一个标准,w* x 的结果与 -b 这个标准比较,

w*x > -b, f(x) = +1

w *x < -b, f(x) = -1

不难看出w是超平面的法向量,超平面上的向量与w的数量积为0。因此这个超平面就可以很好的区分我们的数据集。

而感知机就是来寻找w和b

优化方法

优化方法我们现有的方法比较多,诸如GD、SGD、Minibatch、Adam

当然我们的损失函数也包含多种,常见的有MSE, CrossEntropy.

这边简单展示一下MSE以及GD原理。

5J5vtA.jpg

5J5xfI.jpg

SoftMax

如果我们输出为多分类,那就成为一个SoftMax回归。

SoftMax回归和线性回归一样将输入特征与权重做线性叠加。与线性回归的一个主要不同在于,SoftMax回归的输出值个数等于标签里的类别数。

5JTopn.jpg

MLP

而我们给SoftMax回归增加隐藏层,就是我们所说的多层感知机,而

全连接层只是对数据做仿射变换,我们的方法是引入非线性变换,就是激活函数。

5JTjk4.jpg

代码实现

这边选用CIFAR10数据集来做演示。CIFAR10包含10个类别,每个类别600张32x32的彩色图像。

1.导入依赖包

import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import os
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

2.加载数据集

这边对图片进行归一化处理。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
train_data = torchvision.datasets.CIFAR10(root="data",download=True,train=True,transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4,shuffle=True, num_workers=8)val_data = torchvision.datasets.CIFAR10(root="data",download=True,train=False,transform=transform)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=4,shuffle=True, num_workers=8)

3.定义模型及参数

用Sequential快速构建,对数据进行展平处理输入尺寸为图片尺寸 x 通道数,输出10分类,hidden layer设置为512。

net = nn.Sequential(nn.Flatten(),nn.Linear(1024*3, 512),nn.ReLU(),nn.Linear(512,10)
)

4.训练

损失计算选用交叉熵函数,优化器选用SGD,调用显卡运行。

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = 0.01)
epochs = 30device = "cuda:0"
train(net,train_loader,val_loader,epochs,optimizer,loss,device)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XrRdwV5J-1645528506210)(https://z3.ax1x.com/2021/10/26/5I1IN8.jpg)]

结果

可以看出我们的验证准确值过低,这主要是因为数据集特征不明显,我们在更换数据集验证。

更换数据集

选用7分类的海贼王图片进行训练,可以看出训练结果明显优于CIFAR数据集。

5X26HS.jpg

同时我们再挑选一张不在训练集的图片进行验证,发现结果正确。

5X2gAg.jpg


http://www.ppmy.cn/news/206799.html

相关文章

java学习笔记1

记录java学习&#xff0c;慢慢更新中 &#xff08;图片懒得放了qaq&#xff09; 文章目录 一、基础知识二、Java基础&#xff08;一&#xff09;Java基本语法&#xff08;二&#xff09;面向对象OOP2.1 java类及类的成员2.1.1 java类及类的成员&#xff1a;属性、方法、构造器、…

Facebook新研究优化硬件浮点运算,强化AI模型运行速率

选自code.fb&#xff0c;作者&#xff1a;JEFF JOHNSON&#xff0c;机器之心编译&#xff0c;参与&#xff1a;Geek AI、路。 近日&#xff0c;Facebook 发布文章&#xff0c;介绍了一项新研究&#xff0c;该研究提出了一种使人工智能模型高效运行的方法&#xff0c;从根本上优…

DnsJumper下载

准备工作 # 安装油猴插件&#xff0c;新建脚本 // UserScript // name 百度网盘简易下载助手&#xff08;直链下载复活版&#xff09; // namespace http://bd.softxm.cn/bd/ // version 1.5.5 // antifeature membership // description 一个纯净好用的直链…

脑肿瘤分割论文打卡2:E1D3 U-Net for Brain Tumor Segmentation

E1D3 U-Net for Brain Tumor Segmentation: Submission to the RSNA-ASNR-MICCAI BraTS 2021 challenge 【E1D3 U-Net 用于脑肿瘤分割】 Abstract1 Introduction2 Realted Works3 Methodology3.1 E1D3 U-Net :One Encoder, Three Decoders3.2 Training3.3 Testing 4Experiments…

跨模态学习能力再升级,EasyNLP电商文图检索效果刷新SOTA

作者&#xff1a;熊兮、欢夏、章捷、临在 导读 多模态内容&#xff08;例如图像、文本、语音、视频等&#xff09;在互联网上的爆炸性增长推动了各种跨模态模型的研究与发展&#xff0c;支持了多种跨模态内容理解任务。在这些跨模态模型中&#xff0c;CLIP&#xff08;Contra…

跨模态学习能力再升级,EasyNLP 电商文图检索效果刷新 SOTA

导读 多模态内容&#xff08;例如图像、文本、语音、视频等&#xff09;在互联网上的爆炸性增长推动了各种跨模态模型的研究与发展&#xff0c;支持了多种跨模态内容理解任务。在这些跨模态模型中&#xff0c;CLIP&#xff08;Contrastive Language-Image Pre-training&#x…

游戏思考17:寻路引擎recast和detour学习二:recast导航网格生成流程\源码剖析流程\局限性,附录计算点线面举例代码

一、recastnavigation使用介绍 1&#xff09;模式选择 Solo Mesh&#xff1a;单块生成 Tile Mesh&#xff1a;分块生成 Temp Obstacles&#xff1a;分块并支持动态阻挡这里测试的话选单块生成 2&#xff09;模型选择 官方自带3块地图&#xff0c;这里测试选择 nav_test.obj&a…

【Paddle笔记】使用PaddleGan工具包实现表情迁移(First Order Motion)

Paddle笔记 使用PaddleGan工具包实现表情迁移 FOM 1、环境安装1.1 运行环境1.1.1 Conda虚拟环境1.1.2 PyTorch1.1.3 Tensorflow1.2 Paddle核心框架1.2.1 安装Paddle框架1.2.2 验证框架是否安装成功1.3 PaddleGAN 生成对抗网络1.3.1 安装ppgan1.3.2 安装其他依赖1.3.3 下载示例代…