Pytorch封装简单RNN模型,进行中文训练及文本预测

ops/2024/9/23 14:24:42/

简述

使用pytorch封装简单RNN模型,使用单层nn.RNNnn.Linear等实现,然后做简单的文本预测。

数据集

代码参考李沐:https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-concise.html,但他使用的是一篇英文小说,
这里改为使用COIG-CQIA中文数据集中的:douban_book_introduce.jsonlruozhiba_ruozhiba_ruozhiba.jsonl两个文件,本文目的是为了学习rnn,所以数据集比较简单,不过这个数据集由于都是问答形式,不像小说那样有主题性,所以感觉学习效果不好。理想的应该还是找个中文长篇小说之类。

COIG-CQIA: https://huggingface.co/datasets/m-a-p/COIG-CQIA

另外由于COIG-CQIA的数据是指令问答形式的json文件,所以这里稍作处理,改为单个问题+答案为一行的纯文本txt格式, 去除其它json字段及各种符号。

代码如下:

def jsonl_to_txt(dir_path):  dict_list = []  jsonl_list = os.listdir(dir_path)  qa_list = list()  chars_to_remove = r'[,。?;、:“”:!~()『』「」【】\"\[\]➕〈〉/<>()‰\%《》\*\?\-\.…·○01234567890123456789•\n\t abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ—*]'  for jsonl in jsonl_list:  path = os.path.join(dir_path, jsonl)  print(path)  with open(path, 'r', encoding='utf-8') as f:  jsonl_data = f.readlines()  for line in jsonl_data:  line_dict = JSON.loads(line)  qa = line_dict['instruction'] + line_dict['output']  qa = re.sub(chars_to_remove, '', qa).strip()  qa_list.append(qa)  path = os.path.join(dir_path, 'chengyu_qa.txt')  with open(path, 'w', encoding='utf-8') as f:  f.write('\n'.join(qa_list))  if __name__ == '__main__':  dir_path = '../data/COIG-CQIA'  jsonl_to_txt(dir_path)  print()

上面处理完毕后,还需要进行词元化、构建词典等步骤,参考:
python实现简单中文词元化、词典构造、时序数据集封装等-CSDN博客

模型封装

RNN — PyTorch 2.4 documentation

可以先观察一下tensorboard的add_graph函数对模型可视化后的结构:

在这里插入图片描述

这里使用单层的RNN(nn.RNN有默认参数num_layers=1),nn.functional.one_hot是为了实现单词的向量化表示,后续可以优化成nn.Embedding来做词向量。

nn.functional.one_hot前将x进行了转置,这里有点抽象,来关注一下nn.RNN的参数要求,便可理解。

先看x的初始shape为(batch_size, time_size),转置并向量化后为(time_size, batch_size, vocab_size)

若不转置直接向量化,则为(batch_size, time_size, vocab_size),实际上这两种格式的数据nn.RNN都支持。

但若为(batch_size, time_size, vocab_size)形式,则需在创建nn.RNN实例时指定参数batch_first=False。

在这里插入图片描述

另外,还需要提供一个初始的隐状态,这里用init_state函数实现。

在这里插入图片描述

class SimpleRNNModel(nn.Module):  def __init__(self, vocab_size, hidden_size):  super(SimpleRNNModel, self).__init__()  self.vocab_size = vocab_size  self.hidden_size = hidden_size  self.rnn = nn.RNN(vocab_size, hidden_size)  self.linear = nn.Linear(hidden_size, vocab_size)  def forward(self, x, hidden_state=None):  x = nn.functional.one_hot(x.T.long(), num_classes=self.vocab_size)  x = x.to(torch.float32)  outputs, hidden_state = self.rnn(x, hidden_state)  # rrn的outputs.shape(N, L, D*H)  outputs = outputs.reshape(-1, self.hidden_size)  outputs = self.linear(outputs)  return outputs, hidden_state  def init_state(self, device, batch_size=1):  return torch.zeros((self.rnn.num_layers, batch_size, self.hidden_size), device=device)  

梯度裁剪

源自李沐:https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-scratch.html

