NLP transformers - 文本分类

devtools/2024/9/23 4:30:38/

在这里插入图片描述

Text classification

文章目录

  • Text classification
    • 加载 IMDb 数据集
    • Preprocess 预处理
    • Evaluate
    • Train
    • Inference


本文翻译自:Text classification
https://huggingface.co/docs/transformers/tasks/sequence_classification
notebook : https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/sequence_classification.ipynb


文本分类是一种常见的 NLP 任务,它为文本分配标签或类别。一些大公司在生产中运行文本分类,以实现广泛的实际应用。最流行的文本分类形式之一是 情感分析,它为文本序列分配 🙂 积极、🙁 消极或 😐 中性等标签。

本指南将向您展示:

  1. 在IMDb数据集上微调DistilBERT,以确定电影评论是正面还是负面。
  2. 使用您的微调模型进行推理。

本教程中演示的任务由以下模型架构支持:

ALBERT, BART, BERT, BigBird, BigBird-Pegasus, BioGpt, BLOOM, CamemBERT, CANINE, CodeLlama, ConvBERT, CTRL, Data2VecText, DeBERTa, DeBERTa-v2, DistilBERT, ELECTRA, ERNIE, ErnieM, ESM, Falcon, FlauBERT, FNet, Funnel Transformer, Gemma, GPT-Sw3, OpenAI GPT-2, GPTBigCode, GPT Neo, GPT NeoX, GPT-J, I-BERT, Jamba, LayoutLM, LayoutLMv2, LayoutLMv3, LED, LiLT, LLaMA, Longformer, LUKE, MarkupLM, mBART, MEGA, Megatron-BERT, Mistral, Mixtral, MobileBERT, MPNet, MPT, MRA, MT5, MVP, Nezha, Nyströmformer, OpenLlama, OpenAI GPT, OPT, Perceiver, Persimmon, Phi, PLBart, QDQBert, Qwen2, Qwen2MoE, Reformer, RemBERT, RoBERTa, RoBERTa-PreLayerNorm, RoCBert, RoFormer, SqueezeBERT, StableLm, Starcoder2, T5, TAPAS, Transformer-XL, UMT5, XLM, XLM-RoBERTa, XLM-RoBERTa-XL, XLNet, X-MOD, YOSO


在开始之前,请确保已安装所有必需的库:

pip install transformers datasets evaluate accelerate

我们鼓励您登录 Hugging Face 帐户,以便您可以上传模型并与社区分享。出现提示时,输入您的令牌进行登录:

from huggingface_hub import notebook_loginnotebook_login()

加载 IMDb 数据集

首先从 🤗 数据集库加载 IMDb 数据集:

from  datasets import load_datasetimdb = load_dataset("imdb")

然后看一个数据样例:

IMDB[ “测试” ][ 0 ]
{"label" : 0 ,"text" : "我喜欢科幻小说,并且愿意忍受很多。... 一切又来了。” ,
}

该数据集中有两个字段:

  • text: 影评文字。
  • label: 0:表示负面评论或1正面评论的值。

Preprocess 预处理

下一步是加载 DistilBERT 分词器来预处理该text字段:

from transformers import AutoTokenizertokenizer = AutoTokenizer.from _pretrained( "distilbert/distilbert-base-uncased" )

创建一个预处理函数来对text序列进行标记和截断,使其长度不超过 DistilBERT 的最大输入长度:

def  preprocess_function ( Examples ):return tokenizer(examples[ "text" ], truncation= True )

要将预处理函数应用于整个数据集,请使用 🤗 数据集 map 函数。
您可以map通过设置 batched=True 一次处理数据集的多个元素来加快速度:

tokenized_imdb = imdb.map(preprocess_function, batched=True)

现在使用 DataCollatorWithPadding 创建一批示例。在整理过程中 动态地将句子填充 到批次中的最长长度,比将整个数据集填充到最大长度更有效。

from transformers import DataCollatorWithPaddingdata_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Evaluate

在训练期间包含指标通常有助于评估模型的性能。您可以使用 🤗 Evaluate库快速加载评估方法。对于此任务,加载准确性指标(请参阅 🤗 评估快速浏览以了解有关如何加载和计算指标的更多信息):

import evaluateaccuracy = evaluate.load("accuracy")

然后创建一个传递预测和标签的函数来compute计算准确性:

import numpy as npdef compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)return accuracy.compute(predictions=predictions, references=labels) 

您的compute_metrics函数现在已准备就绪,您将在设置训练时返回该函数。


Train

在开始训练模型之前,请使用id2labellabel2id ,创建预期 id 到其标签的映射:

id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

如果您不熟悉使用 Trainer 微调模型,
请查看基本教程:<(https://huggingface.co/docs/transformers/training#train-with-pytorch-trainer>

您现在就可以开始训练您的模型了!使用 AutoModelForSequenceClassification 加载 DistilBERT以及预期标签的数量和标签映射:

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainermodel = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)

此时,只剩下三步:

  1. 在TrainingArguments中定义训练超参数。
    唯一必需的参数是output_dir指定保存模型的位置。您可以通过设置将此模型推送到 Hub push_to_hub=True(您需要登录 Hugging Face 才能上传模型)。
    在每个 epoch 结束时,Trainer 将评估准确性并保存训练检查点。
  2. 将训练参数以及模型、数据集、分词器、数据整理器和compute_metrics函数传递给Trainer 。
  3. 调用 train() 来微调您的模型。
