pytorch——保存‘类别名与类别数量’到权值文件中

news/2025/2/19 15:46:43/

前言

不知道大家有没有像我一样,每换一次不一样的模型,就要输入不同的num_classes和name_classes,反正我是很头疼诶,尤其是项目里面不止一个模型的时候,更新的时候看着就很头疼,然后就想着直接输入模型权值文件的path该多好,然后我就搞起来了。

在自己的类中加入想要加入数据信息

class your_nets(nn.Module):def __init__(self, num_classes = 21,name_classes=None):super(your_nets, self).__init__()self.num_classes = num_classesself.name_classes = name_classes

训练过程之保存文件

      
model = your_nets(num_classes=num_classes, name_classes=name_classes)save_dict = {'state_dict': model.state_dict(),'num_classes': model.num_classes,'name_classes': model.name_classes}torch.save(save_dict, os.path.join(save_dir, "best_epoch_weights.pth"))

使用 

model = get_nets_class(model_path=model_path)class get_nets_class(object):def __init__(self ,**kwargs):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')load_dict  = torch.load(self.model_path, map_location=device)state_dict =load_dict['state_dict']num_classes = load_dict['num_classes']name_classes = load_dict['name_classes']if num_classes is not None and name_classes is not None:self.num_classes =num_classesself.name_classes = name_classesself.net = your_nets(num_classes=self.num_classes,name_classes=name_classe)self.net.load_state_dict(state_dict)else:self.net = your_nets(num_classes=self.num_classes, backbone=self.backbone)self.net.load_state_dict(load_dict)self.net = self.net.eval()def predict(self,image,name_classes,object_list):#你的预处理操作,没有就忽略image_data = preprocess(image)with torch.no_grad():# 推理pr = self.net(images)[0]# softmax 得出概率 pr.permute(1, 2, 0), dim=-1为我自己的操作,没有请忽略pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()#你的后处理操作,没有就忽略pr = postprocess(pr)#这一步与object_list有关 object_list是你想要模型去预测的内容# 例如你训练了识别cat、dog、pig、person的类别 那么你想只识别人,那么就object_list=['person'] if object_list is not None:model_object_list = [name_classes.index(i) for i in object_list if i in name_classes]temp_list = [i for i in range(len(name_classes))]remove_list = [i for i in temp_list if i not in model_object_list]for i in remove_list:pr[pr==i] = 0retuen pr

我是觉得已经很详细了,大家要是不懂可以再问,我可以慢慢改进,每个人的写法都不一样 。

欢迎大家点赞加收藏哟~


http://www.ppmy.cn/news/1344146.html

相关文章

STM32学习笔记三——深度讲解GPIO及其应用

目录 STM32GPIO端口位基本结构图: 结构图I/O引脚: GPIO输入输出总结 1.GPIO引脚的四种输入方式及其特点: 1)上拉输入(GPIO_Mode_IPU) 2)下拉输入(GPIO_Mode_IPD) 3)模拟输入(GPIO_Mode_AIN) 4)浮空输入(GPIO_Mode_IN_FLOATING…

2024.2.5 寒假训练记录(19)

文章目录 牛客 寒假集训2A Tokitsukaze and Bracelet牛客 寒假集训2B Tokitsukaze and Cats牛客 寒假集训2D Tokitsukaze and Slash Draw牛客 寒假集训2E Tokitsukaze and Eliminate (easy)牛客 寒假集训2F Tokitsukaze and Eliminate (hard)牛客 寒假集训2I Tokitsukaze and S…

vue3学习——路由

安装 pnpm i vue-router/src/router/index.ts import { createRouter, createWebHashHistory } from vue-router import { constantRoute } from /router/routes.tslet router createRouter({history: createWebHashHistory(),routes: constantRoute, })export default rout…

【issue-YOLO】自定义数据集训练YOLO-v7 Segmentation

1. 拉取代码创建环境 执行nvidia-smi验证cuda环境是否可用;拉取官方代码; clone官方代码仓库 git clone https://github.com/WongKinYiu/yolov7;从main分支切换到u7分支 cd yolov7 && git checkout 44f30af0daccb1a3baecc5d80eae229…

十分钟掌握Go语言==运算符与reflect.DeepEqual函数处理interface{}值的比较规则

在 Go 语言中,interface{} 类型是一种特殊的接口类型,它表示任意类型的值。你可以使用 运算符来检测任意两个 interface{} 类型值的相等性,比较的规则和一般的接口类型一样,需要满足以下条件: 两个 interface{} 值的…

Redis 命令大全

文章目录 启动与连接Key(键)相关命令String(字符串)Hash(哈希)List(列表)Set(集合)Sorted Set(有序集合)其他常见命令HyperLogLog&…

用的到的linux-删除文件-Day3

前言: 上一节,我们讲到了怎么去移动文件,其中使用到两大类的脚本命令即cp和mv。各两种命令都可以完成移动,但是cp是复制粘贴的方式,可以选择原封不动的复制粘贴过来,即不修改文件及文件夹的创建时间等&…

Android中设置Toast.setGravity()了后没有效果

当设置 toast.setGravity()后,弹窗依旧从原来的位置弹出,不按设置方向弹出 类似以下代码: var toast Toast.makeText(this, R.string.ture_toast, Toast.LENGTH_SHORT)toast.setGravity(Gravity.TOP, 0, 0)//设置toast的弹出方向为屏幕顶部…