def grad_clipping(net, max_norm):  if isinstance(net, nn.Module):  params = [p for p in net.parameters() if p.requires_grad]  else:  params = net.params  norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))  if norm > max_norm:  for param in params:  param.grad[:] *= max_norm / norm

模型训练

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
print(f'\ndevice: {device}')  corpus, vocab = load_corpus("../data/COIG-CQIA/qa_list.txt")  vocab_size = len(vocab)  
hidden_size = 256  
epochs = 5  
batch_size = 50  
learning_rate = 0.01  
time_size = 4  
max_grad_max_norm = 0.5  dataset = make_dataset(corpus=corpus, time_size=time_size)  
data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)  net = SimpleRNNModel(vocab_size, hidden_size)  
net.to(device)  # print(net.state_dict())  criterion = nn.CrossEntropyLoss()  
criterion.to(device)  
optimizer = optim.Adam(net.parameters(), lr=learning_rate)  writer = SummaryWriter('./train_logs')  
# 随便定义个输入, 好使用add_graph  
tmp = torch.rand((batch_size, time_size)).to(device)  
writer.add_graph(net, tmp)  loss_counter = 0  
total_loss = 0  
ppl_list = list()  
total_train_step = 0  for epoch in range(epochs):  print('------------Epoch {}/{}'.format(epoch + 1, epochs))  for X, y in data_loader:  X, y = X.to(device), y.to(device)  # 如果各个批次间的时序是连续的,则可以把上次的hidden_state传入下个批次, 不然就要重置hidden_state  # 这里batch_size=X.shape[0]是因为在加载数据时, DataLoader没有设置丢弃不完整的批次, 所以存在实际批次不满足设定的batch_size  hidden_state = net.init_state(batch_size=X.shape[0], device=device)  outputs, hidden_state = net(X, hidden_state=hidden_state)  optimizer.zero_grad()  # y也变成 时间序列*批次大小的行数, 才和 outputs 一致  y = y.T.reshape(-1)  # 交叉熵的第二个参数需要LongTorch  loss = criterion(outputs, y.long())  loss.backward()  # 求完梯度之后可以考虑梯度裁剪, 再更新梯度  grad_clipping(net, max_grad_max_norm)  optimizer.step()  total_loss += loss.item()  loss_counter += 1  total_train_step += 1  if total_train_step % 10 == 0:  print(f'Epoch: {epoch + 1}, 累计训练次数: {total_train_step}, 本次loss: {loss.item():.4f}')  writer.add_scalar('train_loss', loss.item(), total_train_step)  ppl = np.exp(total_loss / loss_counter)  ppl_list.append(ppl)  print(f'Epoch {epoch + 1} 结束, batch_loss_average: {total_loss / loss_counter}, perplexity: {ppl}')  writer.add_scalar('ppl', ppl, epoch + 1)  total_loss = 0  loss_counter = 0  torch.save(net.state_dict(), './save/epoch_{}_ppl_{}.pth'.format(epoch + 1, ppl))  writer.close()

tensorboard训练过程观察

横轴为训练epoch。

在这里插入图片描述

横轴为训练次数。

在这里插入图片描述

文本预测

这里首先完善模型的预测函数(该函数放到模型中):

def predict(self, prefix, num_preds, vocab, device):  state = self.init_state(batch_size=1, device=device)  # prefix为字符, 转成索引  outputs = [vocab.word2idx(prefix[0])]  get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))  # 一个字符一个字符跑一遍, 对用户输入进行预热, 即对输入的各个字符间建立联系  for y in prefix[1:]:  # 预热期  _, state = self.forward(get_input(), state)  outputs.append(vocab.word2idx(y))  # 刚好每次都用上一次的预测值做输入  for _ in range(num_preds):  # 预测num_preds步  y, state = self.forward(get_input(), state)  outputs.append(int(y.argmax(dim=1).reshape(1)))  return ''.join([vocab.idx2word(i) for i in outputs])

实现对提示词处理及预测函数的调用:

注意:这里的语料库应和训练使用的一致。

