pytorch-解决过拟合之early stop和dropout

news/2024/9/25 5:32:35/

目录

  • 1. Early Stop
  • 2. 怎样Early Stop
  • 3. Dropout
  • 4. pytorch实现Dropout
  • 5. train和test时的Dropout
  • 6. 增加了vidom的示例代码

1. Early Stop

所谓的over fitting是训练集准确率在上升,但是test准确率开始下降了。
在测试集准确率达到最高点开始下降的时候停止训练,以防止over fitting
在这里插入图片描述

2. 怎样Early Stop

  • 用验证机来选择模型参数
  • 监测验证集的性能
  • 在性能最高点时停止训练

3. Dropout

Dropout就是使用一个概率来减少模型参数量,使得模型复杂度降低,从而降低over fitting的几率。
模型复杂度越低over fitting的几率也就越低,因此Dropout通过使某些连接p=wx=0,相当于断掉该条连接,从而减少了当前层到下一层的连接数。比如:有10k个连接,加了Dropout可能就变成了5k。
在这里插入图片描述

pytorchDropout_13">4. pytorch实现Dropout

在层与层之间使用torch.nn.Dropout增加Dropout,注意Drop是加在两层之间而不是层内的。
在这里插入图片描述

5. train和test时的Dropout

train的时候是可以使用Dropout的,但是test的时候一定不要使用,否则性能会下降,如果train使用了Dropout,那么test的时候要通过net_dropped.eval()取消掉Dropout
在这里插入图片描述

6. 增加了vidom的示例代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsfrom visdom import Visdombatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)viz = Visdom()viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.',legend=['loss', 'acc.']))
global_step = 0for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()global_step += 1viz.line([loss.item()], [global_step], win='train_loss', update='append')if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()viz.line([[test_loss, correct / len(test_loader.dataset)]],[global_step], win='test', update='append')viz.images(data.view(-1, 1, 28, 28), win='x')viz.text(str(pred.detach().cpu().numpy()), win='pred',opts=dict(title='pred'))test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

http://www.ppmy.cn/news/1450838.html

相关文章

Qt Creator中变量与函数的注释 - 鼠标悬浮可显示

Qt Creator中变量与函数的注释 - 鼠标悬浮可显示 引言一、变量注释二、函数注释三、参考链接 引言 代码注释在软件开发中起着至关重要的作用。它们不仅有助于开发者理解和维护代码,还能促进团队协作,提高代码的可读性和可维护性。适当的注释应该是简洁明…

LeetCode135:分发糖果

题目描述 n 个孩子站成一排。给你一个整数数组 ratings 表示每个孩子的评分。 你需要按照以下要求,给这些孩子分发糖果: 每个孩子至少分配到 1 个糖果。 相邻两个孩子评分更高的孩子会获得更多的糖果。 请你给每个孩子分发糖果,计算并返回需…

python学习笔记----文件操作(八)

一、 open() 函数 在 Python 中,处理文件包括读取和写入操作,是通过使用内置的 open() 函数来实现的。 语法: open(file, mode"r", encoding"utf-8") file: 文件路径。mode: 文件打开模式: ‘r’&#xff…

【docker】maven 打包docker的插件学习

docker-maven-plugin GitHub地址:https://github.com/spotify/docker-maven-plugin 您可以使用此插件创建一个 Docker 映像,其中包含从 Maven 项目构建的工件。例如,Java 服务的构建过程可以输出运行该服务的 Docker 映像。 该插件是 Spot…

语言模型:智能化未来的钥匙

语言模型:智能化未来的钥匙 在当今信息爆炸的时代,人们对于有效处理和理解海量信息的需求日益增长。在这个背景下,语言模型崭露头角,成为解决信息处理难题的得力工具之一。而其中,AskBot大模型作为一项重要的技术创新&…

RustGUI学习(iced)之小部件(四):如何使用单选框radio部件?

前言 本专栏是学习Rust的GUI库iced的合集,将介绍iced涉及的各个小部件分别介绍,最后会汇总为一个总的程序。 iced是RustGUI中比较强大的一个,目前处于发展中(即版本可能会改变),本专栏基于版本0.12.1. 概述 这是本专栏的第四篇,主要讲述单选框按钮radio部件的使用,会结…

RabbitMQ入门教学(浅入浅出)

进程间通信 互联网的通讯时网络的基础,一般情况下互联网的资源数据对储存在中心服务器上,一般情况下个体对个体的访问仅限于局域网下,在公网即可完成资源的访问,如各种网站资源,下载资源,种子等。网络通讯…

Cesium 3dTileset 支持 uv 和 纹理贴图

原理: 使用自定义shader实现uv自动计算 贴图效果: uv效果: