【Pytorch】大语言模型中的CrossEntropyLoss

ops/2024/9/29 1:34:28/

文章目录

  • 前言
  • 什么是CrossEntropyLoss
  • 语言模型中的CrossEntropyLoss
    • 计算loss的前期准备
    • CrossEntropyLoss的输入
    • CrossEntropyLoss的输出
  • 额外说明

前言

在大语言模型时代,我们常常使用交叉熵损失函数来计算loss,因此,理解该loss的计算流程有助于帮助我们对训练过程有更清晰的认知。本文从以下几个角度介绍nn.CrossEntropyLoss()

  • 使用该函数的前期准备:如何组织函数的输入(logits & labels)
  • 该函数流程
  • 常用参数
  • 该文章内容仅为个人理解,如有误解,欢迎讨论

什么是CrossEntropyLoss

这部分并不是本文的重点,我们仅介绍在语言模型的训练过程中,如何利用该loss

  • 相关信息可见:本人博客
  • 以及官网:CrossEntropyLoss官网

语言模型中的CrossEntropyLoss

计算loss的前期准备

huggingface-transformers源码中,我们在语言模型forward中总是能看到这样一段函数。我们以LlamaForCausalLM为例:Llama源码

if labels is not None:# Shift so that tokens < n predict nshift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# Flatten the tokensloss_fct = CrossEntropyLoss()shift_logits = shift_logits.view(-1, self.config.vocab_size)shift_labels = shift_labels.view(-1)# Enable model parallelismshift_labels = shift_labels.to(shift_logits.device)loss = loss_fct(shift_logits, shift_labels)if not return_dict:output = (logits,) + outputs[1:]return (loss,) + output if loss is not None else output

对于Decoder-only模型,在训练时,我们的目标是next token prediction,任务流程如下

  • 假定我们是常规的问答任务,问题是“where is the capital of China“,label为“The capital is Beijing”。该任务的目标为,当输入为“where is the capital of China“时,

  • 我们对question和label进行拼接和tokenize化,一般转化结果 (tokenize忽略) 为:< bos > where is the capital of China < sep > The capital is Beijing < eos >

    • < bos>为句子开头的标志
    • < sep>用于分隔question和label,本质作用是,当模型看到时就知道:问题结束了,下一个token要输出答案了
    • < eos>为生成结束的标志
    • 假定每个词算一个token (忽略空格),那么输入一共有13个token
  • 这时我们将整个序列输入到模型中,模型在每个token的位置都生成一个向量,我们利用lm_head将最后一层的hidden state转化成词表大小的向量logits,用于后续利用Softmax确定每个token的概率

  • 现在模型有了输出logits,怎么计算loss?

    • 对比labels和logits之间的差异来计算loss

    • 现在一共有13个token,生成了13个logits,每个logits都是用于生成next token的。那么很直接的,我们来对比该logits生成的next token准不准就好了

      • 输入:< bos> where is the capital of China < sep> The capital is Beijing < eos>

      • 对比情况为:< sep>->The, The->capital, …, is->Beijing, Beijing->< eos>

        • < sep>对应位置要生成The,…, Beijing对应位置要输出< eos>
      • 我们可以将输入右移一位作为labels: where is the capital of China < sep> The capital is Beijing

        • 可以看到,对于输入来说, < eos>位置没有对应的需要生成的token,因此我们去掉该token
        • 对于labels,< bos>不需要生成,因此我们去掉该token
      • 因此,我们在计算loss时,对logits去尾,labels是输入掐头且右移一位

      • 在代码中对应

          shift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()
        

CrossEntropyLoss的输入

此时还不能直接将shift_logitsshift_labels进行对比,来计算loss。因为我们上面的操作只是为了<sep> The capital is BeijingThe capital is Beijing <eos>中的token能一一对应起来,对于其他部分生成的token,我们并没有要求(因为不是answer,不需要生成)

  • CrossEntropyLoss函数中有一个参数为ignore_idx默认值为-100。labels值设置为-100的位置不会计算loss
  • 因此我们将除了需要计算loss的位置 (最后5个位置)的labels都设置为-100
  • 最终,需要输入到CrossEntropyLoss中的inputs和labels为
    • inputs为: [, where, is, the, capital, of, China, < sep>, The, capital, is, Beijing]对应的logits
      • 注意:不需要进行Softmax,直接传logits即可,函数内部有更稳定的Softmax计算方式
    • labels为: [-100, -100, -100, -100, -100, -100, -100, The, capital, is, Beijing, < eos>]
    • 我们在训练时,构造输入和labels要注意构造为这种形式

CrossEntropyLoss的输出

