LLM学习笔记(11)pipeline() 函数的幕后工作

devtools/2024/12/4 21:37:19/

Hugging Face 的 pipeline 背后做了什么?

Hugging Face 的 pipeline 是一个高层封装工具,简化了许多繁琐的操作,使得开发者可以快速调用 NLP 模型完成复杂任务。以示例中的 情感分析任务 (sentiment-analysis) 为例,pipeline 背后执行了以下三个主要步骤:

1. 预处理(Preprocessing)

  • 目标: 将原始文本转化为模型可以理解的输入格式。
  • 操作:
    1. 使用 Tokenizer(分词器)将输入的自然语言文本分解为 token(子词单元)。
    2. 将 token 转换为对应的数字 ID,构建模型所需的输入格式。
    3. 添加特殊标记(如 [CLS][SEP])和填充(padding),使输入满足模型的固定格式。
  • 输入:
    • 原始文本:"I've been waiting for a HuggingFace course my whole life."
  • 输出:
    • 输入 ID(input IDs):[101, 2023, 2607, 2003, 6429, 999, 102](这是 token 对应的数字化表示)。

2. 模型推理(Model Forward Pass)

  • 目标: 将预处理后的输入传递给模型进行推理,生成输出 logits。
  • 操作:
    1. 输入数字化的 token ID 和相关信息(如 attention masks)到模型。
    2. 模型计算输出 logits(未归一化的分数),表示每个类别的置信度。
  • 输入:
    • 数字化表示的 input IDs
  • 输出:
    • Logits:[-4.3630, 4.6859](表示类别的未归一化分数,越大置信度越高)。

3. 后处理(Postprocessing)

  • 目标: 将模型输出的 logits 转换为人类可读的格式。
  • 操作:
    1. 使用 Softmax 函数将 logits 转换为概率值。
    2. 找到概率值最高的类别并标注为最终的预测结果。
  • 输入:
    • Logits:[-4.3630, 4.6859]
  • 输出:
    • 概率分布:[NEGATIVE: 0.11%, POSITIVE: 99.89%]
    • 最终结果:{'label': 'POSITIVE', 'score': 0.9598048329353333}

不涉及JSON文件

在 Hugging Face 的 pipeline 背后的 三个步骤之间(预处理、模型推理、后处理),并不是通过 JSON 文件传递数据。这些步骤是通过 Python 对象和内存中的数据直接传递的,使用的是 Python 和 PyTorch/TensorFlow 等深度学习框架的高效内存操作(如张量和字典),而不是文件操作。

数据传递的具体机制

1. 预处理(Preprocessing)

  • 输入数据: 由用户提供的自然语言文本。
  • 输出数据: 转换后的模型输入格式(token IDs 和相关信息)。
  • 传递方式: 直接将数据存储在内存中的 PyTorch TensorNumPy Array 对象中。

2. 模型推理(Model Forward Pass)

  • 输入数据: 预处理后的 token IDs 和 attention masks 等。
  • 输出数据: 模型生成的 logits(未归一化分数)。
  • 传递方式: 数据通过 Python 对象直接传递,通常是 torch.Tensortf.Tensor 格式。

3. 后处理(Postprocessing)

  • 输入数据: 模型推理生成的 logits。
  • 输出数据: 转换为概率分布和最终结果(如 {'label': 'POSITIVE', 'score': 0.95})。
  • 传递方式: 数据在内存中以 Python 的标准数据类型(如列表或字典)返回。

为什么不是 JSON 文件?

以下是 Hugging Face 在 pipeline 内部没有使用 JSON 文件作为数据传递方式的原因:

1. 数据传递的高效性
  • JSON 文件涉及文件的读写操作,会显著增加延迟。
  • 直接在内存中操作数据(如 torch.Tensor)效率更高,适合深度学习模型的推理过程。
2. 数据格式的复杂性
  • 模型的输入(如 token IDs 和 attention masks)和输出(如 logits)通常是多维张量(tensor)。
  • JSON 不适合存储复杂的高维数据结构,转换为 JSON 会引入额外的复杂性。
3. 实时性需求
  • pipeline 的设计目标是实时处理数据,而文件读写会对实时性产生不必要的影响。

虽然 pipeline 的内部步骤没有使用 JSON 文件,但在以下场景下,JSON 可能会被用到:

1. 用户输入/输出

  • 如果用户从外部文件(如 JSON 文件)加载输入,或将输出保存为 JSON 文件,JSON 会用于持久化存储。

