基于Transformer Models模型完成学习训练模型

news/2024/9/25 17:10:39/

在编程之前需要准备一些文件:

首先,先win+R打开运行框,输入:PowerShell后

输入:

pip install -U huggingface_hub

下载完成后,指定我们的环境变量:

$env:HF_ENDPOINT = "https://hf-mirror.com"

然后下载模型:

huggingface-cli download --resume-download gpt2 --local-dir "D:\Pythonxiangmu\PythonandAI\Transformer Models\gpt-2"

工程目录地址

然后下载数据量:

huggingface-cli download --repo-type dataset --resume-download wikitext --local-dir "D:\Pythonxiangmu\PythonandAI\Transformer Models\gpt-2"

工程目录地址在PowerShell中下载完这些后,可以开始代码啦

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (AutoTokenizer,AutoModelForCausalLM,AdamW,get_linear_schedule_with_warmup,set_seed,
)
from torch.optim import AdamW# 设置随机种子以确保结果可复现
set_seed(42)class TextDataset(Dataset):def __init__(self, tokenizer, texts, block_size=128):self.tokenizer = tokenizerself.examples = [self.tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=block_size) fortextin texts]# 在tokenizer初始化后,确保unk_token已设置print(f"Tokenizer's unk_token: {self.tokenizer.unk_token}, unk_token_id: {self.tokenizer.unk_token_id}")def __len__(self):return len(self.examples)def __getitem__(self, i):item = self.examples[i]# 替换所有不在vocab中的token为unk_token_idfor key in item.keys():item[key] = torch.where(item[key] >= self.tokenizer.vocab_size, self.tokenizer.unk_token_id, item[key])return itemdef train(model, dataloader, optimizer, scheduler, de, tokenizer):model.train()for batch in dataloader:input_ids = batch['input_ids'].to(de)# 添加日志输出检查input_idsif torch.any(input_ids >= model.config.vocab_size):print("Warning: Some input IDs are outside the model's vocabulary.")print(f"Max input ID: {input_ids.max()}, Vocabulary Size: {model.config.vocab_size}")attention_mask = batch['attention_mask'].to(de)labels = input_ids.clone()labels[labels[:, :] == tokenizer.pad_token_id] = -100outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossloss.backward()optimizer.step()scheduler.step()optimizer.zero_grad()def main():local_model_path = "D:/Pythonxiangmu/PythonandAI/Transformer Models/gpt-2"tokenizer = AutoTokenizer.from_pretrained(local_model_path)# 确保pad_token已经存在于tokenizer中,对于GPT-2,它通常自带pad_tokenif tokenizer.pad_token is None:special_tokens_dict = {'pad_token': '[PAD]'}tokenizer.add_special_tokens(special_tokens_dict)model = AutoModelForCausalLM.from_pretrained(local_model_path, pad_token_id=tokenizer.pad_token_id)else:model = AutoModelForCausalLM.from_pretrained(local_model_path)model.to(device)train_texts = ["The quick brown fox jumps over the lazy dog.","In the midst of chaos, there is also opportunity.","To be or not to be, that is the question.","Artificial intelligence will reshape our future.","Every day is a new opportunity to learn something.","Python programming enhances problem-solving skills.","The night sky sparkles with countless stars.","Music is the universal language of mankind.","Exploring the depths of the ocean reveals hidden wonders.","A healthy mind resides in a healthy body.","Sustainability is key for our planet's survival.","Laughter is the shortest distance between two people.","Virtual reality opens doors to immersive experiences.","The early morning sun brings hope and vitality.","Books are portals to different worlds and minds.","Innovation distinguishes between a leader and a follower.","Nature's beauty can be found in the simplest things.","Continuous learning fuels personal growth.","The internet connects the world like never before."# 更多训练文本...]dataset = TextDataset(tokenizer, train_texts, block_size=128)dataloader = DataLoader(dataset, batch_size=4, shuffle=True)optimizer = AdamW(model.parameters(), lr=5e-5)total_steps = len(dataloader) * 5  # 假设训练5个epochscheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)for epoch in range(5):  # 训练5个epochtrain(model, dataloader, optimizer, scheduler, device, tokenizer)  # 使用正确的变量名dataloader并传递tokenizer# 保存微调后的模型model.save_pretrained("path/to/save/fine-tuned_model")tokenizer.save_pretrained("path/to/save/fine-tuned_tokenizer")if __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")main()


http://www.ppmy.cn/news/1445352.html

相关文章

线性代数 --- 计算斐波那契数列第n项的快速算法(矩阵的n次幂)

计算斐波那契数列第n项的快速算法(矩阵的n次幂) The n-th term of Fibonacci Numbers: 斐波那契数列的是一个古老而又经典的数学数列,距今已经有800多年了。关于斐波那契数列的计算方法不难,只是当我们希望快速求出其数列中的第100&#xff0…

Slave SQL线程与PXB FTWRL死锁问题分析

1. 问题背景 2.27号凌晨生产环境MySQL备库在执行备份期间出现因FLUSH TABLES WITH READ LOCK未释放导致备库复制延时拉大,慢日志内看持锁接近25分钟未释放。 版本: MySQL 5.7.21PXB 2.4.18 慢查询日志: 备份脚本中的备份命令:…

Android数据恢复:如何在手机上恢复丢失的文件和照片

我们都有 我们错误地从手机中删除重要内容的时刻。确实如此 不一定是我们的错。其他人可以对您的手机数据执行此操作 有意或无意。这在某个时间点发生在我们所有人身上。 但是,今天市场上有各种各样的软件可以 帮助恢复已删除的文件。这些类型的软件被归类为数据恢复…

在HTML中使用JavaScript实时显示当前日期和时间(结尾完整例程)

在Web开发中,经常需要在网页上显示当前的日期和时间。HTML本身并不具备这样的动态功能,但我们可以借助JavaScript来实现。JavaScript是一种常用的前端脚本语言,它可以轻松地获取系统时间,并将其插入到HTML元素中。 下面是一个简单…

opencv-基本操作

本篇文章,我们将聊一聊利用opencv进行的一些基本操作,以便后续我们利用opencv进行更加复杂的处理。 1、图像的读取、显示与保存 opencv中利用cv2.imread读取RGB图像,利用cv2.imshow() 进行图像的显示。 # 注意: cv2.imread读取RGB图像时, 返回…

TCP、UDP客户端

TCP客户端 #include <mystdio.h> #define CLI_PORT 6666 #define CLI_IP "192.168.124.210" int main(int argc, const char *argv[])//argv[1] IP argv[2] 端口号 { if(argc <3) { printf("请在命令传参端口号和IP地址\n");…

《python编程从入门到实践》day16

昨日知识点回顾 从模块中导入类/模块 今日知识点学习 第十章 文件和异常 10.1 从文件中读取数据 10.1.1 读取整个文件 txt文件与程序文件在同一级目录 with open(pi_digits.txt) as file_object:contents file_object.read() print(contents)# 运行结果&#xff1a; # 3.1…

大数据学习笔记11-Hadoop基础2

一: 分布式的基础架构分析[重要] 集群架构模式: 主从架构(中心化): 主角色 master: 发号施令,负责任务的接受和分配 从角色 slave: 负责干活 主备架构:可以解决中心化存在的问题 主角色active : 正常工作 备角色standby : 观察主角色工作,并实时备份主角色数据,当主角色宕机…