一个使用 PyTorch 实现的中文聊天机器人对话生成模型916

devtools/2024/9/25 14:54:58/

这是一个使用 PyTorch 实现的中文聊天机器人对话生成模型。

1数据准备

代码假设有两个文件:questions.txt 和 answers.txt,它们分别包含输入和输出序列。
load_data 函数读取这些文件并返回一个句子列表。
build_vocab 函数通过遍历句子来构建词汇表字典 word2index 和 index2word。

2模型定义

Encoder 和 Decoder 类定义了 seq2seq 模型的架构。
Encoder 接收输入序列并输出隐藏状态和细胞状态。
Decoder 接收编码器的隐藏状态和细胞状态,并生成输出序列。
Seq2Seq 类将编码器和解码器组合,并添加一个分类头来完成辅助任务。

3训练

train 函数使用 Adam 优化器和交叉熵损失来训练模型。
模型在指定的 epoch 数中进行训练,并在每个 epoch 中计算和打印损失。
模型在训练完成后被保存到文件 model.pth 中。

4预测

predict 函数接收输入句子并使用训练好的模型生成输出序列。

5数据增强

data_augmentation 函数对输入句子应用各种数据增强技术,包括:

  • 随机插入 token
  • 随机删除 token
  • 随机交换 token

* 6回译

*7 随机替换 token 为同义词

注意,代码的一些部分是不完整或注释掉的,因此您可能需要修改或完成它们以适应您的具体使用场景。
下面是代码:

