从0开始深度学习(33)——循环神经网络的简洁实现

server/2024/11/29 1:32:25/

本章使用Pytorch的API实现RNN上的语言模型训练

0 导入库

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import re
import math
from tqdm import tqdm

1 准备数据

需要对文本进行预处理,比如转换为小写、去除标点符号等,以减少词汇量并简化问题,然后构建词汇表,即创建一个字符到索引的映射和一个索引到字符的映射,最后将将文本转换为整数序列,这些整数代表词汇表中的位置。

# 1. 加载数据
def load_data(file_path):with open(file_path, 'r') as f:lines = f.readlines()text = ''.join([line.strip().lower() for line in lines])# 使用正则表达式去除标点符号和数字text = re.sub(r'[^\w\s]', '', text)  # 去除标点符号text = re.sub(r'\d+', '', text)      # 去除数字return text# 2. 文本预处理
def preprocess_text(text):tokens = list(text)  # 将文本切分为字符vocab = sorted(set(tokens))  # 构建词表token_to_idx = {token: idx for idx, token in enumerate(vocab)}  # 词元到索引的映射idx_to_token = {idx: token for token, idx in token_to_idx.items()}  # 索引到词元的映射token_indices = [token_to_idx[token] for token in tokens]  # 把文本转化为索引列表return token_indices, token_to_idx, idx_to_token, vocab

2 创建数据集

从文本中提取固定长度的子序列作为输入,并将紧随其后的字符作为目标输出,最后将这些序列转换为适合输入到RNN模型的张量格式

# 数据集类
class TextDataset(Dataset):def __init__(self, token_indices, seq_len):self.data = token_indicesself.seq_len = seq_lendef __len__(self):return len(self.data) - self.seq_lendef __getitem__(self, idx):# 输入数据是从当前位置到指定序列长度的位置的数据,即一个序列x = self.data[idx:idx + self.seq_len]# 目标数据是输入数据的下一个位置的数据,即单个字符y = self.data[idx + 1:idx + self.seq_len + 1]return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)# 转化为Tensor

3 构建RNN模型

使用Pytorch构建RNN模型

