Pytorch模型测试时显存一直上升导致爆显存

news/2025/1/11 22:43:21/

问题描述

首先说明: 由于我的测试集很大, 因此需要对测试集进行分批次推理.

在写代码的时候发现进行训练的时候大概显存只占用了2GB左右, 而且训练过程中显存占用量也基本上是不变的. 而在测试的时候, 发现显存在每个batch数据推理后逐渐增加, 直至最后导致爆显存, 程序fail.

这里放一下我测试的代码:

	y, y_ = torch.Tensor(), torch.Tensor()for batch in tqdm(loader):x, batch_y = batch[0], batch[1]batch_y_ = model(x)y = torch.cat([y, batch_y], dim=0)y_ = torch.cat([y_, batch_y_], dim=0)

解决方法

遇到问题后我就进行单步调试, 然后观察显存的变化. 发现在模型推理这一步, 每一轮次显存都会增加.

batch_y_ = model(x)

这里令人费解的是, 模型推理实际上在训练和测试中都是存在的, 为什么训练的时候就不会出现这个问题呢.

最后发现其实是在训练的时候有这样一步与测试不同:

self.optimizer.zero_grad()

⭐️ 在训练时, 每一个batch后都会将模型的梯度进行一次清零. 而测试的时候我则没有加这一步, 这样的话每次模型再做推理的时候都会产生新的梯度, 并累积到显存当中.

清楚了问题, 那么解决方法也就随之而来, 在测试的时候让模型不要记录梯度就好, 因为其实也用不到:

with torch.no_grad():test()

总结

  • 训练的时候每一个batch结束除了梯度反向传播, 还要提前清理梯度
  • 梯度如果不清理的话, 会在显存中累积下来

http://www.ppmy.cn/news/869743.html

相关文章

dataloader合理设置num_works和batchsize 避免爆内存

dataloader合理设置num_works和batchsize,避免爆内存 1.关乎内存2.关乎显存3.总结 个人总结,禁止以任何形式的转载!! 1.关乎内存 1)dataloader会开启num_works个进程,如图所示:(这里设置的是6…

理清ROS通信的一些细节

目标:掌握ros的python编程 基本教程:https://www.bilibili.com/video/BV1sU4y1z7mw/?spm_id_from333.788&vd_source32148098d54c83926572ec0bab6a3b1d terminator 快捷键需要自己去重新启用 ctrlshifte 横向分屏 ctrlshifto 纵向分屏 ctrlshiftw …

CPU GPU爆显存

用CPU的环境训练》? 换成GPU环境 爆显存

TensorFlow 显存占用率高 GPU利用率低

文章目录 nvidia-smi指令动态刷新GPU信息显存占用高,但是CPU使用率低回头再看 nvidia-smi指令 命令位置: 所以Path中添加环境变量: C:\Program Files\NVIDIA Corporation\NVSMI 试验一哈: 要注意的点: Driver Vers…

[数组]移除元素

我用好长时间才写出来,看了题解感觉他思路贼好 一、leecode题目链接 力扣 二、题解 数组中移除元素并不容易! | LeetCode:27. 移除元素_哔哩哔哩_bilibili 三、代码 1、老师的 思路 1)用快慢指针的思路来解决问题 slow指针…

【Jeston Nano】刷机

Jeston Nano刷机 有两种方法 1.使用SD卡 2.使用SDK Manager 使用SD卡 1.下载镜像 JetPack JetPack存档 下载4.6版本的 找到Nano 2.烧写SD image 下载balenaetcher,链接:balenaetcher 选择合适的版本。 windows版本的可以直接下载安装 选择镜像…

Shuttle ESB实现消息推送

ESB全称Enterprise Service Bus,即企业服务总线。它是传统中间件技术与XML、Web服务等技术结合的产物。 ESB的出现改变了传统的软件架构,能够提供比传统中间件产品更为便宜的解决方式。同一时候它还能够消除不同应用之间的技术差异,让不同的应…