2. API 通信

  • 如果 pipeline 被部署为 API 服务(如通过 REST API 提供服务),通常会使用 JSON 格式传递请求和响应。

3. 分布式计算

  • 在一些需要跨进程通信的场景中,可能会使用 JSON 或其他序列化格式传递数据。

使用分词器进行预处理

神经网络模型无法直接处理文本,因此首先需要通过预处理环节将文本转换为模型可以理解的数字。具体地,我们会使用每个模型对应的分词器 (tokenizer) 来进行:

1. 分词(Tokenization)

  • 将文本分解为较小的单元(单词或子词,称为 tokens)。

2. 将 tokens 转换为数字 ID

  • 分词器通过查找模型词汇表,将每个 token 映射到对应的数字编码(token IDs)。

3. 特殊标记

  • 添加模型需要的特殊标记,如句子开头的 [CLS] 和结尾的 [SEP]

4. 填充(Padding)

  • 对于批量输入,分词器会将短句填充到相同长度,确保张量形状一致。

我们对输入文本的预处理需要与模型自身预训练时的操作完全一致,只有这样模型才可以正常地工作。

注意,每个模型都有特定的预处理操作,如果对要使用的模型不熟悉,可以通过 Model Hub 查询。这里我们使用 AutoTokenizer 类和它的 from_pretrained() 函数,它可以自动根据模型 checkpoint 名称来获取对应的分词器。

代码

from transformers import AutoTokenizer

checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

raw_inputs = [
    "I've been waiting for a HuggingFace course my whole life.",
    "I hate this so much!"
]

inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")
print(inputs)

加载分词器

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

  • checkpoint 指定模型名称 distilbert-base-uncased-finetuned-sst-2-english
  • AutoTokenizer.from_pretrained() 根据模型自动加载对应的分词器。

输入文本

raw_inputs = [
    "I've been waiting for a HuggingFace course my whole life.",
    "I hate this so much!"
]

  • 这个由中括号框住的内容是 Python 中的列表(List) 格式,而其中的每个元素是 字符串(String)

列表(List):

  • Python 中的一种数据结构,可以存储多个有序的元素。
  • 列表中的元素可以是任意类型(如字符串、数字、布尔值、甚至嵌套的列表)。
  • 使用中括号 [] 表示。
  • 列表元素之间用逗号 , 分隔。

分词和编码

inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")

  • padding=True 对短句进行填充,使所有句子的长度一致。
  • truncation=True 截断超长文本,确保输入长度不超过模型最大限制。
  • return_tensors="pt" 将输出结果转换为 PyTorch 的张量格式。

输出结果解析

字段解释

  • input_ids

    • 表示每个句子经过分词后转化的 token IDs。
    • 每一行对应一个输入句子,数字表示模型词汇表中 token 的索引。
  • attention_mask

    • 用于标识哪些位置是有效 token,哪些是填充 token:
      • 1 表示有效 token。
      • 0 表示填充 token(padding)。

将预处理好的输入送入模型

基础模型:AutoModel

加载基础模型(AutoModel)

model = AutoModel.from_pretrained(checkpoint)

  • 这是一个基础的 Transformer 模型,仅计算中间层特征(hidden states)。
  • 并没有特定的任务头(head),因此不能直接用于分类等具体任务。

输出内容

  • 输出的 last_hidden_state 是一个三维张量,维度为 (batch_size, sequence_length, hidden_size)
  • torch.Size([2, 16, 768])
    • Batch size: 2(因为输入了两句文本)。
    • Sequence length: 16(分词器将短句填充到相同长度)。
    • Hidden size: 768(DistilBERT 的特征维度)。

特定任务模型:AutoModelForSequenceClassification

加载任务模型

  • AutoModelForSequenceClassification 是为文本分类任务设计的模型。
  • 在基础模型的顶部添加了分类头(classification head),用于输出分类 logits。

输出内容

  • logits 分类器的输出结果,表示每个类别的未归一化分数。
  • torch.Size([2, 2])
    • Batch size 为 2(两条输入)。
    • 每条输入有两个分类分数(如 POSITIVENEGATIVE)。

分类任务的完整流程

  • 输入经过基础模型(计算 hidden states)。
  • hidden states 被传入分类头,得到每个类别的 logits。