training_args = TrainingArguments(output_dir="my_awesome_model",learning_rate=2e-5,per_device_train_batch_size=16,per_device_eval_batch_size=16,num_train_epochs=2,weight_decay=0.01,evaluation_strategy="epoch",save_strategy="epoch",load_best_model_at_end=True,push_to_hub=True,
)trainer = Trainer(model=model,args=training_args,train_dataset=tokenized_imdb["train"],eval_dataset=tokenized_imdb["test"],tokenizer=tokenizer,data_collator=data_collator,compute_metrics=compute_metrics,
)trainer.train()

当您传递 token 给Trainer时, 它默认应用动态填充tokenizer。在这种情况下,您不需要显式指定数据整理器。

训练完成后,使用 push_to_hub()方法将您的模型共享到 Hub,以便每个人都可以使用您的模型:

trainer.push_to_hub()

有关如何微调文本分类模型的更深入示例,请查看相应的 PyTorch 笔记本 或 TensorFlow 笔记本。


Inference

太好了,现在您已经微调了模型,您可以使用它进行推理!

获取一些您想要进行推理的文本:

text = “这是一部杰作。并不完全忠实于原著,但从头到尾都令人着迷。可能是三本书中我最喜欢的。”

尝试微调模型进行推理的最简单方法是在 pipeline() 中使用它。使用您的模型实例化pipeline情感分析,并将文本传递给它:

from transformers import pipelineclassifier = pipeline("sentiment-analysis", model="stevhliu/my_awesome_model")
classifier(text)

如果您愿意,您还可以手动复制 pipeline 的结果:


对文本进行分词并返回 PyTorch 张量:

from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("stevhliu/my_awesome_model")
inputs = tokenizer(text, return_tensors="pt")

将您的输入传递给模型并返回logits

from transformers import AutoModelForSequenceClassificationmodel = AutoModelForSequenceClassification.from_pretrained("stevhliu/my_awesome_model")with torch.no_grad():logits = model(**inputs).logits

获取概率最高的类,并使用模型的id2label映射将其转换为文本标签:

predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]
# -> 'POSITIVE'

2024-04-28(日)


http://www.ppmy.cn/devtools/22803.html

相关文章

web server apache tomcat11-27-Security Considerations

前言 整理这个官方翻译的系列&#xff0c;原因是网上大部分的 tomcat 版本比较旧&#xff0c;此版本为 v11 最新的版本。 开源项目 从零手写实现 tomcat minicat 别称【嗅虎】心有猛虎&#xff0c;轻嗅蔷薇。 系列文章 web server apache tomcat11-01-官方文档入门介绍 web…

Linxu系统服务管理,systemd知识/进程优先级/平均负载/php进程CPU100%怎么解决系列知识!

shell脚本&#xff08;命令&#xff09;放后台 sleep 300& 放到后台运行&#xff0c;脚本或命令要全路径 nohup&#xff1a;用户推出系统进程继续工作 【功能说明】 nohup 命令可以将程序以忽略挂起信号的方式运行起来&#xff0c;被运行程序的输出信息将不会显示到终端 如…

VoxAtnNet:三维点云卷积神经网络

VoxAtnNet:三维点云卷积神经网络 摘要IntroductionProposed VoxAtnNet 3D Face PAD3D face point cloud presentation attack Dataset (3D-PCPA) VoxAtnNet: A 3D Point Clouds Convolutional Neural Network for 摘要 面部生物识别是智能手机确保可靠和可信任认证的重要组件。…

ai智能机器人语音后端识别处理呼叫系统部署

人工智能是推动科技跨越发展、产业优化升级、生产力整体跃升的重要战略资源。随着一系列支持人工智能发展政策的相继落地&#xff0c;相关产业的创新活力也被日益激发&#xff0c;推动现有商业体系内各个产业加速变革。在人工智能领域&#xff0c;电话机器人落地的速度也在加快…

【排序算法】第一章:插入排序----直接插入排序与希尔排序的详解和对比

&#x1fae1;和我一起感受 两种排序算法的魅力吧&#xff01; 前言&#xff1a; 理解排序算法最好的方法就是&#xff1a;先单趟后整体 先从一个元素的一趟开始理解再扩展到所有元素的排序 一、直接插入排序 理解排序算法最好的方法就是&#xff1a;先单趟后整体 插入排序&a…

学习STM32第二十天

低功耗编程 一、修改主频 STM32F4xx系列主频为168MHz&#xff0c;当板载8MHz晶振时&#xff0c;系统时钟HCLK满足公式 H C L K H S E P L L N P L L M P L L P HCLK \frac{HSE \times PLLN}{PLLM \times PLLP} HCLKPLLMPLLPHSEPLLN​&#xff0c;在文件stm32f4xx.h中可修…

免费分享一套SpringBoot企业人事管理系统(员工管理,工资管理,档案管理,招聘管理),帅呆了~~

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的SpringBoot企业人事管理系统(员工管理&#xff0c;工资管理&#xff0c;档案管理&#xff0c;招聘管理)&#xff0c;分享下哈。 项目视频演示 【免费】SpringBoot企业人事管理系统(员工管理&#xff0c;工…

分布式与一致性协议之CAP(三)

CAP ACID理论:CAP的"酸"&#xff0c;追求一致性。 提到ACID,它很容易理解&#xff0c;在单机上实现也不难&#xff0c;比如可以通过锁、时间序列等机制保障操作的顺序执行&#xff0c;让系统实现ACID特性。但是一说要实现分布式系统的ACID特性比较难实现呢&#xf…