Pytorch——训练时,冻结网络部分参数的方法

devtools/2024/9/24 4:28:23/

一、原理:

要固定训练网络的哪几层,只需要找到这几层参数(parameter),然后将其 .requires_grad 属性设置为 False 。然后修改优化器,只将不被冻结的层传入。

二、效果

  1. 节省显存:不将不更新的参数传入optimizer
  2. 提升速度:将不更新的参数的requires_grad设置为False,节省了计算这部分参数梯度的时间

三、代码:

.requires_grad 属性设置为 False

# 根据参数层的 name 来进行冻结
unfreeze_layers = ["text_id"] # 用列表
# 设置冻结参数:
for name, param in model.named_parameters():# print(name, param.shape)# 错误判定:# if name.split(".")[0] in unfreeze_layers: # 不要用in来判定,因为"id"也在"text_id"的in中。# 正确判定:for unfreeze_layer in unfreeze_layers:if name.split(".")[0] != unfreeze_layer:param.requires_grad = Falseprint(name, param.requires_grad)else:print(name, param.requires_grad)
# 冻结整个网络
for param in self.model.parameters():param.requires_grad = False
# 查看冻结参数与否:
for name, param in self.clip_model.named_parameters():print(name, param.requires_grad)

修改优化器

# 只将未被冻结的层,传入优化器
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)

四、其他知识

  1. 模型权重冻结:一些情况下,我们可能只需要对模型进行推断,而不需要调整模型的权重。通过调用model.eval(),可以防止在推断过程中更新模型的权重。
  2. with torch.no_grad(): # 禁用梯度计算以加快计算速度。
  3. 训练完train_datasets之后,model要来测试样本了。在model(test_datasets)之前,需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的的性质。
    eval()时,pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大。eval()在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。
  4. 何时用model.eval() :训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。在eval/test过程中,需要显示地让model调用eval(),此时模型会把BN和Dropout固定住,不会取平均,而是用训练好的值。
  5. 何时用with torch.no_grad():无论是train() 还是eval() 模式,各层的gradient计算和存储都在进行且完全一致,只是在eval模式下不会进行反向传播。而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。若想节约算力,可在test阶段带上torch.no_grad()。
  6. with torch.no_grad() 主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

参考文章

  1. 知乎讨论
  2. 第十三章 深度解读预训练与微调迁移,模型冻结与解冻(工具)
  3. 【PyTorch】搞定网络训练中的model.train()和model.eval()模式

http://www.ppmy.cn/devtools/6605.html

相关文章

LeetCode 面试经典150题 202.快乐数

题目: 编写一个算法来判断一个数 n 是不是快乐数。 「快乐数」 定义为: 对于一个正整数,每一次将该数替换为它每个位置上的数字的平方和。然后重复这个过程直到这个数变为 1,也可能是 无限循环 但始终变不到 1。如果这个过程 结…

selenium反反爬虫,隐藏selenium特征

一、stealth.min.js 使用 用selenium爬网页时,常常碰到被检测到selenium ,会被服务器直接判定为非法访问,这个时候就可以用stealth.min.js 来隐藏selenium特征,达到绕过检测的目的 from selenium import webdriver from seleniu…

OpenHarmony实战开发-如何通过分割swiper区域,实现指示器导航点位于swiper下方的效果。

介绍 本示例介绍通过分割swiper区域,实现指示器导航点位于swiper下方的效果。 效果预览图 使用说明 1.加载完成后swiper指示器导航点,位于显示内容下方。 实现思路 1.将swiper区域分割为两块区域,上方为内容区域,下方为空白区…

STM32 CAN接收FIFO细节

STM32 CAN接收FIFO细节 简介 CAN外设一共有2个接收FIFO,每个FIFO中有3个邮箱,即最多可以缓存6个接收到的报文。 FIFO状态 EMPTY: 初始状态,表示FIFO为空,没有挂起的消息(FMP0x00),且没有发生…

【计算机网络】物理层

目录 物理层概述物理层接口特性 物理层下面的传输媒体导向型传输媒体非导向型传输媒体 传输方式串行传输和并行传输同步传输和异步传输单向通信,双向交替通信和双向同时通信 编码和调制编码与调制的基本概念常见编码方式 信道的极限容量(不全&#xff0c…

ubuntu常用方法

文本文件的创建: sudo touch ubuntu.txt move clock: sudo chmod 777 ubuntu.txt 安装chrome wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb sudo apt install ./google-chrome-stable_current_amd64.deb .sh 文件的安装 例…

C语言中的控制语句(循环语句while、for)

循环语句 什么是循环 重复执行代码 为什么需要循环循环的实现方式 whiledo...whilefor while语句 语法格式&#xff1a; while (条件) {循环体…… } 需求&#xff1a;跑步5圈 示例代码&#xff1a; #include <stdio.h>int main() {// 需求跑步5圈// 1. 条件变量的…

B树(B-tree)

B树(B-tree) B树(B-tree)是一种自平衡的多路查找树&#xff0c;主要用于磁盘或其他直接存取的辅助存储设备 B树能够保持数据有序&#xff0c;并允许在对数时间内完成查找、插入及删除等操作 这种数据结构常被应用在数据库和文件系统的实现上 B树的特点包括&#xff1a; B树为…