python">import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import random
import tkinter as tk
import jieba
import matplotlib.pyplot as plt
import os
from googletrans import Translator  # 用于回译# 中文词汇表和索引映射
word2index = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
index2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}# 使用 jieba 进行中文分词
def tokenize_chinese(sentence):tokens = jieba.lcut(sentence)return tokens# 构建词汇表
def build_vocab(sentences):global word2index, index2wordvocab_size = len(word2index)for sentence in sentences:for token in tokenize_chinese(sentence):if token not in word2index:word2index[token] = vocab_sizeindex2word[vocab_size] = tokenvocab_size += 1return vocab_size# 将句子转换为张量
def sentence_to_tensor(sentence, max_length=50):tokens = tokenize_chinese(sentence)indices = [word2index.get(token, word2index["<UNK>"]) for token in tokens]indices += [word2index["<PAD>"]] * (max_length - len(indices))return torch.tensor(indices, dtype=torch.long), len(indices)# 读取问题和答案文件
def load_data(file_path):with open(file_path, 'r', encoding='utf-8') as f:lines = f.read().splitlines()return lines# 假设数据文件是 'questions.txt' 和 'answers.txt'
question_file = 'questions.txt'
answer_file = 'answers.txt'
questions = load_data(question_file)
answers = load_data(answer_file)# 获取词汇表大小
vocab_size = build_vocab(questions + answers)# 定义数据集
class ChatDataset(Dataset):def __init__(self, questions, answers, labels):self.questions = questionsself.answers = answersself.labels = labelsdef __len__(self):return len(self.questions)def __getitem__(self, idx):input_tensor, input_length = sentence_to_tensor(self.questions[idx])target_tensor, target_length = sentence_to_tensor(self.answers[idx])label = self.labels[idx]return input_tensor, target_tensor, input_length, target_length, label# 自定义 collate 函数
def collate_fn(batch):inputs, targets, input_lengths, target_lengths, labels = zip(*batch)inputs = nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=word2index["<PAD>"])targets = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=word2index["<PAD>"])labels = torch.tensor(labels)return inputs, targets, torch.tensor(input_lengths), torch.tensor(target_lengths), labels# 创建数据集和数据加载器
labels = [0] * len(questions)  # 假设所有数据都属于同一类别
dataset = ChatDataset(questions, answers, labels)
dataloader = DataLoader(dataset, batch_size=60, shuffle=True, collate_fn=collate_fn)# 定义模型结构
class Encoder(nn.Module):def __init__(self, input_size, hidden_size, num_layers=1):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)def forward(self, input_seq, input_lengths, hidden=None):embedded = self.embedding(input_seq)packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True, enforce_sorted=False)outputs, hidden = self.gru(packed, hidden)outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)return outputs, hiddenclass Decoder(nn.Module):def __init__(self, output_size, hidden_size, num_layers=1):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)self.out = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, input_step, hidden, encoder_outputs):embedded = self.embedding(input_step)gru_output, hidden = self.gru(embedded, hidden)output = self.softmax(self.out(gru_output.squeeze(1)))return output, hiddenclass Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device, tokenizer):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = deviceself.tokenizer = tokenizerself.classifier = nn.Linear(encoder.hidden_size, 1)  # 分类头def forward(self, input_tensor, target_tensor, input_lengths, target_lengths, teacher_forcing_ratio=0.5):batch_size = input_tensor.size(0)max_target_len = max(target_lengths)vocab_size = self.decoder.out.out_featuresoutputs = torch.zeros(batch_size, max_target_len, vocab_size).to(self.device)encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths)decoder_input = torch.tensor([[word2index["<SOS>"]] * batch_size], device=self.device).transpose(0, 1)decoder_hidden = encoder_hiddenfor t in range(max_target_len):decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)outputs[:, t, :] = decoder_outputtop1 = decoder_output.argmax(1)decoder_input = target_tensor[:, t].unsqueeze(1) if random.random() < teacher_forcing_ratio else top1.unsqueeze(1)classification_output = self.classifier(encoder_hidden[-1])  # 分类任务输出return outputs, classification_output# 实例化模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(vocab_size, hidden_size=256).to(device)
decoder = Decoder(vocab_size, hidden_size=256).to(device)# 检查是否存在已保存的模型和分词器
model_path = './models/model.pth'
tokenizer_path = './models/tokenizer.pth'if os.path.exists(model_path) and os.path.exists(tokenizer_path):print("Loading existing model and tokenizer...")model = torch.load(model_path)tokenizer = torch.load(tokenizer_path)word2index = tokenizer['word2index']index2word = tokenizer['index2word']
else:print("Creating new model and tokenizer...")model = Seq2Seq(encoder, decoder, device, tokenizer={'word2index': word2index, 'index2word': index2word}).to(device)tokenizer = {'word2index': word2index, 'index2word': index2word}def train(model, dataloader, num_epochs, learning_rate=0.001, save_path='model.pth'):criterion = nn.CrossEntropyLoss(ignore_index=word2index["<PAD>"])classifier_criterion = nn.BCEWithLogitsLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)loss_values = []for epoch in range(num_epochs):model.train()total_loss = 0total_class_loss = 0for inputs, targets, input_lengths, target_lengths, labels in dataloader:inputs, targets = inputs.to(device), targets.to(device)input_lengths = input_lengths.cpu().clone().detach()target_lengths = target_lengths.cpu().clone().detach()labels = labels.to(device).float()optimizer.zero_grad()outputs, classification_output = model(inputs, targets, input_lengths, target_lengths)loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))class_loss = classifier_criterion(classification_output.squeeze(), labels)total_loss += loss.item()total_class_loss += class_loss.item()(loss + class_loss).backward()optimizer.step()avg_loss = total_loss / len(dataloader)avg_class_loss = total_class_loss / len(dataloader)loss_values.append(avg_loss)print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}, Class Loss: {avg_class_loss:.4f}")# 验证model.eval()with torch.no_grad():val_loss = 0val_class_loss = 0correct_predictions = 0total_samples = 0for inputs, targets, input_lengths, target_lengths, labels in dataloader:inputs, targets = inputs.to(device), targets.to(device)input_lengths = input_lengths.cpu().clone().detach()target_lengths = target_lengths.cpu().clone().detach()labels = labels.to(device).float()outputs, classification_output = model(inputs, targets, input_lengths, target_lengths, teacher_forcing_ratio=0)loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))class_loss = classifier_criterion(classification_output.squeeze(), labels)val_loss += loss.item()val_class_loss += class_loss.item()# 计算准确率predicted_indices = outputs.argmax(dim=2)for pred, target, target_len in zip(predicted_indices, targets, target_lengths):pred = pred[:target_len]target = target[:target_len]correct = (pred == target).all().item()if correct:correct_predictions += 1total_samples += 1# 计算分类准确率predicted_labels = (classification_output > 0).float()correct_labels = (predicted_labels == labels).sum().item()total_labels = labels.size(0)correct_predictions += correct_labelstotal_samples += total_labelsval_accuracy = correct_predictions / total_samples if total_samples > 0 else 0val_avg_loss = val_loss / len(dataloader)val_avg_class_loss = val_class_loss / len(dataloader)print(f"Validation Loss: {val_avg_loss:.4f}, Validation Class Loss: {val_avg_class_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")torch.save(model, save_path)plt.plot(range(1, num_epochs + 1), loss_values)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.show()# 预测函数
def predict(question):model.eval()with torch.no_grad():input_tensor, input_length = sentence_to_tensor(question)input_tensor = input_tensor.unsqueeze(0).to(device)input_length = [input_length]encoder_outputs, encoder_hidden = model.encoder(input_tensor, input_length)decoder_input = torch.tensor([[word2index["<SOS>"]]], device=device)decoder_hidden = encoder_hiddendecoded_words = []for _ in range(50):  # 设置一个较大的最大长度来避免潜在的循环decoder_output, decoder_hidden = model.decoder(decoder_input, decoder_hidden, encoder_outputs)top1 = decoder_output.argmax(1).item()if top1 == word2index["<EOS>"]:breakelse:decoded_words.append(index2word[top1])decoder_input = torch.tensor([[top1]], device=device)return ''.join(decoded_words)# 数据增强函数
def data_augmentation(sentence):tokens = tokenize_chinese(sentence)augmented_sentence = []# 随机插入if random.random() < 0.1:insert_token = random.choice(list(word2index.keys()))insert_index = random.randint(0, len(tokens))tokens.insert(insert_index, insert_token)# 随机删除if random.random() < 0.1:delete_index = random.randint(0, len(tokens) - 1)del tokens[delete_index]# 随机交换if len(tokens) > 1 and random.random() < 0.1:index1, index2 = random.sample(range(len(tokens)), 2)tokens[index1], tokens[index2] = tokens[index2], tokens[index1]# 回译if random.random() < 0.1:translator = Translator()translated = translator.translate(sentence, src='zh-cn', dest='en').textback_translated = translator.translate(translated, src='en', dest='zh-cn').texttokens = tokenize_chinese(back_translated)# 随机替换if random.random() < 0.1:replace_index = random.randint(0, len(tokens) - 1)tokens[replace_index] = random.choice(list(word2index.keys()))# 同义词替换if random.random() < 0.1:syn_dict = {'好': ['优秀', '出色', '质量高'], '坏': ['差', '劣质', '质量低']}for i, token in enumerate(tokens):if token in syn_dict:tokens[i] = random.choice(syn_dict[token])augmented_sentence = ''.join(tokens)return augmented_sentence# 创建图形界面
def on_predict():question = question_entry.get()if question.strip() == "":result_label.config(text="请输入有效的问题。")returnanswer = predict(question)answer = " ".join(answer.split())result_label.config(text=f'Answer: {answer}')def on_clear():question_entry.delete(0, 'end')# 创建主窗口
root = tk.Tk()
root.title("羲和")# 输入框
question_label = tk.Label(root, text="请输入你的问题:")
question_label.pack()
question_entry = tk.Entry(root, width=50)
question_entry.pack()# 生成按钮
generate_button = tk.Button(root, text="生成答案", command=on_predict)
generate_button.pack(side=tk.LEFT, padx=10)# 清除按钮
clear_button = tk.Button(root, text="清除", command=on_clear)
clear_button.pack(side=tk.LEFT)# 结果标签
result_label = tk.Label(root, text="")
result_label.pack(pady=10)# 添加提示信息
tip_label = tk.Label(root, text="提示:本模型可能存在一定的局限性,答案仅供参考。")
tip_label.pack()question_entry.focus_set()  # 生成答案后自动选中输入框# 主事件循环
root.mainloop()# 在程序结束时释放 GPU 内存
if torch.cuda.is_available():torch.cuda.empty_cache() 

http://www.ppmy.cn/devtools/117012.html

相关文章

数据库连接池

1、连接池介绍 1、操作数据库都需要创建连接&#xff0c;操作完成还需要关闭连接 2、创建连接和关闭连接需要可能比执行sql需要的时间都长 3、一个网站需要高频繁的访问数据库&#xff0c;如果短时间频繁的访问数据库服务器&#xff0c;就容易造成服务器的宕机&#xff0c;即…

Nexus学习

系列文章目录 第一章 基础知识、数据类型学习 第二章 万年历项目 第三章 代码逻辑训练习题 第四章 方法、数组学习 第五章 图书管理系统项目 第六章 面向对象编程&#xff1a;封装、继承、多态学习 第七章 封装继承多态习题 第八章 常用类、包装类、异常处理机制学习 第九章 集…

数据结构---二叉搜索树(二叉排序树)

什么是二叉排序树 二叉搜索树又是二叉排序树&#xff0c;当我们的是一颗空树或者具有以下性质时&#xff1a; 左子树不为空&#xff0c;左子树上的值都小于我们的根节点上的值。右子树不为空时&#xff0c;右子树上的值都大于我们的根节点上的值左右子树都是二叉搜索树&#…

Java笔试面试题AI答之设计模式(2)

文章目录 6. 什么是单例模式&#xff0c;以及他解决的问题&#xff0c;应用的环境 &#xff1f;解决的问题应用的环境实现方式 7. 什么是工厂模式&#xff0c;以及他解决的问题&#xff0c;应用的环境 &#xff1f;工厂模式简述工厂模式解决的问题工厂模式的应用环境工厂模式的…

ESP32-WROOM-32 [创建AP站点-客户端-TCP透传]

简介 基于ESP32-WROOM-32 开篇(刚买)&#xff0c; 本篇讲的是基于固件 ESP32-WROOM-32-AT-V3.4.0.0&#xff08;内含用户指南, 有AT指令说明&#xff09;的TCP透传设置与使用 设备连接 TTL转USB线, 接ESP32 板 的 GND&#xff0c;RX2&#xff0c; TX2 指令介绍 注意,下面指…

Python知识点:如何使用Python与.NET进行互操作(IronPython)

开篇&#xff0c;先说一个好消息&#xff0c;截止到2025年1月1日前&#xff0c;翻到文末找到我&#xff0c;赠送定制版的开题报告和任务书&#xff0c;先到先得&#xff01;过期不候&#xff01; IronPython 是一个开源的 Python 实现&#xff0c;它运行在 .NET 平台上&#xf…

用python给markdown文档加空行

在文本格式化过程中&#xff0c;我们通常会在行与行之间添加一个空行&#xff0c;以提升文档的阅读体验&#xff0c;使其外观更加整洁。 若您处理的文档篇幅较短&#xff0c;手动添加空行也是一个可行的选择。 此外&#xff0c;为了简化这一过程&#xff0c;您可以采用以下Pyt…

Python 路径管理新纪元:pathlib 模块带你玩转文件系统

引言 在早期的 Python 版本中&#xff0c;处理文件路径往往需要依赖于 os 和 os.path 模块&#xff0c;虽然它们功能强大&#xff0c;但复杂的 API 设计使得代码可读性较差。随着 Python 3.4 的发布&#xff0c;pathlib 模块正式登场&#xff0c;它以对象导向的方式简化了路径…