使用gpt2-medium基座说明模型微调

news/2024/10/24 14:25:19/

预训练与微调的背景

  • 预训练:在大规模数据集上训练模型,以捕捉通用的特征和模式。例如,GPT-2 模型在大量文本上进行训练,学习语言的基本结构和语法。
  • 微调:在特定领域或任务的数据上对预训练模型进行训练,以使其更好地适应特定需求。微调通常需要的数据量少于从头开始训练模型所需的数据量。

微调的过程

微调过程通常包括以下几个步骤:

  1. 选择预训练模型:选择一个适合任务的预训练模型,通常根据模型在相似任务上的表现来决定。
  2. 准备数据:收集并清洗与目标任务相关的数据,确保数据的质量和代表性。
  3. 调整模型参数
    • 学习率:微调时通常使用较小的学习率,因为模型已经在大规模数据上学习到了丰富的特征,微调的目的是精细调整这些特征。
    • 冻结部分层:在某些情况下,可以选择冻结预训练模型的某些层,只训练后面的几层,以避免破坏已经学习到的知识。
  4. 训练过程:使用特定任务的数据对模型进行训练,计算损失并进行反向传播,以更新模型参数。
  5. 评估与优化:在验证集上评估模型性能,根据需要调整超参数或训练策略,直到达到满意的结果。

下面使用gpt2-medium 来说明

import torch
import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import Dataset, DataLoadernum_epochs=1000
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载预训练模型和tokenizer
model_name = "gpt2-medium"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.to(device)# 定义微调数据集
class CustomDataset(Dataset):def __init__(self, data):self.input_ids = datadef __len__(self):return len(self.input_ids)def __getitem__(self, index):return torch.tensor(self.input_ids[index])training = False
if training:# 加载和处理数据train_data = ['我家门前有两棵树,一棵是枣树,另一棵也是枣树。']  # 微调用的训练数据tokenizer.pad_token = tokenizer.eos_tokentokenizer.add_special_tokens({'pad_token': '[PAD]'})train_encodings = tokenizer(train_data, truncation=True, padding=True)train_dataset = CustomDataset(train_encodings["input_ids"])train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)# 设置优化器optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)# 微调模型model.train()for id, epoch in enumerate(tqdm.tqdm(range(num_epochs))):for batch in train_dataloader:input_ids = batch.to(device)model.zero_grad()outputs = model(input_ids, labels=input_ids)loss = outputs.lossloss.backward()optimizer.step()# 保存微调后的模型model.save_pretrained("data/nlp_model")tokenizer.save_pretrained("data/nlp_model")if not training:# 加载微调后的模型和tokenizermodel_name = "data/nlp_model"  # 微调后模型的路径tokenizer = GPT2Tokenizer.from_pretrained(model_name)model = GPT2LMHeadModel.from_pretrained(model_name)# 设置设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)# 生成文本prompt = "门前有树是我家的"input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)# 生成对应的 attention maskattention_mask = (input_ids != 0).int()output = model.generate(input_ids, max_length=100, num_return_sequences=1, attention_mask=attention_mask)generated_text = tokenizer.decode(output[0], skip_special_tokens=True)print("Generated Text:")print(generated_text)

代码展示了如何使用 PyTorch 和 Hugging Face 的 Transformers 库微调一个预训练的 GPT-2 模型,具体是 gpt2-medium 版本,并使用中文文本进行训练和生成任务。以下是代码的详细解释和模型微调的技术要点。

代码解释

  1. 导入所需的库

    import torch
    import tqdm
    from transformers import GPT2LMHeadModel, GPT2Tokenizer
    from torch.utils.data import Dataset, DataLoader
    

    这部分代码导入了 PyTorch、进度条库 tqdm、Hugging Face 的 Transformers 库中的模型和 tokenizer,以及 PyTorch 的数据集和数据加载器工具。

  2. 设置设备

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

    检测是否有可用的 GPU,如果有,则使用 GPU。

  3. 加载模型和 tokenizer

    model_name = "gpt2-medium"
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
    model.to(device)
    

    加载预训练的 GPT-2 模型和相应的 tokenizer,并将模型移动到指定的设备上。

  4. 定义自定义数据集

    class CustomDataset(Dataset):def __init__(self, data):self.input_ids = datadef __len__(self):return len(self.input_ids)def __getitem__(self, index):return torch.tensor(self.input_ids[index])
    

    创建一个自定义数据集类,继承自 PyTorch 的 Dataset 类,主要用于处理输入数据。

  5. 数据处理和准备

    train_data = ['我家门前有两棵树,一棵是枣树,另一棵也是枣树。']
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    train_encodings = tokenizer(train_data, truncation=True, padding=True)
    train_dataset = CustomDataset(train_encodings["input_ids"])
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    
    • 定义训练数据。
    • 设置 tokenizer 的填充标记。
    • 对文本进行编码,生成输入 ID。
    • 创建一个数据集实例和数据加载器,以便在训练过程中批量处理数据。
  6. 设置优化器

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    

    使用 AdamW 优化器,设置学习率为 1e-5

  7. 微调模型

    model.train()
    for id, epoch in enumerate(tqdm.tqdm(range(num_epochs))):for batch in train_dataloader:input_ids = batch.to(device)model.zero_grad()outputs = model(input_ids, labels=input_ids)loss = outputs.lossloss.backward()optimizer.step()
    
    • 将模型置于训练模式。
    • 遍历多个训练轮次(epochs)。
    • 对每个批次进行前向传播、计算损失、反向传播和优化步骤。
  8. 保存微调后的模型

    model.save_pretrained("data/nlp_model")
    tokenizer.save_pretrained("data/nlp_model")
    

    保存微调后的模型和 tokenizer。

  9. 文本生成

    prompt = "xxxx"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)attention_mask = (input_ids != 0).int()
    output = model.generate(input_ids, max_length=100, num_return_sequences=1, attention_mask=attention_mask)
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)print("Generated Text:")
    print(generated_text)
    
    • 使用微调后的模型生成文本,给定一个提示词(prompt)。
    • 编码提示词并生成文本序列,最后解码为可读文本并输出。

