刘二大人《Pytorch深度学习实践》第八讲加载数据集

news/2024/10/30 23:20:42/

文章目录

  • Epoch、Batch-Size、Iterations
  • Dataset、DataLoader
  • 课上代码
  • torchvision中数据集的加载

Epoch、Batch-Size、Iterations

在这里插入图片描述
1、所有的训练集进行了一次前向和反向传播,叫做一个Epoch
2、在深度学习训练中,要给整个数据集分成多份,即mini-batch,每个mini-batch所包含的样本的数量叫做Batch-Size
3、因为数据集分成了多个mini-batch,有多少份mini-batch就有多少个Iteration,每进行一次mini-batch的前向和后向传播,就会进行一次权重参数的更新,在一个Epoch中,有多少个Iteration,就更新了多少次权重参数

Dataset、DataLoader

在这里插入图片描述
1、DataSet 是抽象类,不能实例化对象,需要自己定义类继承该抽象类并实现其中的方法
2、init()函数里面主要用来加载数据集,分成x_data,y_data
3、__getitem()__主要根据下表来获取数据集
4、len() 主要用来返回数据集的个数
5、DataLoader是Pytorch中用来处理模型输入数据的一个工具类。组合了数据集(dataset) + 采样器(sampler),并在数据集上提供单线程或多线程(num_workers )的可迭代对象。在DataLoader中有多个参数,这些参数中重要的几个参数的含义说明如下:

 1. epoch:所有的训练样本输入到模型中称为一个epoch; 2. iteration:一批样本输入到模型中,成为一个Iteration;3. batchszie:批大小,决定一个epoch有多少个Iteration;4. 迭代次数(iteration)=样本总数(epoch)/批尺寸(batchszie)5. dataset (Dataset) – 决定数据从哪读取或者从何读取;6. batch_size (python:int, optional) – 批尺寸(每次训练样本个数,默认为1)7. shuffle (bool, optional) –每一个 epoch是否为乱序 (default: False);8. num_workers (python:int, optional) – 是否多进程读取数据(默认为0);9. drop_last (bool, optional) – 当样本数不能被batchsize整除时,最后一批数据是否舍弃(default: False)10. pin_memory(bool, optional) - 如果为True会将数据放置到GPU上去(默认为false) 

在这里插入图片描述

课上代码

import torch 
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoaderclass DiabetesDataset (Dataset):def __init__(self):xy = np.loadtxt('diabetes.csv', delimiter=',', dtype = np.float32)self.len = xy.shape[0]self.x_data = torch.from_numpy (xy[:,:-1])self.y_data = torch.from_numpy (xy[:,[-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset()
train_loader = DataLoader (dataset=dataset, batch_size=32, shuffle=True, num_workers=0)class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6) # 输入数据x的特征是8维,x有8个特征self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid() # 将其看作是网络的一层,而不是简单的函数使用def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x)) # y hatreturn xmodel = Model()criterion = torch.nn.BCELoss(reduction='mean')  
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)for epoch in range (100):for i, data in enumerate (train_loader, 0):inputs, labels = datay_pred = model (inputs)loss = criterion (y_pred, labels)print (epoch, i, loss.item())optimizer.zero_grad()loss.backward ()optimizer.step()

torchvision中数据集的加载

在这里插入图片描述
在这里插入图片描述


http://www.ppmy.cn/news/42093.html

相关文章

【Docker】Docker常规软件的安装

总体步骤 搜索镜像 拉取镜像 查看镜像 启动镜像 停止容器 移除容器 示例(安装mysql) 搜索镜像 docker search mysql[root192 ~]# docker search mysql NAME DESCRIPTION STARS OF…

基于改进多目标灰狼优化算法的考虑V2G技术的风、光、荷、储微网多目标日前优化调度研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

基于深度学习的花卉识别

1、数据集 春天来了,我在公园的小道漫步,看着公园遍野的花朵,看起来真让人心旷神怡,一周工作带来的疲惫感顿时一扫而光。难得一个糙汉子有闲情逸致俯身欣赏这些花朵儿,然而令人尴尬的是,我一朵都也不认识。…

用ChatGPT怎么赚钱?普通人用这5个方法也能赚到生活费

ChatGPT在互联网火得一塌糊涂,因为它可以帮很多人解决问题。比如:帮编辑人员写文章,还可以替代程序员写代码,帮策划人员写文案策划等等。ChatGPT这么厉害,能否用它来赚钱呢?今天和大家分享用ChatGPT赚钱的5…

银行数字化转型导师坚鹏:ChatGPT解密与银行应用案例

ChatGPT解密与银行应用案例 ——开启人类AI新纪元 打造数字化转型新利器 课程背景: 很多企业和员工存在以下问题:不清楚ChatGPT对我们有什么影响?不知道ChatGPT的发展现状及作用?不知道ChatGPT的银行业应用案例?…

Matlab基础

Matlab基础目录Matlab变量特殊常量变量的命名规则变量定义与赋值变量的显示变量的存取变量的清楚变量的检查数组和矩阵一维数组的创建和元素提取一维数组的创建一维数组的提取二维数组的创建与元素提取二维数组的创建二维矩阵元素提取字符数组和空数组矩阵的基本算术运算数据可…

应用程序接口(API)安全的入门指南

本文简单回顾了 API 的发展历史,其基本概念、功能、相关协议、以及使用场景,重点讨论了与之相关的不同安全要素、威胁、认证方法、以及十二项优秀实践。 根据有记录的历史,随着 Salesforce 的销售自动化解决方案的推出,首个 We…

Linux下异步socket客户端

文章目录socket 客户端1. 创建socketsocket()函数返回值2. 设置socket的属性connect函数sockaddr_in结构体inet_pton函数3. fcntl设置非阻塞4. recv函数socket 客户端 1. 创建socket socket()函数 #include <sys/socket.h> int socket(int domain, int type, int proto…