class SimpleRNN(nn.Module):def __init__(self, vocab_size, hidden_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_size # 隐藏层形状self.rnn = nn.RNN(vocab_size, hidden_size, batch_first=True)'''vocab_size:特征的数量,即词汇表的大小hidden_size:隐藏层的状态向量的维度batch_first:决定了输入和输出张量的形状如果batch_first=True,输入和输出张量的形状将是(batch_size,sequence_length, input_size)。如果batch_first=False,输入和输出张量的形状将是 (sequence_length, batch_size, input_size)。'''self.fc = nn.Linear(hidden_size, vocab_size)def forward(self, x, hidden=None):out, hidden = self.rnn(x, hidden)  # RNN层out = self.fc(out)  # 全连接层return out, hidden

4 训练模型

在训练前,需要把数据转化为one-hot编码,以增强特征属性,添加困惑度作为评价指标,使用早停法提前结束训练,避免过拟合

# 4. 训练模型
def train_model(model, dataloader, val_dataloader, criterion, vocab_size, optimizer, device, num_epochs=100, patience=5, min_delta=0.001):assert vocab_size is not None, "vocab_size must be provided"model.to(device)  # 将模型移动到指定设备model.train()  # 设置模型为训练模式best_val_loss = float('inf')epochs_no_improve = 0for epoch in range(num_epochs):total_loss = 0# 训练阶段with tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs} (Training)', unit='batch') as tepoch:for inputs, targets in tepoch:# 将数据移动到指定设备inputs, targets = inputs.to(device), targets.to(device)  # 将输入数据转换为 one-hot 编码inputs_one_hot = F.one_hot(inputs, num_classes=vocab_size).float()# 清零梯度optimizer.zero_grad()  # 前向传播outputs, _ = model(inputs_one_hot)# 计算损失loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))# 反向传播和优化loss.backward()optimizer.step()total_loss += loss.item()tepoch.set_postfix(loss=loss.item())average_loss = total_loss / len(dataloader)perplexity = math.exp(average_loss)  # 计算困惑度print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {average_loss:.4f}, Perplexity: {perplexity:.4f}')# 验证阶段model.eval()val_loss = 0with torch.no_grad():with tqdm(val_dataloader, desc=f'Epoch {epoch+1}/{num_epochs} (Validation)', unit='batch') as tepoch:for inputs, targets in tepoch:inputs, targets = inputs.to(device), targets.to(device)inputs_one_hot = F.one_hot(inputs, num_classes=vocab_size).float()outputs, _ = model(inputs_one_hot)loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))val_loss += loss.item()tepoch.set_postfix(loss=loss.item())average_val_loss = val_loss / len(val_dataloader)print(f'Validation Loss: {average_val_loss:.4f}')# 检查是否需要早停if average_val_loss < best_val_loss - min_delta:best_val_loss = average_val_lossepochs_no_improve = 0else:epochs_no_improve += 1if epochs_no_improve >= patience:print(f'Early stopping at epoch {epoch+1}')breakmodel.train()  # 回到训练模式

5 预测模型

我们的输入必须大于seq_len,不然就不符合输入格式(可以使用补全,这里不展开),对于单词或者句子,需要把他们分割为字符,然后转换为token序列,作为输入

def predict(model, token_to_idx, idx_to_token, start_text, length, device, unk_token='<UNK>'):model.to(device)model.eval()# 将起始文本转换为字符 token 序列input_tokens = []for char in start_text:if char in token_to_idx:input_tokens.append(token_to_idx[char])else:if unk_token in token_to_idx:input_tokens.append(token_to_idx[unk_token])  # 使用 <UNK> 表示未知字符else:raise ValueError(f"Character '{char}' not in vocabulary and no '<UNK>' token provided.")# 转换为 PyTorch Tensorinput_tensor = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0).to(device)generated_tokens = []with torch.no_grad():hidden = Nonefor i in range(length):# 将输入数据转换为 one-hot 编码inputs_one_hot = F.one_hot(input_tensor, num_classes=len(token_to_idx)).float()# 前向传播outputs, hidden = model(inputs_one_hot, hidden)# 获取最后一个时间步的输出output = outputs[0, -1, :]# 获取最大概率的 token_, top_index = output.topk(1)predicted_token = idx_to_token[top_index.item()]# 添加预测的 token 到生成的序列中generated_tokens.append(predicted_token)# 更新输入 tensorinput_tensor = torch.tensor([[top_index.item()]], dtype=torch.long).to(device)# 将生成的字符序列拼接成字符串generated_text = ''.join(generated_tokens)return start_text + generated_text

6 主函数

# 读取数据
file_path = '/home/caser/code/data/timemachine.txt'
text = load_data(file_path)
# 预处理数据
token_indices, token_to_idx, idx_to_token, vocab=preprocess_text(text)# 参数设置
seq_len = 5
batch_size = 64
hidden_size = 128
learning_rate = 0.01
num_epochs = 100
patience = 5  # 早停法的耐心值
min_delta = 0.001  # 早停法的最小改进阈值# 创建数据集和数据加载器
dataset = TextDataset(token_indices, seq_len)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)# 初始化模型和优化器
vocab_size = len(vocab)
model = SimpleRNN(vocab_size, hidden_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 训练模型
train_model(model, train_dataloader, val_dataloader, criterion, vocab_size, optimizer, device, num_epochs, patience, min_delta)# 进行预测
start_text = 'the time traveller '
predicted_text = predict(model, token_to_idx, idx_to_token, start_text, length=50, device=device)
print(predicted_text)

运行结果:
在这里插入图片描述


http://www.ppmy.cn/server/145779.html

相关文章

etcd、kube-apiserver、kube-controller-manager和kube-scheduler有什么区别

在我们部署K8S集群的时候 初始化master节点之后&#xff08;在master上面执行这条初始化命令&#xff09; kubeadm init --apiserver-advertise-address10.0.1.176 --image-repository registry.aliyuncs.com/google_containers --kubernetes-version v1.16.0 --service…

云技术-docker

声明&#xff01; 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下&#xff0c;如涉及侵权马上删除文章&#xff0c;笔记只是方便各位师傅的学习和探讨&#xff0c;文章所提到的网站以及内容&#xff0c;只做学习交流&#xff0c;其他均与本人以及泷羽sec团…

Spring Boot 实战:分别基于 MyBatis 与 JdbcTemplate 的数据库操作方法实现与差异分析

1. 数据库新建表 CREATE TABLE table_emp(id INT AUTO_INCREMENT,emp_name CHAR(100),age INT,emp_salary DOUBLE(10,5),PRIMARY KEY(id) );INSERT INTO table_emp(emp_name,age,emp_salary) VALUES("tom",18,200.33); INSERT INTO table_emp(emp_name,age,emp_sala…

CentOS上如何离线批量自动化部署zabbix 7.0版本客户端

CentOS上如何离线批量自动化部署zabbix 7.0版本客户端 管理的服务器大部分都是CentOS操作系统&#xff0c;版本主要是CentOS 7。因为监控服务器需要&#xff0c;要在前两天搭建的Zabbix 7.0系统上把这些CentOS 7系统都监控起来。因为服务器数量众多&#xff0c;而且有些服务器…

redislite:轻量级的嵌入式 Redis 解决方案

在现代应用程序中&#xff0c;数据存储和管理是至关重要的。Redis 是一个非常流行的内存数据结构存储&#xff0c;广泛用于缓存、会话存储和消息传递等场景。然而&#xff0c;在某些情况下&#xff0c;开发者并不希望在本地或服务器上维护一个独立的 Redis 实例。这时&#xff…

【cocos creator】下拉框

https://download.csdn.net/download/K86338236/90038176 const { ccclass, property } cc._decorator;type DropDownOptionData {optionString?: string,optionSf?: cc.SpriteFrame } type DropDownItemData {label: cc.Label,sprite: cc.Sprite,toggle: cc.Toggle }cccl…

JVM调优篇之JVM基础入门AND字节码文件解读

目录 Java程序编译class文件内容常量池附录-访问标识表附录-常量池类型列表 Java程序编译 Java文件通过编译成class文件后&#xff0c;通过JVM虚拟机解释字节码文件转为操作系统执行的二进制码运行。 规范 Java虚拟机有自己的一套规范&#xff0c;遵循这套规范&#xff0c;任…

行为型模式-命令模式

命令模式&#xff08;Command Pattern&#xff09;是一种行为设计模式&#xff0c;它将请求封装为一个对象&#xff0c;从而使你可以用不同的请求、队列或者日志来参数化对象。命令模式允许请求的发送者与接收者完全解耦。 关键组成部分 Command&#xff08;命令接口&#xff0…