测试下微调效果

对于资源并不充沛的公司而言

一个可行的思路是结合参数较小的模型进行微调,再利用向量数据库和知识图谱使用去实现RAG

模型微调的技术要点

  1. 数据准备:微调时使用的数据应与目标应用场景相符,以便模型能够学习特定的上下文和语言特征。

  2. 超参数设置:学习率、批量大小、训练轮数等超参数对模型性能有重要影响。通常需要通过实验来找到最适合的设置。

  3. 损失计算:在微调过程中,通常使用模型输出的损失值进行优化,以指导模型学习。

  4. 模型保存:微调后的模型需要保存,以便后续使用或部署。

  5. 文本生成:使用微调后的模型生成文本时,可以通过调整 max_lengthnum_return_sequences 等参数来控制生成文本的长度和数量。


http://www.ppmy.cn/news/1541613.html

相关文章

1024程序员节祝福

1024程序员节祝福 在每年的10月24日,我们迎来了属于程序员的节日——1024程序员节。这个特殊的日子,既是对广大程序员辛勤工作的致敬,也是对他们在科技创新和数字时代进步中做出贡献的认可。在这个值得庆祝的日子里,我想对所有程…

力扣每日一题3185. 构成整天的下标对数目 II

今天的题目没啥好说的,就是昨天的题目的进阶版,用昨天题解的最终版就可以直接过了 今天的就不写思路了,有需要就看昨天的就好了 力扣每日打卡挑战 3184. 构成整天的下标对数目 I class Solution { public:int countCompleteDayPairs(vecto…

WPF+Mvvm项目入门完整教程-基于SqlSugar的数据库实例(三)

目录 数据库实现创建数据库类库资源获取 在上一节中,我们实现了主页UI框架和基础菜单功能,本节主要实现数据库的类库创建、数据功能接口以及泛型方法实现。本例使用的数据库为 MySql数据库,ORM框架采用 SqlSugar 实现。 数据库实现 创建数据…

【计算机网络 - 基础问题】每日 3 题(四十九)

✍个人博客:https://blog.csdn.net/Newin2020?typeblog 📣专栏地址:http://t.csdnimg.cn/fYaBd 📚专栏简介:在这个专栏中,我将会分享 C 面试中常见的面试题给大家~ ❤️如果有收获的话,欢迎点赞…

百度开源语音识别强大工具PaddleSpeech从0到1快速上手:安装、部署、Debug与测试详尽指南

目录 Introduction 导言PaddleSpeech安装部署和测试环境要求:安装参考:安装整体过程如下:使用代码示例:Bug处理模型选择性能测试 参考资料其它资料下载 Introduction 导言 在当今快速发展的人工智能领域,语音识别技术…

JavaScript 在网页设计中的四大精彩案例:画布时钟、自动轮播图、表单验证与可拖动元素

在网页开发中,JavaScript 发挥着至关重要的作用,为网页带来丰富的交互性和动态效果,极大地提升了用户体验。本文将通过几个具体案例展示 JavaScript 的强大魅力。 一、美丽的画布时钟 这是一个使用 JavaScript 在网页上创建美丽画布时钟的案…

太空探索如何引领我们找到真正的命运?

为什么要进行太空探索呢?为什么我们人类要接过这一棒并奋力前行呢?这个问题的答案比太空探索给我们带来诸如果味饮料糖和智能手机等好处要深刻得多,甚至可能更加深远。 答案甚至比我们在其他国家占领月球部分区域以获取月球资源或我们有朝一…

又一次升级:字节在用大模型在做推荐啦!

原文链接 字节前几天2024年9年19日公开发布的论文《HLLM:通过分层大型语言模型增强基于物品和用户模型的序列推荐效果》。 文字、图片、音频、视频这四大类信息载体,在生产端都已被AI生成赋能助力,再往前一步,一定需要一个更强势…