Transformer应用之构建聊天机器人(二)

news/2024/11/23 3:11:29/

四、模型训练解析

在PyTorch提供的“Chatbot Tutorial”中,关于训练提到了2个小技巧:

  • 使用”teacher forcing”模式,通过设置参数“teacher_forcing_ratio”来决定是否需要使用当前标签词汇来作为decoder的下一个输入,而不是把decoder当前预测出来的词汇当做decoder的下一个输入,这是因为存在这样的情况,如果当前预测出来的词汇跟输入词汇从语义上来讲没有多大关联时,如果继续使用预测出来的词汇来训练模型,有可能就会造成比较大的预测偏差,从而导致模型训练后的预测效果很差,如果改为直接使用输入词汇对应的目标词汇(标签)来作为decoder的下一个输入,相当于进行强制纠偏,使decoder训练时输出与输入之间不至于出现偏差很大的情况。
  • 第2个小技巧是使用梯度裁剪(Gradient Clipping),这是一种常用的防止梯度爆炸的技术。在深度学习训练过程中,因为网络层数较多,梯度可能会非常大,导致模型无法收敛。梯度裁剪的目的就是限制梯度的大小,使其不超过一个预设的阈值,从而避免梯度爆炸的问题。

训练过程如下:

  1. 输入语句正向传播通过encoder
  2. 使用SOS token作为decoder的初始输入,使用encoder的final hidden state来初始化decoder的hidden state
  3. Decoder端根据输入单步执行产生输出
  4. 如果执行”teacher forcing”模式,则把当前对应的目标词汇(标签)作为decoder的下一个输入,否则使用当前decoder的输出词汇作为decoder的下一个输入
  5. 计算并累加损失
  6. 执行反向传播
  7. 执行梯度裁剪
  8. 更新decoder和encoder的模型参数

以下是代码示例:

以下是Transformer模型训练代码示例,

  • 首先把输入sequence(对话输入),输出sequence(对话输出),以及各自的mask传入模型做正向传播
  • 计算预测结果与标签的损失,然后反向传播更新模型参数
  • 训练时可以使用验证集(dev dataset)对训练效果进行评估

五、模型预测(推理)过程解析

下面这个图描述了Transformer的预测推理过程:

  • 假设使用两个encoder和两个decoder来构成这个Transformer模型,首先把输入语句转为embedding词向量,并加入位置编码信息
  • 正向传播通过encoder1,它的输出再通过encoder2,期间会使用多头注意力机制对输入序列中的每个词向量并行地进行注意力Q,K,V的计算
  • Decoder1使用<START> token进行初始化,并使用带掩码多头注意力机制进行计算,并且需要根据前面encoder2的输出进行注意力的计算,然后输出预测得到的词汇
  • Decoder1输出的词汇作为decoder2的输入,同样decoder2在进行多头注意力计算时也需要使用encoder2的注意力计算输出结果
  • Decoder2的输出传入线性层,之后使用Softmax函数转为0到1之间的概率,然后可以使用greedy search(贪心解码)算法得到概率最高的词汇作为预测结果

下面是预测相关代码的示例:

再来看下PyTorch提供的聊天机器人样例的预测操作:

  • 用户输入正向传播通过encoder模型
  • 把encoder的final hidden layer作为decoder模型的first hidden input
  • 使用SOS_token作为decoder的第一个输入来初始化模型
  • decoder根据encoder的输出(上篇文章提到的“Luong attention”注意力机制计算),以及当前decoder的输入,hidden state来输出预测得到的词汇(迭代操作)
  • 使用Softmax计算概率并根据概率获取最有可能出现的词汇
  • 把当前预测得到的词汇作为下一个decoder的输入
  • 收集所有预测得到的词汇

以下是预测相关代码的示例:

六、聊天机器人对话效果解析

