从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)

news/2024/11/20 8:23:18/

🚩🚩🚩Hugging Face 实战系列 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传

从零构建属于自己的GPT系列1:数据预处理
从零构建属于自己的GPT系列2:模型训练1
从零构建属于自己的GPT系列3:模型训练2
从零构建属于自己的GPT系列4:模型训练3

3 数据加载函数

def load_dataset(logger, args):logger.info("loading training dataset")train_path = args.train_pathwith open(train_path, "rb") as f:train_list = pickle.load(f)train_dataset = CPMDataset(train_list, args.max_len)return train_dataset
  1. 日志报告加载训练数据
  2. 训练数据路径
  3. 将以二进制形式存储的数据文件
  4. 使用 pickle 加载到内存中的 train_list 变量中
  5. 加载CPMDataset包,将train_list从索引转化为torch tensor
  6. 返回tensor
    在这里插入图片描述

4 训练函数

def train(model, logger, train_dataset, args):train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,drop_last=True)logger.info("total_steps:{}".format(len(train_dataloader)* args.epochs))t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochsoptimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)# 设置warmuplogger.info('start training')train_losses = []   # 记录每个epoch的平均lossfor epoch in range(args.epochs):train_loss = train_epoch(model=model, train_dataloader=train_dataloader,optimizer=optimizer, scheduler=scheduler,logger=logger, epoch=epoch, args=args)train_losses.append(round(train_loss, 4))logger.info("train loss list:{}".format(train_losses))logger.info('training finished')logger.info("train_losses:{}".format(train_losses))
  1. 训练函数
  2. 制作Dataloader
  3. 制作Dataloader
  4. 制作Dataloader
  5. 制作Dataloader
  6. 日志添加信息Dataloader*epochs的数量
  7. 记录数据长度到t_total变量中
  8. 指定优化器
  9. 学习率衰减策略,从transformers包中调用现成的get_linear_schedule_with_warmup方法
  10. 设置warmup等参数
  11. 学习率衰减策略
  12. 日志添加信息开始训练
  13. 记录所有epoch的训练损失,以求每个epoch的平均loss
  14. 遍历每个epoch
  15. 指定一个我们自己写的train_epoch函数1
  16. train_epoch函数2
  17. train_epoch函数3
  18. train_epoch函数4
  19. 记录损失,只保存4位小数
  20. 记录日志信息训练损失
  21. 记录日志信息训练完成
  22. 最后一句是在日志中保存所有损失吗?

5 迭代训练函数train_epoch

def train_epoch(model, train_dataloader, optimizer, scheduler, logger, epoch, args):model.train()device = args.deviceignore_index = args.ignore_indexepoch_start_time = datetime.now()total_loss = 0  # 记录下整个epoch的loss的总和epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量epoch_total_num = 0  # 每个epoch中,预测的word的总数量for batch_idx, (input_ids, labels) in enumerate(train_dataloader):try:input_ids = input_ids.to(device)labels = labels.to(device)outputs = model.forward(input_ids, labels=labels)logits = outputs.logitsloss = outputs.lossloss = loss.mean()batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)epoch_correct_num += batch_correct_numepoch_total_num += batch_total_numbatch_acc = batch_correct_num / batch_total_numtotal_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)if (batch_idx + 1) % args.gradient_accumulation_steps == 0:optimizer.step()scheduler.step()optimizer.zero_grad()if (batch_idx + 1) % args.log_step == 0:logger.info("batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))del input_ids, outputsexcept RuntimeError as exception:if "out of memory" in str(exception):logger.info("WARNING: ran out of memory")if hasattr(torch.cuda, 'empty_cache'):torch.cuda.empty_cache()else:logger.info(str(exception))raise exceptionepoch_mean_loss = total_loss / len(train_dataloader)epoch_mean_acc = epoch_correct_num / epoch_total_numlogger.info("epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))logger.info('saving model for epoch {}'.format(epoch + 1))model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)model_to_save = model.module if hasattr(model, 'module') else modelmodel_to_save.save_pretrained(model_path)logger.info('epoch {} finished'.format(epoch + 1))epoch_finish_time = datetime.now()logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss
  1. train_epoch函数
  2. 指定训练模式
  3. 训练设备
  4. 需要忽略的索引
  5. 当前epoch开启的具体时间
  6. 当前epoch的loss总和
  7. 当前epoch预测词正确的总数量
  8. 每个epoch需要预测的测的总数量
  9. for训练从train_dataloader遍历取数据
  10. 捕捉异常
  11. 输入词的索引数据进入训练设备
  12. 标签数据进入训练设备
  13. 输入数据经过前向传播得到输出
  14. 经过softmax后的输出
  15. 得到损失
  16. 平均损失
  17. 通过calculate_acc函数统计该batch的预测token的正确数与总数
  18. 统计该epoch的预测token的正确数
  19. 统计该epoch的预测token的总数
  20. 计算该batch的accuracy
  21. 获得损失值的标量累加到当前epoch总损失
  22. 如果当前的梯度累加步数大于1
  23. 对当前累加的损失对梯度累加步数求平均
  24. 损失反向传播
  25. 梯度裁剪:梯度裁剪的目的是控制梯度的大小,防止梯度爆炸的问题。在训练神经网络时,梯度可能会变得非常大,导致优化算法出现数值不稳定的情况。裁剪梯度就是将梯度的整体范数限制在一个特定的阈值之内
  26. 达到梯度累加的次数后
  27. 更新参数
  28. 更新学习率
  29. 梯度清零
  30. 梯度累加次数为0时,也就是参数更新时
  31. 记录日志
  32. 记录的各个参数占位符
  33. 占位符对应的各个变量
  34. 删除两个变量,释放内存
  35. 捕捉到异常
  36. 如果异常的信息中的字符串包含内存不足的问题,也就是显卡内存不足
  37. 将该问题添加到日志信息
  38. 当显卡内存占用过多时
  39. 手动释放显卡内存
  40. 如果不是显卡内存不足
  41. 记录日志
  42. 返回异常
  43. 记录当前epoch的平均loss
  44. 记录当前epoch的平均accuracy
  45. 日志记录信息
  46. 记录的信息为当前epoch索引、损失、准确率
  47. 日志记录信息,当前保存的模型以及对于的epoch索引
  48. 保存模型的地址
  49. 如果地址不存在
  50. 创建该地址
  51. 确保得到不是外壳对象
  52. 保存模型
  53. 日志记录信息训练完成
  54. 记录完成时间
  55. 记录当前epoch训练所花费时间
  56. 返回epoch平均损失