第一步:输入经过基础模型
  • 模型首先接收已经分词和编码的输入数据,包括 input_ids(token 的数字表示)和 attention_mask(标记有效 token 的掩码)。
  • 基础模型(如 AutoModelBERT)会将这些输入传递给其内部的 Transformer 网络,逐层处理,生成 hidden states(隐藏层特征表示)。

hidden states 是什么?

  • Hidden states:
    • 表示每个 token 的上下文特征,是模型内部计算出的高维向量(例如,768 维或更高)。
    • 它们是文本经过 Transformer 网络后的中间语义表示。

hidden states 的作用:

  • 它们是基础模型的输出,是后续任务(如分类、问答等)的输入。
  • 类似于“语义向量”,代表了模型对文本的深层次理解。
第二步:hidden states 被传入分类头(Classification Head)
  • 分类头(head):

    • 一个额外的神经网络层,通常是全连接层(fully connected layer),用于特定任务(如分类)。
    • 它会接收基础模型的 hidden states,并将其映射到类别概率空间。
  • 工作原理:

    • 分类头会根据 hidden states 计算出每个类别的分数,称为 logits
    • 例如,对于二分类任务(如情感分析),会输出两个 logits(分别对应 POSITIVENEGATIVE)。
第三步:从 logits 得到分类结果
  • Logits:

    • logits 是未归一化的分数,表示每个类别的可能性大小。
    • 例如,[0.5, -1.2] 可能表示 POSITIVENEGATIVE 的初始分数。
  • Softmax:

    • 一般会对 logits 应用 Softmax 函数,将分数转换为概率分布。
    • 例如,[0.5, -1.2] 经过 Softmax 后,可能得到 [0.88, 0.12],表示 POSITIVE 的概率为 88%。
  • 最终分类:

    • 选择概率最大的类别作为最终分类结果。
    • 例如,POSITIVE 是概率最大的类别,因此输出 POSITIVE


http://www.ppmy.cn/devtools/139450.html

相关文章

Zookeeper的通知机制是什么?

大家好,我是锋哥。今天分享关于【Zookeeper的通知机制是什么?】面试题。希望对大家有帮助; Zookeeper的通知机制是什么? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Zookeeper的通知机制主要通过Watcher实现,它是Zookeeper客…

Linux 各个目录作用

刚毕业的时候学习Linux基础知识,发现了一份特别好的文档快乐的 Linux 命令行,翻译者是happypeter,作者当年也在慕课录制了react等前端相关的视频,通俗易懂,十分推荐 关于Linux的目录,多数博客已有详细介绍…

【VPX312-0】基于3U VPX总线架构的XC7VX690T FPGA数据预处理平台

产品概述 VPX312-0是一款基于3U VPX总线架构的XC7VX690T FPGA高性能数据预处理平台,该平台采用1片Xilinx的28nm Virtex-7系列FPGA XC7VX690T作为主处理器,主要完成数据的采集、处理以及传输的功能。 板卡的FPGA支持2组64位DDR3 SDRAM高速数据缓存&…

第九章 使用Apache服务部署静态网站

1. 网站服务程序 1970 年,作为互联网前身的 ARPANET(阿帕网)已初具雏形,并开始向非军用部门开放,许多大学和商业机构开始陆续接入。虽然彼时阿帕网的规模(只有 4 台主机联网运行)还不如现在的…

【C++】从零到一掌握红黑树:数据结构中的平衡之道

个人主页: 起名字真南的CSDN博客 个人专栏: 【数据结构初阶】 📘 基础数据结构【C语言】 💻 C语言编程技巧【C】 🚀 进阶C【OJ题解】 📝 题解精讲 目录 前言1 红黑树的概念**红黑树的五大性质** 2 红黑树的实现2.1 红黑树的结构…

DevOps工程技术价值流:GitLab源码管理与提交流水线实践

在当今快速迭代的软件开发环境中,DevOps(开发运维一体化)已经成为提升软件交付效率和质量的关键。而GitLab,作为一个全面的开源DevOps平台,不仅提供了强大的版本控制功能,还集成了持续集成/持续交付(CI/CD)…

QT学习笔记-QStringList,QTimer

QStringList-存储和管理一系列的字符串 在Qt框架中&#xff0c;QStringList 是一个模板类 QList<QString> 的特化&#xff0c;专门用于处理 QString 对象&#xff08;即Qt中的字符串&#xff09;的列表。当你看到这样的声明&#xff1a; QStringList m_rec_topicList; …

大数据新视界 -- 大数据大厂之 Hive 数据压缩:优化存储与传输的关键(上)(19/ 30)

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…