PyTorch Lightning快速学习教程一:快速训练一个基础模型

news/2024/11/28 8:32:22/

粉丝量突破1200了!找到了喜欢的岗位,毕业上班刚好也有20天,为了督促自己终身学习的态度,继续开始坚持写写博客,沉淀并总结知识!
介绍:PyTorch Lightning是针对科研人员、机器学习开发者专门设计的,能够快速复用代码的一个工具,避免了因为每次都编写相似的代码而带来的时间成本。其可以理解为,lightning设计了一个,能够快速搭建训练验证测试模型的整套代码模板,我们只需要编写设计需要的模型、超参数、优化器等,直接套进去即可。lightning的优势在于:灵活性高、可读性强、支持多卡训练、内置测试、内置日志等。

前置掌握知识:Python和PyTorch的使用

链接:https://lightning.ai/

快速安装:pip install lightning

1.添加依赖包

需要添加相应的依赖,包括os,torch工具包,torch数据载入等依赖

import os		
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as pl

2.定义模型

PyTorch定义模型案例如下,定义好了方便后续的调用

class Encoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))def forward(self, x):return self.l1(x)	# 全连接 激活 全连接class Decoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))def forward(self, x):return self.l1(x)	# 全连接 激活 全连接

3.定义网络架构

定义网络模型,自定义模型名字,并继承lightning.pytorch.LightningModule类,如下代码

  • training_step定义了与nn.Module之间交互

  • configure_optimizers为模型定义优化器

class LitAutoEncoder(pl.LightningModule):def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoderdef training_step(self, batch, batch_idx):# training_step defines the train loop.x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = F.mse_loss(x_hat, x)return lossdef configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)return optimizer

4.定义训练集

定义DataLoader,这一点跟PyTorch调模型的流程一样,如下调用了MNIST公开数据集

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

5.训练数据

使用Lightning来处理所有的训练,如下代码。

# model 模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())# train model 训练
trainer = pl.Trainer()
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

一般的训练过程,需要设计如下代码,进行遍历和循环训练,Lightning会消除这些繁琐的过程,使用Lightning,可以将所有这些技术混合在一起,而无需每次都重写一个新的循环。

autoencoder = LitAutoEncoder(Encoder(), Decoder())
optimizer = autoencoder.configure_optimizers()for batch_idx, batch in enumerate(train_loader):loss = autoencoder.training_step(batch, batch_idx)loss.backward()optimizer.step()optimizer.zero_grad()

完整代码

# coding:utf-8
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
import lightning as L# --------------------------------
# Step 1: 定义一个 LightningModule
# --------------------------------
# A LightningModule (nn.Module subclass) defines a full *system*
# (例如: an LLM, diffusion model, autoencoder, or simple image classifier).class LitAutoEncoder(L.LightningModule):def __init__(self):super().__init__()self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))def forward(self, x):# forward 定义了一次 预测/推理 行为embedding = self.encoder(x)return embeddingdef training_step(self, batch, batch_idx):# training_step 定义了一次训练的迭代, 和forward相互独立x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = F.mse_loss(x_hat, x)self.log("train_loss", loss)return lossdef configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)return optimizer# -------------------
# Step 2: 定义数据集
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])# -------------------
# Step 3: 开始训练
# -------------------
autoencoder = LitAutoEncoder()
trainer = L.Trainer(accelerator="gpu")	
trainer.fit(autoencoder, data.DataLoader(train,batch_size=128), data.DataLoader(val))

http://www.ppmy.cn/news/978068.html

相关文章

leetcode 面试题 判定是否互为字符重排

⭐️ 题目描述 🌟 leetcode链接:判定是否互为字符重排 思路: 两个字符串的每个字母和数量都相等。那么 s2 一定可以排成 s1 字符串。 代码: bool CheckPermutation(char* s1, char* s2){char hash1[26] {0};char hash2[26] {…

华为OD机试真题100分题 C/C++/Java/Python/Js/Go【补种未成活胡杨】

题目描述: 【补种未成活胡杨】 近些年来,我国防沙治沙取得显著成果。某沙漠新种植N棵胡杨(编号1-N),排成一排。一个月后,有M棵胡杨未能成活。现可补种胡杨K棵,请问如何补种(只能补种,不能新种),可以得到最多的连续胡杨树? 输入描述: N 总种植数量 M 未成活胡杨数…

基于OpenCV的红绿灯识别

基于OpenCV的红绿灯识别 技术背景 为了实现轻舟航天机器人实现红绿灯的识别,决定采用传统算法OpenCV视觉技术。 技术介绍 航天机器人的红绿灯识别主要基于传统计算机视觉技术,利用OpenCV算法对视频流进行处理,以获取红绿灯的状态信息。具…

draw.io画图时,用一个箭头(线段)连结一个矩形和直线时,发现,无论怎么调节,都无法使其无缝连接。

问题描述:draw.io画图时,用一个箭头(线段)连结一个矩形和直线时,发现,无论怎么调节,都无法使其无缝连接。要么少一段,如图1所示。要么多一段,如图2所示。 图1&#xff0c…

心海舟楫、三一重工面试(部分)

心海舟楫 一道算法题: 我开始给出的是暴力解法,时间复杂度O(n^2)。 在面试官的提示下,实现了时间复杂度为O(n)的解法。 三一重工 没啥特别的

类和对象下

目录 初始化列表stakc关键字友元友元函数友元类 内部类匿名对象拷贝对象时编译器的优化构造函数中的隐式类型转换连续构造拷贝构造 初始化列表 前面我们了解了类的构造函数,知道了构造函数体赋值,其实C构造函数中还有一个初始化列表也可以进行初始化。 …

四、运算符(1)

本章概要 开始使用优先级赋值 方法调用中的别名现象 算术运算符 一元加减运算符 递增和递减 Java 是从 C 的基础上做了一些改进和简化发展而成的。对于 C/C 程序员来说,Java 的运算符并不陌生。如果你已了解 C 或 C,大可以跳过本章和下一章&#xff0c…

tqdm进度条

from time import sleep from faker import Faker fFaker(“en-us”) alist [f.name for _ in range(50)] from tqdm import tqdm,trange p1 for i in tqdm(alist): pp1 p1 for i in trange(50): p*(i1) sleep(0.05) proc_nartqdm(range(50)) for i in proc_nar: # 设置前…