从零构建属于自己的GPT系列1:数据预处理
从零构建属于自己的GPT系列2:模型训练1
从零构建属于自己的GPT系列3:模型训练2
从零构建属于自己的GPT系列4:模型训练3


http://www.ppmy.cn/news/1260349.html

相关文章

《C++ primer》 anki学习卡片txt输出101张,更新至第2章,截止2023年12月6日

C程序中{{c1::}}要有main函数 一定 不一定 一个函数定义包含哪几个部分 返回类型 函数名(形参列表) { 函数体 } main函数返回值为0时,表示{{c1::}} 成功 失败 无论你使用命令行页面或者IDE,大多数编译器都要求程序源码存储在一…

1-Tornado的介绍

1 tornado的介绍 **Tornado**是一个用Python编写的可扩展的、无阻塞的**Web应用程序框架**和**Web服务器**。 它是由FriendFeed开发使用的;该公司于2009年被Facebook收购,而Tornado很快就开源了龙卷风以其高性能着称。它的设计允许处理大量并发连接&…

【1day】致远 A8+ OA wpsAssistServlet任意文件读取漏洞学习

注:该文章来自作者日常学习笔记,请勿利用文章内的相关技术从事非法测试,如因此产生的一切不良后果与作者无关。 目录 一、漏洞描述 二、影响版本 三、资产测绘 四、漏洞复现

nextTick

在下次 DOM 更新循环结束之后执行延迟回调。在修改数据之后立即使用这个方法,获取更新后的 DOM。 // 修改数据 vm.msg Hello // DOM 还没有更新 Vue.nextTick(function () {// DOM 更新了 })切换页签,不流畅,所以用nextTick,等页…

易宝OA 两处任意文件上传漏洞复现

0x01 产品简介 易宝OA系统是一种专门为企业和机构的日常办公工作提供服务的综合性软件平台,具有信息管理、 流程管理 、知识管理(档案和业务管理)、协同办公等多种功能。 0x02 漏洞概述 易宝OA系统UploadFile、BasicService.asmx等接口处存在文件上传漏洞,未授权的攻击者可…

Kotlin(十四) 扩展函数和运算符重载

目录 扩展函数 语法结构 代码示例 运算符重载 语法结构 一元操作符 二元操作符 数值类型操作符 等于和不等于操作符 比较操作符 调用操作符 扩展函数 语法结构 对于扩张函数的语法结构其实很简单,你想在那个类中添加扩张函数,那么你就用该类…

占用站点资源,无法正常登录?这个功能帮助解决

在企业里随着PDM用户的增加PDM管理员是否发现原本的站点已经不够用出现部分用户占用站点资源导致其他用户无法正常登录导致该问题无法解决,本篇介绍PDM自动下线的功能助力企业解决问题,更好的帮助企业完成PDM的正常使用 今天我给大家带来的就是SOLIDWOR…

风力发电对讲 IP语音对讲终端IP安防一键呼叫对讲 医院对讲终端SV-6005网络音频终端

风力发电对讲 IP语音对讲终端IP安防一键呼叫对讲 医院对讲终端SV-6005网络音频终端 目 录 1、产品规格 2、接口使用 2.1、侧面接口功能 2.2、背面接口功能 2.3、面板接口功能 3、功能使用 1、产品规格 输入电源: 12V~24V的直流电源 网络接口&am…