def predict(state_dict_path, vocab, prefix=None, num_preds=3):  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  vocab_size = len(vocab)  hidden_size = 256  net = SimpleRNNModel(vocab_size, hidden_size).to(device)  net.load_state_dict(torch.load(state_dict_path, map_location=device, weights_only=True))  net.eval()  with torch.no_grad():  outputs = net.predict(prefix=prefix, num_preds=num_preds, vocab=vocab, device=device)  return outputs  if __name__ == '__main__':  corpus, vocab = load_corpus("../data/COIG-CQIA/qa_list.txt")  # corpus, vocab = load_corpus("../data/COIG-CQIA/chengyu_qa.txt")  # print(len(vocab))  # idx = [vocab.word2idx(ch) for ch in prefix]  path = "../save/Simple/新建文件夹/state_dict-time_size_30-ppl_1.pth"  prefix = "有什么超赞的诗句"  print(f'提示词: {prefix}')  outputs = predict(path, vocab, prefix=prefix, num_preds=22)  print(f'预测输出: {outputs}\n')

http://www.ppmy.cn/ops/99976.html

相关文章

jvm监控工具一览

下面是对 BTrace、JAD、JMAP、JSTAT、JSTACK、JINFO 以及 MARK 工具的比较表&#xff1a; 工具/属性功能适用场景使用难度是否侵入式是否需要重启 JVMBTrace动态跟踪和监控 Java 应用程序性能分析、故障排查、日志收集、安全监控中等无侵入式否JAD反编译 Java 字节码文件&…

ubuntu 不生成core 的可能原因

一、首先检查 $ cat /proc/sys/kernel/core_pattern $ cat /proc/sys/kernel/core_pattern|/usr/share/apport/apport -p%p -s%s -c%c -d%d -P%P -u%u -g%g -- %E 系统当前的/proc/sys/kernel/core_pattern设置为&#xff1a; |/usr/share/apport/apport -p%p -s%s -c%c -d…

稳石机器人 | 工业级AMR S1200L,专为多样化需求设计,柔性拓展更易用

近日&#xff0c;稳石机器人重磅推出基于新品控制器ROC1000的全新移动机器人AMR S1200L&#xff0c;专为满足生产制造和仓储物流的多样化需求而设计&#xff0c;无需改造现场&#xff0c;最快可在1周内完成部署。 重载型AMR-S1200L设计注重实用性和灵活性&#xff0c;可在室内…

Springcloud从零开始---Service业务模块(三)

上篇&#xff1a;Springcloud从零开始---Zuul&#xff08;二&#xff09;-CSDN博客 Service模块是客户端模块&#xff0c;用户编写业务逻辑代码和功能实现。前端请求发送到Zuul网关再有网关发送到Service服务&#xff0c;可以是系统的安全性提升。 开始继上篇Springcloud从零…

软件测试——设计测试用例

用例 边界值 取边界值次边界值边界值有效则次边界值取有效&#xff0c;二者相反 场景法 这些具体的方法&#xff0c;旨在提高我们的测试思路提高我们设计测试用例的能力 正交表法 1.分析需求 2.使用工具 只填写部分时如何选择 如输入选项有5种&#xff0c;则需要32种&…

集合及数据结构第十节(下)————常用接口介绍、堆的应用和java对象的比较

系列文章目录 集合及数据结构第十节&#xff08;下&#xff09;————常用接口介绍和堆的应用 常用接口介绍和堆的应用 PriorityQueue的特性.PriorityQueue常用接口介绍top-k问题堆排序PriorityQueue中插入对象元素的比较.对象的比较.集合框架中PriorityQueue的比较方式 文…

软件项目需求分析报告(doc原件全文)

第3章 技术要求 3.1 软件开发要求 第4章 项目建设内容 第5章 系统安全需求 5.1 物理设计安全 5.2 系统安全设计 5.3 网络安全设计 5.4 应用安全设计 5.5 对用户安全管理 5.6 其他信息安全措施 第6章 其他非功能需求 6.1 性能设计 6.2 稳定性设计 6.3 安全性设计 6.4 兼容性设计…

卡通头像生成器.exe

下载 关键代码 void Widget::PostRequest() {if(ui->lineEdit->text().isEmpty()){qDebug()<<"lineEdit is empty";//生成随机数auto randNum QRandomGenerator::global()->generate();auto randUrl url.url() QString::number(randNum) "…