Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

ops/2024/10/21 9:37:00/

相关阅读

Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


        笔者在使用torch.nn.module的load_state_dict中出现了一个问题,一个被注册的张量在加载后居然没有变化,一开始以为是加载出现了问题,但发现其他参数加载成功,思索后发现是注册的张量的类型是整型而checkpoint中保存为浮点数类型,恰好注册时的默认值给的是0,而checkpoint中的浮点数又在0到1之间,因此出现了这个令人困惑的bug。

        下面首先复现这个bug。

import torch
import torch.nn as nn# 定义一个简单的线性模型,参数类型为整数
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.register_buffer('test', torch.tensor(0)) # 注册一个整型张量# 创建一个简单模型实例
model = SimpleModel()# 创建一个浮点数作为参数
float_parameter = torch.tensor(0.6)# 将注册名指向另一个浮点型张量
model.test = float_parameter# 保存模型
torch.save(model.state_dict(), 'model.pth')# 直接使用原模型加载
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint)# 打印加载后的参数
print(model.test)# 直接使用新模型加载
model_1 = SimpleModel()
model_1.load_state_dict(checkpoint)# 打印加载后的参数
print(model_1.test)
输出:
tensor(0.6000)
tensor(0)

        可以看到,当模型中注册的名字(test),指向了一个类型不符的张量后,并不会导致浮点型张量被截断为整型,这是因为此处是直接使用赋值号=,使名字指向了另一个张量。

        但使用load_state_dict()方法与使用赋值号是不同的,load_state_dict()方法的实现中,调用了_load_from_state_dict()方法,其中调用了copy_()方法,进行了原位(in-place)数据替换,这可能会进行截断,下面是原位替换的一个例子。

python">import torch# 创建两个张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5.1, 6.1], [7.1, 8.1]])# 查看张量对象的id
print(id(a))
print(id(b))# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())# 将张量 b 中的值复制到张量 a 中
a.copy_(b)# 打印复制后的结果
print(a)# 查看张量对象的id
print(id(a))
print(id(b))# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())
python">输出:
2604425272672
2604426953808  
2604511348096  
2602930352832  
tensor([[5, 6],[7, 8]])
2604425272672
2604426953808
2604511348096
2602930352832

        在保存了模型的状态字典后,使用load_state_dict()方法加载后,也不会有任何截断问题,因为对于原模型而言,名字test指向的是一个浮点型张量,此时原位替换,类型吻合。但是对于一个新的模型,此时的test指向的是一个整型张量,此时原位替换,会发生截断。

        因此,在注册一个张量时,需要确保其在注册时和保存时的类型吻合,此处除了指形状,还有类型,否则可能会出现意想不到的bug。


http://www.ppmy.cn/ops/27217.html

相关文章

推荐一个wordpress免费模板下载

首页大背景图,首屏2张轮播图,轮换展示,效果非常的炫酷,非常的哇噻,使用这个主题搭建的wordpress网站,超过了200个,虽然是一个老主题了,不过是经得起时间考验的,现在用起来…

Oracle数据常驻程序内存优化【数据库实例优化系列三】

Oracle程序常驻程序内存优化【数据库实例优化系列二】-CSDN博客 在生产中,为提高用户的访问速度。可以将经常使用的表常驻与内存中。避免频繁的访问磁盘,降低IO。 虽然会占用一定的内存,但是效果还是很明显的。 如果不是用了,DBA可以将其删除。 一、数据缓冲池 数据库…

CentOS/Anolis的Linux系统如何通过VNC登录远程桌面?

综述 需要在server端启动vncserver,推荐tigervnc的server 然后再本地点来启动client进行访问,访问方式是IPport(本质是传递数据包到某个ip的某个port) 然后需要防火墙开启端口 服务器上:安装和启动服务 安装服务 y…

PotatoPie 4.0 实验教程(32) —— FPGA实现摄像头图像浮雕效果

什么是浮雕效果? 浮雕效果是一种图像处理技术,用于将图像转换为看起来像浮雕一样的效果,给人一种凸起或凹陷的立体感觉,下面第二张图就是图像处理实现浮雕效果。 不过这个图是用Adobe公司的PS人工P图实现的,效果比较…

前端开发工程师——Vue

Vue学习笔记&#xff08;尚硅谷天禹老师&#xff09;_尚硅谷天禹老师vue2021讲课笔记下载-CSDN博客 模板语法 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" co…

飞书API(6):使用 pandas 处理数据并写入 MySQL 数据库

一、引入 上一篇了解了飞书 28 种数据类型通过接口读取到的数据结构&#xff0c;本文开始探讨如何将这些数据写入 MySQL 数据库。这个工作流的起点是从 API 获取到的一个完整的数据&#xff0c;终点是写入 MySQL 数据表&#xff0c;表结构和维格表结构类似。在过程中可以有不同…

Python和Julia河流湖泊沿海水域特征数值算法模型

&#x1f3af;要点 一维水流场景计算和绘图&#xff1a; &#x1f3af;恒定透射率水头和流量计算&#xff1a;&#x1f58a;两条完全穿透畜水层理想河流之间 | &#x1f58a;无承压畜水层两侧及两条完全穿透畜水层的补给 | &#x1f58a;分水岭或渗透性非常低的岩体的不渗透边…

【C++】模板初阶

&#x1f525;个人主页&#xff1a; Forcible Bug Maker &#x1f525;专栏&#xff1a; C 目录 前言泛型编程模板函数模板概念及简单使用函数模板的原理函数模板的实例化模板参数的匹配原则 类模板概念及简单使用类模板的实例化 结语 前言 本篇博客主要内容&#xff1a;初步接…