Pytorch——训练时,冻结网络部分参数的方法

embedded/2024/9/24 1:25:57/

一、原理:

要固定训练网络的哪几层,只需要找到这几层参数(parameter),然后将其 .requires_grad 属性设置为 False 。然后修改优化器,只将不被冻结的层传入。

二、效果

  1. 节省显存:不将不更新的参数传入optimizer
  2. 提升速度:将不更新的参数的requires_grad设置为False,节省了计算这部分参数梯度的时间

三、代码:

.requires_grad 属性设置为 False

# 根据参数层的 name 来进行冻结
unfreeze_layers = ["text_id"] # 用列表
# 设置冻结参数:
for name, param in model.named_parameters():# print(name, param.shape)# 错误判定:# if name.split(".")[0] in unfreeze_layers: # 不要用in来判定,因为"id"也在"text_id"的in中。# 正确判定:for unfreeze_layer in unfreeze_layers:if name.split(".")[0] != unfreeze_layer:param.requires_grad = Falseprint(name, param.requires_grad)else:print(name, param.requires_grad)
# 冻结整个网络
for param in self.model.parameters():param.requires_grad = False
# 查看冻结参数与否:
for name, param in self.clip_model.named_parameters():print(name, param.requires_grad)

修改优化器

# 只将未被冻结的层,传入优化器
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)

四、其他知识

  1. 模型权重冻结:一些情况下,我们可能只需要对模型进行推断,而不需要调整模型的权重。通过调用model.eval(),可以防止在推断过程中更新模型的权重。
  2. with torch.no_grad(): # 禁用梯度计算以加快计算速度。
  3. 训练完train_datasets之后,model要来测试样本了。在model(test_datasets)之前,需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的的性质。
    eval()时,pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大。eval()在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。
  4. 何时用model.eval() :训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。在eval/test过程中,需要显示地让model调用eval(),此时模型会把BN和Dropout固定住,不会取平均,而是用训练好的值。
  5. 何时用with torch.no_grad():无论是train() 还是eval() 模式,各层的gradient计算和存储都在进行且完全一致,只是在eval模式下不会进行反向传播。而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。若想节约算力,可在test阶段带上torch.no_grad()。
  6. with torch.no_grad() 主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

参考文章

  1. 知乎讨论
  2. 第十三章 深度解读预训练与微调迁移,模型冻结与解冻(工具)
  3. 【PyTorch】搞定网络训练中的model.train()和model.eval()模式

http://www.ppmy.cn/embedded/6828.html

相关文章

redis实现未支付时间超时就删除订单,并给前端反应一个已过期

1.创建订单缓存,设置过期时间为一分钟 now 是一个表示当前时间的对象,offset 方法用于对当前时间进行偏移。 redisTemplate.expireAt(paymentKey, now.offset(DateField.SECOND, 60)); 2.创建KeyExpiredListener类并且继承KeyExpirationEventMessageLis…

十大排序——5.选择排序

这篇文章我们来介绍一下选择排序 目录 1.介绍 2.代码实现 3.小结与思考 1.介绍 选择排序:选择排序( Selection sort)是一种简单直观的排序算法。它的工作原理是每一趟从待排序的数据元素中选出最小(或最大)的一个…

通过Dockerfile 创建 kali-novnc

创建Dockerfile # 使用官方Kali镜像作为基础镜像 FROM kalilinux/kali-rolling# 设置工作目录 WORKDIR /app# 将当前目录下的所有文件复制到工作目录中 COPY ./run.sh . RUN chmod x /app/run.sh# 安装项目依赖 RUN apt update -y RUN apt upgrade -y# 安装中文字体支持 apt …

《Linux运维总结:Kylin V10+ARM架构CPU基于docker-compose一键离线部署redis6.2.8之容器版哨兵集群》

总结:整理不易,如果对你有帮助,可否点赞关注一下? 更多详细内容请参考:《Linux运维篇:Linux系统运维指南》 一、部署背景 由于业务系统的特殊性,我们需要面向不通的客户安装我们的业务系统&…

酷得智能 无人机方案开发

东莞市酷得智能科技有限公司,是一家专业的技术服务公司,致力于为各类智能硬件提供高效、稳定、安全的底层驱动解决方案。拥有一支经验丰富、技术精湛的团队,能够为客户提供全方位的底层驱动开发服务。 无人机功能介绍: 1、自动跟…

【春秋云境】CVE-2023-4450 jeect-boot queryFieldBySql接口RCE漏洞

靶场介绍 JeecgBoot 是一个开源的低代码开发平台,Jimureport 是低代码报表组件之一。当前漏洞在 1.6.1 以下的 Jimureport 组件库中都存在,由于未授权的 API /jmreport/queryFieldBySql 使用了 freemarker 解析 SQL 语句从而导致了 RCE 漏洞的产生。 开…

【QT】QChartView和QChart的一些图表设置

enum RubberBand {NoRubberBand 0x0,VerticalRubberBand 0x1,HorizontalRubberBand 0x2,RectangleRubberBand 0x3};在 Qt Charts 中,QChartView 类提供了一些方法和属性来控制图表的渲染和交互行为。这些方法包括 setRenderHint 和 setRubberBand,它…

【iOS安全】iOS ARM汇编

mov指令 MOV X22, X0 将X0的值移到X22中 参数传递 参数1:寄存器X0传递 参数2:寄存器X1传递 参数3:寄存器X2传递 参数4:寄存器X3传递 如果需要传递更多参数,会使用栈来传递 返回值 ARM架构下,通常使用…