LLM学习笔记-5

ops/2024/10/18 18:23:49/

目录

  • 1.多层神经网络的实现
  • 2. 训练轮次示例
  • 3. 保存并加载模型
  • 4. 使用GPU加速训练
  • 5. 使用上面所教,进行一次训练

摘要:今天想整理一下Pytorch常用操作,以便以后进行预习(不是)
在这里插入图片描述

1.多层神经网络的实现

这是常用的操作,要会

class NeuralNetwork(torch.nn.Module):def __init__(self, num_inputs, num_outputs):super().__init__()self.layers = torch.nn.Sequential(# 第一个隐藏层torch.nn.Linear(num_inputs, 30),torch.nn.ReLU(),# 第二个隐藏层torch.nn.Linear(30, 20),torch.nn.ReLU(),# 输出层torch.nn.Linear(20, num_outputs),)def forward(self, x):logits = self.layers(x)return logitsmodel = NeuralNetwork(50, 3)
print(model)

NeuralNetwork(
(layers): Sequential(
(0): Linear(in_features=50, out_features=30, bias=True)
(1): ReLU()
(2): Linear(in_features=30, out_features=20, bias=True)
(3): ReLU()
(4): Linear(in_features=20, out_features=3, bias=True)
)
)

2. 训练轮次示例

import torch.nn.functional as Ftorch.manual_seed(123)
model = NeuralNetwork(num_inputs=2, num_outputs=2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)num_epochs = 3for epoch in range(num_epochs):model.train()for batch_idx, (features, labels) in enumerate(train_loader):logits = model(features)loss = F.cross_entropy(logits, labels) # 损失函数optimizer.zero_grad()loss.backward()optimizer.step()### 日志print(f"Epoch: {epoch+1:03d}/{num_epochs:03d}"f" | Batch {batch_idx:03d}/{len(train_loader):03d}"f" | Train/Val Loss: {loss:.2f}")model.eval()# 可选的模型评估指标

Epoch: 001/003 | Batch 000/002 | Train/Val Loss: 0.75
Epoch: 001/003 | Batch 001/002 | Train/Val Loss: 0.65
Epoch: 002/003 | Batch 000/002 | Train/Val Loss: 0.44
Epoch: 002/003 | Batch 001/002 | Train/Val Loss: 0.13
Epoch: 003/003 | Batch 000/002 | Train/Val Loss: 0.03
Epoch: 003/003 | Batch 001/002 | Train/Val Loss: 0.00

3. 保存并加载模型

就一句话

torch.save(model.state_dict(), "model.pth")

4. 使用GPU加速训练

我们常常说的CUDA就是在GPU上训练

import torch
# 显示PyTorch是否支持GPU
print(torch.cuda.is_available())

如果显示True,则代表可以用GPU,否则则要用CPU

# 根据设备可用情况选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

5. 使用上面所教,进行一次训练

创建了一个简单的神经网络模型来对二分类问题进行训练,并且使用了 PyTorch 提供的 Dataset 和 DataLoader 类来加载数据集并进行批处理。此外,你还定义了一个函数来计算模型的准确率。

import torch
X_train = torch.tensor([[-1.2, 3.1],[-0.9, 2.9],[-0.5, 2.6],[2.3, -1.1],[2.7, -1.5]
])
y_train = torch.tensor([0, 0, 0, 1, 1])
X_test = torch.tensor([[-0.8, 2.8],[2.6, -1.6],
])
y_test = torch.tensor([0, 1])from torch.utils.data import Dataset
class ToyDataset(Dataset):def __init__(self, X, y):self.features = Xself.labels = ydef __getitem__(self, index):one_x = self.features[index]one_y = self.labels[index]return one_x, one_ydef __len__(self):return self.labels.shape[0]
train_ds = ToyDataset(X_train, y_train)
test_ds = ToyDataset(X_test, y_test)from torch.utils.data import DataLoader
torch.manual_seed(123)
train_loader = DataLoader(dataset=train_ds,batch_size=2,shuffle=True,num_workers=1,drop_last=True
)
test_loader = DataLoader(dataset=test_ds,batch_size=2,shuffle=False,num_workers=1
)class NeuralNetwork(torch.nn.Module):def __init__(self, num_inputs, num_outputs):super().__init__()self.layers = torch.nn.Sequential(# 第一个隐藏层torch.nn.Linear(num_inputs, 30),torch.nn.ReLU(),# 第二个隐藏层torch.nn.Linear(30, 20),torch.nn.ReLU(),# 输出层torch.nn.Linear(20, num_outputs),)def forward(self, x):logits = self.layers(x)return logits# 使用accuracy(准确率)作为指标
def compute_accuracy(model, dataloader, device):model = model.eval()correct = 0.0total_examples = 0for idx, (features, labels) in enumerate(dataloader):# 将数据移动到指定的设备上features, labels = features.to(device), labels.to(device) # Newwith torch.no_grad():logits = model(features)# 获取预测结果并计算准确数量predictions = torch.argmax(logits, dim=1)compare = labels == predictionscorrect += torch.sum(compare)total_examples += len(compare)# 计算并返回准确率return (correct / total_examples).item()import torch.nn.functional as F
# 设置随机数种子,以确保可复现性
torch.manual_seed(123)
# 创建神经网络模型
model = NeuralNetwork(num_inputs=2, num_outputs=2)
# 根据设备可用情况选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 将模型移动到所选设备上
model = model.to(device)
# 定义优化器,使用随机梯度下降 (SGD)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
# 定义训练循环的 epoch 数量
num_epochs = 3
for epoch in range(num_epochs):model.train()for batch_idx, (features, labels) in enumerate(train_loader):features, labels = features.to(device), labels.to(device) logits = model(features)loss = F.cross_entropy(logits, labels) # 损失函数optimizer.zero_grad()loss.backward()optimizer.step()### 训练日志print(f"Epoch: {epoch+1:03d}/{num_epochs:03d}"f" | Batch {batch_idx:03d}/{len(train_loader):03d}"f" | Train/Val Loss: {loss:.2f}")model.eval()print('accuracy',str(compute_accuracy(model, train_loader, device=device)))

Epoch: 001/003 | Batch 000/002 | Train/Val Loss: 0.75
Epoch: 001/003 | Batch 001/002 | Train/Val Loss: 0.65
Epoch: 002/003 | Batch 000/002 | Train/Val Loss: 0.44
Epoch: 002/003 | Batch 001/002 | Train/Val Loss: 0.13
Epoch: 003/003 | Batch 000/002 | Train/Val Loss: 0.03
Epoch: 003/003 | Batch 001/002 | Train/Val Loss: 0.00
accuracy:1.0


http://www.ppmy.cn/ops/28144.html

相关文章

【Jenkins】持续集成与交付 (七):Gitlab添加组、创建用户、创建项目和源码上传到Gitlab仓库

🟣【Jenkins】持续集成与交付 (七):Gitlab添加组、创建用户、创建项目和源码上传到Gitlab仓库 1、创建组2、创建用户3、将用户添加到组中4、在用户组中创建项目5、源码上传到Gitlab仓库5.1 初始化版本控制5.2 将文件添加到暂存区5.3 提交代码到本地仓库5.4 推送代码到 Git…

数据结构-二叉树的遍历

二叉树的遍历广义上是指下面我们说的七种遍历 深度优先搜索 : 递归完成 前序 中序 后序 的遍历 广度优先搜索 : 层序遍历(借助队列) 非递归的迭代法完成前中后遍历(借助栈) 代码合集如下 package TreeDemo; import java.util.*; public class BinaryTreeTest {public static c…

RunnerGo四月更新:强化UI自动化测试与UI录制插件功能

RunnerGo最近更新的 UI自动化测试和UI录制插件可以让测试人员更高效地布置UI自动化场景。这次优化升级的插件录制能力,可以更准确的定位元素并执行步骤,并增加了局部截图功能,准确查看定位的元素位置等。 UI插件V2.0介绍 接下来,让…

力扣501,二叉树中的众数

501. 二叉搜索树中的众数 - 力扣(LeetCode) 给你一个含重复值的二叉搜索树(BST)的根节点 root ,找出并返回 BST 中的所有 众数(即,出现频率最高的元素)。 如果树中有不止一个众数&…

低代码工业组态数字孪生平台

2024 两会热词「新质生产力」凭借其主要特征——高科技、高效能及高质量,引发各界关注。在探索构建新质生产力的重要议题中,数据要素被视为土地、劳动力、资本和技术之后的第五大生产要素。数据要素赋能新质生产力发展主要体现为:生产力由生产…

Devops部署maven项目

这里讲下应用k8s集群devops持续集成部署maven项目的流程。 failed to verify certificate: x509: certificate signed by unknown authority 今天在执行kubectl get nodes的时候报的证书验证问题,看了一圈首次搭建k8s的都是高频出现的问题。 couldn’t get curren…

C# Winform父窗体打开新的子窗体前,关闭其他子窗体

随着Winform项目越来越多,界面上显示的窗体越来越多,窗体管理变得更加繁琐。有时候我们要打开新窗体,然后关闭多余的其他窗体,这个时候如果一个一个去关闭就会变得很麻烦,而且可能还会出现遗漏的情况。这篇文章介绍了三…

网络安全之弱口令与命令爆破(中篇)(技术进阶)

目录 一,什么是弱口令? 二,为什么会产生弱口令呢? 三,字典的生成 四,使用Burpsuite工具验证码爆破 总结 笔记改错 一,什么是弱口令? 弱口令就是容易被人们所能猜到的密码呗&a…