Pytorch——训练时,冻结网络部分参数的方法

ops/2024/9/24 8:47:14/

一、原理:

要固定训练网络的哪几层,只需要找到这几层参数(parameter),然后将其 .requires_grad 属性设置为 False 。然后修改优化器,只将不被冻结的层传入。

二、效果

  1. 节省显存:不将不更新的参数传入optimizer
  2. 提升速度:将不更新的参数的requires_grad设置为False,节省了计算这部分参数梯度的时间

三、代码:

.requires_grad 属性设置为 False

# 根据参数层的 name 来进行冻结
unfreeze_layers = ["text_id"] # 用列表
# 设置冻结参数:
for name, param in model.named_parameters():# print(name, param.shape)# 错误判定:# if name.split(".")[0] in unfreeze_layers: # 不要用in来判定,因为"id"也在"text_id"的in中。# 正确判定:for unfreeze_layer in unfreeze_layers:if name.split(".")[0] != unfreeze_layer:param.requires_grad = Falseprint(name, param.requires_grad)else:print(name, param.requires_grad)
# 冻结整个网络
for param in self.model.parameters():param.requires_grad = False
# 查看冻结参数与否:
for name, param in self.clip_model.named_parameters():print(name, param.requires_grad)

修改优化器

# 只将未被冻结的层,传入优化器
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)

四、其他知识

  1. 模型权重冻结:一些情况下,我们可能只需要对模型进行推断,而不需要调整模型的权重。通过调用model.eval(),可以防止在推断过程中更新模型的权重。
  2. with torch.no_grad(): # 禁用梯度计算以加快计算速度。
  3. 训练完train_datasets之后,model要来测试样本了。在model(test_datasets)之前,需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的的性质。
    eval()时,pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大。eval()在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。
  4. 何时用model.eval() :训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。在eval/test过程中,需要显示地让model调用eval(),此时模型会把BN和Dropout固定住,不会取平均,而是用训练好的值。
  5. 何时用with torch.no_grad():无论是train() 还是eval() 模式,各层的gradient计算和存储都在进行且完全一致,只是在eval模式下不会进行反向传播。而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。若想节约算力,可在test阶段带上torch.no_grad()。
  6. with torch.no_grad() 主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

参考文章

  1. 知乎讨论
  2. 第十三章 深度解读预训练与微调迁移,模型冻结与解冻(工具)
  3. 【PyTorch】搞定网络训练中的model.train()和model.eval()模式

http://www.ppmy.cn/ops/4764.html

相关文章

A-1:树状数组

A-1:树状数组 1.介绍Q1:树状数组解决什么问题?Q2:树状数组的使用1.前置知识:lowbit(x)2.单点修改3.求[1,n]的和4.区间查询5.hh Q3:树状数组是否优化了Q4:上图上例子解释上面说的东西(Important) 2.习题练习 1.介绍 树状数组是一个比较难以理解的高级数据…

WPF App.xaml 中添加多个ResourceDictionary

在WPF应用程序中,App.xaml 文件是一个常用的集中位置来管理应用级别的资源,包括样式、模板、图像、数据转换器等。为了添加多个 ResourceDictionary 到 App.xaml 中,可以利用 ResourceDictionary 的 MergedDictionaries 属性。这个属性允许您…

2024npm国内镜像源

npm 官方原始镜像网址是:https://registry.npmjs.org/ 淘宝 NPM 镜像:https://registry.npmmirror.com 阿里云 NPM 镜像:https://npm.aliyun.com 腾讯云 NPM 镜像:https://mirrors.cloud.tencent.com/npm/ 华为云 NPM 镜像&#x…

【切换网络连接后】VMware虚拟机网络配置【局域网通信】

初次安装Linux虚拟机以及切换网络都需要配置虚拟机网络, 从而使得win主机内通过远程连接工具能够连接该虚拟机, 而不是在虚拟机内操作。 本片文章你将了解到网络切换后如何配置虚拟机网络的一些基础操作,以及局域网通信的一些基础知识。 …

链表经典算法OJ题目

1.单链表相关经典算OJ题目1:移除链表元素 思路一 直接在原链表里删除val元素,然后让val前一个结点和后一个节点连接起来。 这时我们就需要3个指针来遍历链表: pcur —— 判断节点的val值是否于给定删除的val值相等 prev ——保存pcur的前…

深入理解汇编:平栈、CALL和RET指令详解

​视频学习下载地址:​​https://pan.quark.cn/s/04e6946a803a​​ 汇编语言以其接近硬件的特性和高效的执行速度,在系统编程、性能优化和逆向工程中占有不可或缺的地位。本文将深入探讨汇编语言中的平栈操作以及​​CALL​​​和​​RET​​指令&#…

聚观早报 | 华为Pura70系列先锋计划;月之暗面升级Kimi

聚观早报每日整理最值得关注的行业重点事件,帮助大家及时了解最新行业动态,每日读报,就读聚观365资讯简报。 整理丨Cutie 4月19日消息 华为Pura70系列先锋计划 月之暗面升级Kimi OPPO Find X7将推白色版本 波士顿动力推出人形机器人 v…

vue flvjs 播放视频

写在前面: 之前使用过vodiejs插件播放过mp4视频格式的视频; 此次需要使用flvjs插件播放rtsp视频格式的视频; 因为视频的数据格式不同,所以对应的插件不同。 思维导图: 参考链接:rtmp、rtsp、flv、m3u8、 …