基于Transformer的聊天机器人和PyTorch提供的聊天机器人都使用同样的训练语料(“Cornell Movie-Dialogs Corpus.”)进行训练,基于Transformer的聊天机器人模型训练了20个epochs,输入语句最大长度设置为60,PyTorch提供的聊天机器人训练配置如下:

clip = 50.0

teacher_forcing_ratio = 1.0

learning_rate = 0.0001

decoder_learning_ratio = 5.0

n_iteration = 4000

print_every = 1

save_every = 500

使用同样的测试对话语料分别对两个模型进行测试,基于Transformer模型的对话测试结果如下:

PyTorch提供的聊天机器人对话测试结果如下:


http://www.ppmy.cn/news/91655.html

相关文章

调用华为API实现图像搜索

调用华为API实现图像搜索 1、作者介绍2、华为API介绍2.1 华为云图像搜索2.2 图像搜索应用场景2.2.1商品图片搜索2.2.2版权图片搜索 2.3 调用华为API实现图像标签 3、实验过程3.1完整代码3.2运行结果3.3常见错误 1、作者介绍 张勇进&#xff0c;男&#xff0c;西安工程大学电子…

计算机组成原理-指令系统-指令格式及寻址方式

目录 一、指令的定义 1.1 扩展操作码指令格式 二、指令寻址方式 2.1 顺序寻址 2.2 跳跃寻址 三、 数据寻址 3.1 直接寻址 3.2 间接寻址 3.3 寄存器寻址 ​ 3.4 寄存器间接寻址 3.5 隐含寻址 3.6 立即寻址 3.7 偏移地址 3.7.1 基址寻址 3.7.2 变址寻址 3.7.3 相对寻址…

大数据框架-Hadoop

大数据框架-Hadoop 1.什么是大数据 大数据是指由传统数据处理工具难以处理的规模极大、结构复杂或速度极快的数据集合。这些数据集合通常需要使用先进的计算和分析技术才能够处理和分析&#xff0c;因此大数据技术包括了大数据存储、大数据处理和大数据分析等方面的技术和工具…

Elasticsearch扫盲篇

1. 什么是搜索&#xff1f; 在日常的工作和学习中&#xff0c;当我们说想找查询任何的信息的时候&#xff0c;可能第一时间会想到上百度或者谷歌去搜索一下。比如说找一部自己喜欢的电影&#xff0c;或者说找一本喜欢的书&#xff0c;或者找一条感兴趣的新闻。但是百度和谷歌不…

【前端面试知识点】- 大厂面试(三)

目录 1.JS垃圾回收机制2.form表单的使用以及相应的属性值3.vue中的双向数据绑定原理&#xff08;重点深入学习&#xff09;4.vue单向数据流5.undefined、null、NaN有什么区别&#xff1f;7.如何获取安全的undefined 值&#xff1f;8.Object.is() 与比较操作符 “ ” 、 “” 的…

随机森林Proximity实现及应用

随机森林Proximity实现及应用 1 算法1.1 随机森林Proximity简介1.2 RF-GAP1.3 实现代码 2 应用2.1 离群点(outlier)检测2.1.1 原理和实现2.1.2 实验结果 附录 项目主页&#xff1a;randomforest C implementation of random forests classification, regression, proximity and…

总线系统 - 总线带宽 计算

文章目录 1 概述2 总线带宽2.1 衡量计算机系统的指标2.2 计算公式 3 扩展3.1 GB、MB、KB、B 换算公式 1 概述 #mermaid-svg-PlQh8E6rS0IQvpEb {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-PlQh8E6rS0IQvpEb .erro…

selenium UI自动化中文件上传的两种方式

前言 文件上传是自动化中很常见的一个功能&#xff0c;那么对于文件上传你又有多少了解呢&#xff1f;请往下看 1、被测产品中文件上传的功能非常普遍&#xff0c;一般情况下需要将准备好的文件放在预定的路径下&#xff0c;然后在自动化测试的脚本中&#xff0c;去预置的路径…