一、指标介绍
在训练语义分割模型时,我们不仅需要知道训练,验证损失,还想要知道性能指标。
二、计算流程
(1)读取验证集的图片和标签(mask图)
(2)对模型预测的特征图进行解码,获得预测的mask图
(3)创建num_class x num_class尺寸的混淆矩阵hist
(4)将标签mask图和预测mask图转换为numpy数组
(5)将两个numpy数组展平为一维数组,使用np.bincount逐像素计算,再reshape,结果累计在混淆矩阵hist中
(6)根据hist混淆矩阵,计算语义分割指标mIOU,PA_Recall,Precision
三、相应代码
(1)读取验证集的图片和标签(mask图)
for image_id in tqdm(self.image_ids):#-------------------------------## 从文件中读取图像#-------------------------------#image_path = os.path.join(self.dataset_path, "JPEGImages/"+image_id+".jpg")image = Image.open(image_path)#------------------------------## 获得预测特征图#------------------------------#image = self.get_miou_png(image)image.save(os.path.join(pred_dir, image_id + ".png"))
(2)对模型预测的特征图进行解码,获得预测的mask图
#------------------------------#
# 获得预测特征图
#------------------------------#
image = self.get_miou_png(image)
image.save(os.path.join(pred_dir, image_id + ".png"))def get_miou_png(self, image):#---------------------------------------------------------## 在这里将图像转换成RGB图像,防止灰度图在预测时报错。# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB#---------------------------------------------------------#image = cvtColor(image)orininal_h = np.array(image).shape[0]orininal_w = np.array(image).shape[1]#---------------------------------------------------------## 给图像增加灰条,实现不失真的resize# 也可以直接resize进行识别#---------------------------------------------------------#image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))#---------------------------------------------------------## 添加上batch_size维度#---------------------------------------------------------#image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)with torch.no_grad():images = torch.from_numpy(image_data)if self.cuda:images = images.cuda()#---------------------------------------------------## 图片传入网络进行预测#---------------------------------------------------#pr = self.net(images)[0]#---------------------------------------------------## 取出每一个像素点的种类#---------------------------------------------------#pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()#--------------------------------------## 将灰条部分截取掉#--------------------------------------#pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]#---------------------------------------------------## 进行图片的resize#---------------------------------------------------#pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)#---------------------------------------------------## 取出每一个像素点的种类#---------------------------------------------------#pr = pr.argmax(axis=-1)image = Image.fromarray(np.uint8(pr))return image
(3)创建num_class x num_class尺寸的混淆矩阵hist
print('Num classes', num_classes)
#-----------------------------------------#
# 创建一个全是0的矩阵,是一个混淆矩阵
#-----------------------------------------#
hist = np.zeros((num_classes, num_classes))
(4)将标签mask图和预测mask图转换为numpy数组
#------------------------------------------------#
# 读取每一个(图片-标签)对
#------------------------------------------------#
for ind in range(len(gt_imgs)): #------------------------------------------------## 读取一张图像分割结果,转化成numpy数组#------------------------------------------------#pred = np.array(Image.open(pred_imgs[ind])) #------------------------------------------------## 读取一张对应的标签,转化成numpy数组#------------------------------------------------#label = np.array(Image.open(gt_imgs[ind]))
(5)将两个numpy数组展平为一维数组,使用np.bincount逐像素计算,再reshape,结果累计在混淆矩阵hist中
# 如果图像分割结果与标签的大小不一样,这张图片就不计算
if len(label.flatten()) != len(pred.flatten()): print('Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(len(label.flatten()), len(pred.flatten()), gt_imgs[ind],pred_imgs[ind]))continue#------------------------------------------------#
# 对一张图片计算21×21的hist矩阵,并累加
#------------------------------------------------#
hist += fast_hist(label.flatten(), pred.flatten(), num_classes) def fast_hist(a, b, n):#--------------------------------------------------------------------------------## a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)#--------------------------------------------------------------------------------#k = (a >= 0) & (a < n)#--------------------------------------------------------------------------------## np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)# 返回中,写对角线上的为分类正确的像素点#--------------------------------------------------------------------------------#return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
(6)根据hist混淆矩阵,计算语义分割指标mIOU,PA_Recall,Precision
#------------------------------------------------#
# 计算所有验证集图片的逐类别mIoU值
#------------------------------------------------#
IoUs = per_class_iu(hist)
PA_Recall = per_class_PA_Recall(hist)
Precision = per_class_Precision(hist)
四、完整代码
class EvalCallback():def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \miou_out_path=".temp_miou_out", eval_flag=True, period=1):super(EvalCallback, self).__init__()self.net = netself.input_shape = input_shapeself.num_classes = num_classesself.image_ids = image_idsself.dataset_path = dataset_pathself.log_dir = log_dirself.cuda = cudaself.miou_out_path = miou_out_pathself.eval_flag = eval_flagself.period = periodself.image_ids = [image_id.split()[0] for image_id in image_ids]self.mious = [0]self.epoches = [0]if self.eval_flag:with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:f.write(str(0))f.write("\n")def get_miou_png(self, image):#---------------------------------------------------------## 在这里将图像转换成RGB图像,防止灰度图在预测时报错。# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB#---------------------------------------------------------#image = cvtColor(image)orininal_h = np.array(image).shape[0]orininal_w = np.array(image).shape[1]#---------------------------------------------------------## 给图像增加灰条,实现不失真的resize# 也可以直接resize进行识别#---------------------------------------------------------#image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))#---------------------------------------------------------## 添加上batch_size维度#---------------------------------------------------------#image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)with torch.no_grad():images = torch.from_numpy(image_data)if self.cuda:images = images.cuda()#---------------------------------------------------## 图片传入网络进行预测#---------------------------------------------------#pr = self.net(images)[0]#---------------------------------------------------## 取出每一个像素点的种类#---------------------------------------------------#pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()#--------------------------------------## 将灰条部分截取掉#--------------------------------------#pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]#---------------------------------------------------## 进行图片的resize#---------------------------------------------------#pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)#---------------------------------------------------## 取出每一个像素点的种类#---------------------------------------------------#pr = pr.argmax(axis=-1)image = Image.fromarray(np.uint8(pr))return imagedef on_epoch_end(self, epoch, model_eval):if epoch % self.period == 0 and self.eval_flag:self.net = model_evalgt_dir = os.path.join(self.dataset_path, "SegmentationClass/")pred_dir = os.path.join(self.miou_out_path, 'detection-results')if not os.path.exists(self.miou_out_path):os.makedirs(self.miou_out_path)if not os.path.exists(pred_dir):os.makedirs(pred_dir)print("Get miou.")for image_id in tqdm(self.image_ids):#-------------------------------## 从文件中读取图像#-------------------------------#image_path = os.path.join(self.dataset_path, "JPEGImages/"+image_id+".jpg")image = Image.open(image_path)#------------------------------## 获得预测特征图#------------------------------#image = self.get_miou_png(image)image.save(os.path.join(pred_dir, image_id + ".png"))print("Calculate miou.")_, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数temp_miou = np.nanmean(IoUs) * 100self.mious.append(temp_miou)self.epoches.append(epoch)with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:f.write(str(temp_miou))f.write("\n")plt.figure()plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou')plt.grid(True)plt.xlabel('Epoch')plt.ylabel('Miou')plt.title('A Miou Curve')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_miou.png"))plt.cla()plt.close("all")print("Get miou done.")shutil.rmtree(self.miou_out_path)