默认情况下,输出为mean,即各个token计算得到loss的平均值(在token-level上平均,分母是token的个数)

import torch
import torch.nn as nn# 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3)
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]])# 标签,其中第二个样本的标签为 ignore_index (-100)
labels = torch.tensor([0, -100, 2])# 定义 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()# 计算损失
loss = criterion(logits, labels)print(f"Loss: {loss}")
>>> Loss: 0.51058030128479
  • 常用参数:

    • reduction:控制loss的输出形式,共三种'none', 'mean', 'sum',默认为'mean'

      • mean: 每个token计算得到的loss的平均值

      • none: 直接返回每个token计算得到的loss

        • 例子:

          import torch
          import torch.nn as nn# 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3)
          logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]])# 标签,其中第二个样本的标签为 ignore_index (-100)
          labels = torch.tensor([0, -100, 2])# 定义 CrossEntropyLoss
          criterion = nn.CrossEntropyLoss(reduction='none')# 计算损失
          loss = criterion(logits, labels)print(f"Loss: {loss}")
          >>> Loss: tensor([0.4170, 0.0000, 0.6041])
          
      • sum: 所有token对应loss求和

额外说明

对最上面的代码补充说明

  shift_logits = shift_logits.view(-1, self.config.vocab_size)shift_labels = shift_labels.view(-1)
  • 训练数据往往是按batch组织的,shape为(batch_size, seq_len, vocab_size)
  • 我们将所有batch的token压缩为一个序列,计算整个序列的loss,这样比较方便

http://www.ppmy.cn/ops/117746.html

相关文章

蓝桥杯【物联网】零基础到国奖之路:九. I2C

蓝桥杯【物联网】零基础到国奖之路:九. I2C 第一节 I2C概念第二节 I2C的物理层第三节 I2C的协议层 第一节 I2C概念 中文叫集成电路总线&#xff0c;是一种串行通信总线&#xff0c;使用多主从架构&#xff0c;由飞利浦公司1980年代初设计&#xff0c;方便主板、嵌入式系统或手…

python的函数

python中的函数 函数参数不可变对象参数可变对象参数参数调用的几种方法 *表示隔开关键字参数匿名函数函数的返回return强制位置参数错误写法正确写法 函数 函数是组织好的一组代码块&#xff0c;可以实现单一或关联功能的代码段。函数也可以提高应用的模块性以及代码的重复利…

【高分系列卫星简介】

高分系列卫星是中国国家高分辨率对地观测系统&#xff08;简称“高分工程”&#xff09;的重要组成部分&#xff0c;旨在提供全球范围内的高分辨率遥感数据&#xff0c;广泛应用于环境监测、灾害应急、城市规划、农业估产等多个领域。以下是对高分系列卫星及其数据、相关参数和…

✨机器学习笔记(五)—— 神经网络,前向传播,TensorFlow

Course2-Week1: https://github.com/kaieye/2022-Machine-Learning-Specialization/tree/main/Advanced%20Learning%20Algorithms/week1机器学习笔记&#xff08;五&#xff09; 1️⃣神经网络&#xff08;Neural Network&#xff09;2️⃣前向传播&#xff08;Forward propaga…

trixbox call php发起电话呼叫

调用方法&#xff1a; asterisk 命令行 OK originate sip/801 extension 802 originate sip/802 extension 9013816338277default good bye挂断 originate sip/802 extension 9013816338277from-internal OK Asterisk Call Manager (AMI)呼叫可以 http://xxxx/voip/c…

C++ Mean Shift算法

原理 每个样本点最终会移动到核概率密度的峰值&#xff0c;移动到相同峰值的样本点属于同一种颜色 关键代码 template <typename PointType> inline typename MeanShift<PointType>::PointsVector MeanShift<PointType>::meanshift(const PointsVector &am…

使用Hutool-poi封装Apache POI进行Excel的上传与下载

介绍 Hutool-poi是针对Apache POI的封装&#xff0c;因此需要用户自行引入POI库,Hutool默认不引入。到目前为止&#xff0c;Hutool-poi支持&#xff1a; Excel文件&#xff08;xls, xlsx&#xff09;的读取&#xff08;ExcelReader&#xff09;Excel文件&#xff08;xls&…

828华为云征文 | 华为云 X 实例服务器存储性能测试与优化策略

目录 引言 1 华为云 X 实例服务器概述 2 存储性能测试方法与工具 2.1 测试方法 2.2 测试工具 3 FIO&#xff08;Flexible I/O Tester&#xff09;读写性能测试 3.1 顺序读写测试 3.2 随机读写测试 4 hdparm性能测试 4.1 实际读取速度测试 4.2 缓存读取速度测试 4.3…