前言
不知道大家有没有像我一样,每换一次不一样的模型,就要输入不同的num_classes和name_classes,反正我是很头疼诶,尤其是项目里面不止一个模型的时候,更新的时候看着就很头疼,然后就想着直接输入模型权值文件的path该多好,然后我就搞起来了。
在自己的类中加入想要加入数据信息
class your_nets(nn.Module):def __init__(self, num_classes = 21,name_classes=None):super(your_nets, self).__init__()self.num_classes = num_classesself.name_classes = name_classes
训练过程之保存文件
model = your_nets(num_classes=num_classes, name_classes=name_classes)save_dict = {'state_dict': model.state_dict(),'num_classes': model.num_classes,'name_classes': model.name_classes}torch.save(save_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
使用
model = get_nets_class(model_path=model_path)class get_nets_class(object):def __init__(self ,**kwargs):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')load_dict = torch.load(self.model_path, map_location=device)state_dict =load_dict['state_dict']num_classes = load_dict['num_classes']name_classes = load_dict['name_classes']if num_classes is not None and name_classes is not None:self.num_classes =num_classesself.name_classes = name_classesself.net = your_nets(num_classes=self.num_classes,name_classes=name_classe)self.net.load_state_dict(state_dict)else:self.net = your_nets(num_classes=self.num_classes, backbone=self.backbone)self.net.load_state_dict(load_dict)self.net = self.net.eval()def predict(self,image,name_classes,object_list):#你的预处理操作,没有就忽略image_data = preprocess(image)with torch.no_grad():# 推理pr = self.net(images)[0]# softmax 得出概率 pr.permute(1, 2, 0), dim=-1为我自己的操作,没有请忽略pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()#你的后处理操作,没有就忽略pr = postprocess(pr)#这一步与object_list有关 object_list是你想要模型去预测的内容# 例如你训练了识别cat、dog、pig、person的类别 那么你想只识别人,那么就object_list=['person'] if object_list is not None:model_object_list = [name_classes.index(i) for i in object_list if i in name_classes]temp_list = [i for i in range(len(name_classes))]remove_list = [i for i in temp_list if i not in model_object_list]for i in remove_list:pr[pr==i] = 0retuen pr
我是觉得已经很详细了,大家要是不懂可以再问,我可以慢慢改进,每个人的写法都不一样 。
欢迎大家点赞加收藏哟~