解决size mismatch for embedding.embed_dict.userid.weight

news/2025/2/3 21:08:22/

文章目录

  • 一、问题描述
  • 二、解决方法
  • 三、其他问题
  • Reference

一、问题描述

导入之前训练好的模型权重后使用模型预测时如题报错size mismatch for embedding.embed_dict.userid.weight

state_dict = torch.load(model_path)
model.load_state_dict(state_dict)

二、解决方法

是因为导入的模型权重(之前训练好、保存的)的维度和当前定义的model的权重维度不同,所以我选择修改下当前定义的model,即将自己返回如下beat_sparse_features等的dataloader,其读取的数据换成之前模型训练的数据,使得模型定义后的model的模型权重和导入的权重一致。

model = DeepFM(deep_features=beat_dense_features + beat_sparse_features,fm_features=beat_sparse_features,mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"},
)

当然如果根据大家的实际情况改动,如很多时候实例化模型时改变实参即可。

三、其他问题

可能还有其他情况也会报这个错,如导入预训练模型进行微调,首先加载预训练模型权重:

model = models.resnet34(pretrained=False)
pretrained_dict = torch.load('./pretrain/resnet34-333f7ec4.pth')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
model.load_state_dict(pretrained_dict)
model.fc = torch.nn.Linear(512, 5) # 512为原始fc的数目,5是自己任务的分类数

由于分类类别不一致,报错size mismatch for fc.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([x]).,这里可以选择不加载fc层:

model = models.resnet34(pretrained=False)
pretrained_dict = torch.load('./pretrain/resnet34-333f7ec4.pth')
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)} # 将'fc'这一层的权重选择不加载即可。
model_dict.update(pretrained_dict) # 更新权重
model.load_state_dict(model_dict)

可能还有其他情况,如NLP词表维度不一致等等,后面遇到再更新该帖。如有不对之处,恳请大佬们指正!

Reference

[1] 解决CNN中训练权重参数不匹配size mismatch for fc.weight,size mismatch for fc.bias
[2] torch 封装文本数据预处理、训练、评估、预测过程
[3] 关于Pytorch加载模型参数的避坑指南


http://www.ppmy.cn/news/2852.html

相关文章

Docker容器化技术入门(一)Docker简介

Docker容器化技术入门(一)Docker简介前言(一)Docker简介1 Docker是什么?1.1 Docker的出现1.2 Docker的理念1.3 一句话2 容器与虚拟机比较2.1 容器发展简史2.2 传统虚拟机技术2.3 容器虚拟化技术2.4 对比3 Docker能干什…

实操1 : 如何统计输入单词中出现次数最多的字母并将其打印到控制台

(一) 问题描述 输入一行单词 统计其中出现的最多次数的字母并将其打印到控制台 输入 longlonglongistoolong输出 o 6(二) 解决方案 # -*- coding: utf-8 -*-w input() d {}for i in w:d[i] d.get(i, 0) 1 ds sorted(d.items(), keylambda x: x[1], reverseTrue)Max ds[0…

『SnowFlake』雪花算法的详解及时间回拨解决方案

📣读完这篇文章里你能收获到 图文形式为你讲解原生雪花算法的特征及原理了解时间回拨的概念以及可能引起发此现象的操作掌握时间回拨的解决方案—基于时钟序列的雪花算法关于雪花算法的常见问题解答 文章目录一、原生的雪花算法1. 简介2. 特征3. 原理3.1 格式&…

1158 Telefraud Detection

Telefraud(电信诈骗) remains a common and persistent problem in our society. In some cases, unsuspecting victims lose their entire life savings. To stop this crime, you are supposed to write a program to detect those suspects from a hu…

GC 算法总结_java培训

1.标记清除压缩(Mark-Sweep-Compact) 标记清除、标记压缩的结合使用 原理java培训GC 算法总结 2.算法总结 内存效率:复制算法>标记清除算法>标记整理算法(此处的效率只是简单的对比时间复杂度,实际情况不一定如此)。 内…

基于多目标优化算法的电力系统分析(Matlab代码实现)

💥💥💥💞💞💞欢迎来到本博客❤️❤️❤️💥💥💥 🎉作者研究:🏅🏅🏅主要研究方向是电力系统和智能算法、机器学…

Spring中的Bean的实例化

Bean的实例化1. Bean的配置2.Bean的实例化2.1 构造器实例化2.2 静态工厂方式实例化2.3 实例工厂方式实例化1. Bean的配置 Spring 可以被看作是一个大型工厂,这个工厂的作用就是生产和管理 Spring 容器中的Bean。如果想要在项目中使用这个工厂,就需要开发…

是时候给钉钉和腾讯会议算算账了

杨净 萧箫 发自 凹非寺量子位 | 公众号 QbitAI这几天,工作和上课等事情开始有回归线下的迹象,腾讯会议、钉钉似乎也可以松口气了。毕竟云会议的这两大APP,前段时间一直在被网友找平替。一来,它们要收费了;二来&#xf…