pytorch综合多个弱分类器,投票机制,进行手写数字分类(boosting)

news/2024/11/23 13:04:13/

首先,这个文章的出发点就是让一个网络一个图片进行预测,在直观上不如多个网络对一个图片进行预测之后再少数服从多数效果好。

也就是对于任何一个分类任务,训练n个弱分类器,也就是分类准确度只比随机猜好一点,那么当n足够大的时候,通过投票机制,也能提升很大的准确度:毕竟每个网络都分错同一个数据的可能性会降低。

接下来就是代码实现。

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from collections import Counter
import numpy as npclass MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.input_layer=nn.Sequential(nn.Linear(28*28,30),nn.Tanh(),)self.output_layer=nn.Sequential(nn.Linear(30,10),#nn.Sigmoid())def forward(self, x):x=x.view(x.size(0),-1)x=self.input_layer(x)x=self.output_layer(x)return xtrans=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([.5],[.5]),]
)
BATCHSIZE=100
DOWNLOAD_MNIST=False
EPOCHES=200
LR=0.001train_data=torchvision.datasets.MNIST(root="./mnist",train=True,transform=trans,download=DOWNLOAD_MNIST,
)
test_data=torchvision.datasets.MNIST(root="./mnist",train=False,transform=trans,download=DOWNLOAD_MNIST,
)
train_loader=DataLoader(train_data,batch_size=BATCHSIZE,shuffle=True)
test_loader =DataLoader(test_data,batch_size=BATCHSIZE,shuffle=False)mlps=[MLP().cuda() for i in range(10)]
optimizer=torch.optim.Adam([{"params":mlp.parameters()} for mlp in mlps],lr=LR)loss_function=nn.CrossEntropyLoss()for ep in range(EPOCHES):for img,label in train_loader:img,label=img.cuda(),label.cuda()optimizer.zero_grad()#10个网络清除梯度for mlp in mlps:out=mlp(img)loss=loss_function(out,label)loss.backward()#网络们获得梯度optimizer.step()pre=[]vote_correct=0mlps_correct=[0 for i in range(len(mlps))]for img,label in test_loader:img,label=img.cuda(),label.cuda()for i, mlp in  enumerate( mlps):out=mlp(img)_,prediction=torch.max(out,1) #按行取最大值pre_num=prediction.cpu().numpy()mlps_correct[i]+=(pre_num==label.cpu().numpy()).sum()pre.append(pre_num)arr=np.array(pre)pre.clear()result=[Counter(arr[:,i]).most_common(1)[0][0] for i in range(BATCHSIZE)]vote_correct+=(result == label.cpu().numpy()).sum()print("epoch:" + str(ep)+"总的正确率"+str(vote_correct/len(test_data)))for idx, coreect in enumerate( mlps_correct):print("网络"+str(idx)+"的正确率为:"+str(coreect/len(test_data)))

 可以看到虽然网络模型很简单,但是通过多个弱分类模型的投票,得到的结果也是比其中任何一个网络的效果都要好不少的。应该关注相对提升,不应该关注绝对提升。

这些网络模型的架构一致,只是初始化不一样。如果模型之间架构差别比较大,比如有简单的cnn,dnn,rnn,svm等等,效果可能更好。


http://www.ppmy.cn/news/679447.html

相关文章

利用C#编写网页投票器程序

一、前言 看个图,了解下投票的过程。 提交投票信息 投票页 ――――――――>投票信息处理页 反馈投票结果 (请求页)<―――――――(响应页) 一般情况下,填写投票信息,然后点提交按钮发送到响应页&#xf…

Docker的一些基本概念

使用Docker的原因 (面试题) 1.简化程序 2.避免选择恐惧 3.节省开支 4.持续交付和部署 5.更轻松的迁移 应用场景 Web应用的自动化打包和发布 自动化测试和持续集成、发布 在服务型环境中部署和调整数据库或其他的后台应用 从头编译或者扩展现有的OpenShift或Cloud Foun…

我的世界1.20.1优化模组推荐:全新的58个模组让帧数更高

by LingLing1301 我的世界优化模组列表(无mod冲突) 如果mod在3个月内没有更新,或者mod有兼容性错误,将删除mod 模组列表: Fabric APIStarlightSmooth BootBetter Fps - Render DistanceMinecraftCapes ModIris Shader…

把steam上下载的GTA5转移到Epic中,免除Epic再次下载GTA5的方法

之前都是用steam玩GTA5的,而且还是白嫖我朋友的,所以Epic上一可以免费领GTA5就去了,但领了之后发现在Epic上还需要再下载一个GTA5就感到崩溃,毕竟游戏太大了。终于在网上找到了方法解决,现在记录一下。 步骤&#xff1…

gta5因为计算机丢失xinput1,GTA5 运行缺少这个xinput1-3.dll,怎样办

软件安装:装机软件必备包 网络游戏,英文名称为Online Game,又称 “在线游戏”,简称“网游”。指以互联网为传输媒介,以游戏运营商服务器和用户计算机为处理终端,以游戏客户端软件为信息交互窗口的旨在实现娱乐、休闲、交流和取得虚拟成就的具有可持续性的个体性多人在线游…

gta5oracle.yft原文件,GTA5 addonpeds2.2[添加人物模型的人物模型选择器]

添加人物模型的AddonPeds.ASI 文件以及人物模型选择器是一个允许你无需替换任意文件就能直接添加新款人物模型至你的游戏之中去的脚本。 已经包含了由 Quechus13制作的超胆侠和超人角色作为添加的人物样板。 安装说明 拖放压缩包中AddonPedsPatch.exe文件至你的游戏根目录下 拖…

gta5r星服务器无限载入,GTA5及R星平台加载不出来问题解决办法

《GTA5》是一款违法主题动作冒险类游戏,其多人在线形式故事情节,可让玩家体会和实际国际相同的国际观。可是许多入坑的玩家,反应常常遇到在翻开《GTA5》或许R星渠道时总是加载不出来,一向呈现转圈圈状况。 这是由于游戏服务器在国…

gta5oracle原文件,GTA5.exe GTAVLauncher.exe 1.0.323.1 侠盗猎车5原版执行文件

如果您下载的是DLL文件(如果是其他软件请无视下面的信息): 1、下载后根据您系统的情况选择X86/X64,X86为32位电脑,X64为64位电脑。 如果您不知道是X86还是X64,可以住个尝试。 2、把dll文件拷贝到对应目录 C:\Windows\System (Windows 95/98/Me) C:\WINNT\System32 (Windows…