【深度学习实战(17)】计算语义分割的性能指标mIOU

embedded/2024/9/24 23:25:07/

一、指标介绍

在训练语义分割模型时,我们不仅需要知道训练,验证损失,还想要知道性能指标。

二、计算流程

(1)读取验证集的图片和标签(mask图)
(2)对模型预测的特征图进行解码,获得预测的mask图
(3)创建num_class x num_class尺寸的混淆矩阵hist
(4)将标签mask图和预测mask图转换为numpy数组
(5)将两个numpy数组展平为一维数组,使用np.bincount逐像素计算,再reshape,结果累计在混淆矩阵hist中
(6)根据hist混淆矩阵,计算语义分割指标mIOU,PA_Recall,Precision

三、相应代码

(1)读取验证集的图片和标签(mask图)

for image_id in tqdm(self.image_ids):#-------------------------------##   从文件中读取图像#-------------------------------#image_path  = os.path.join(self.dataset_path, "JPEGImages/"+image_id+".jpg")image       = Image.open(image_path)#------------------------------##   获得预测特征图#------------------------------#image       = self.get_miou_png(image)image.save(os.path.join(pred_dir, image_id + ".png"))

(2)对模型预测的特征图进行解码,获得预测的mask图

#------------------------------#
#   获得预测特征图
#------------------------------#
image       = self.get_miou_png(image)
image.save(os.path.join(pred_dir, image_id + ".png"))def get_miou_png(self, image):#---------------------------------------------------------##   在这里将图像转换成RGB图像,防止灰度图在预测时报错。#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB#---------------------------------------------------------#image       = cvtColor(image)orininal_h  = np.array(image).shape[0]orininal_w  = np.array(image).shape[1]#---------------------------------------------------------##   给图像增加灰条,实现不失真的resize#   也可以直接resize进行识别#---------------------------------------------------------#image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))#---------------------------------------------------------##   添加上batch_size维度#---------------------------------------------------------#image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)with torch.no_grad():images = torch.from_numpy(image_data)if self.cuda:images = images.cuda()#---------------------------------------------------##   图片传入网络进行预测#---------------------------------------------------#pr = self.net(images)[0]#---------------------------------------------------##   取出每一个像素点的种类#---------------------------------------------------#pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()#--------------------------------------##   将灰条部分截取掉#--------------------------------------#pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]#---------------------------------------------------##   进行图片的resize#---------------------------------------------------#pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)#---------------------------------------------------##   取出每一个像素点的种类#---------------------------------------------------#pr = pr.argmax(axis=-1)image = Image.fromarray(np.uint8(pr))return image

(3)创建num_class x num_class尺寸的混淆矩阵hist

print('Num classes', num_classes)  
#-----------------------------------------#
#   创建一个全是0的矩阵,是一个混淆矩阵
#-----------------------------------------#
hist = np.zeros((num_classes, num_classes))

(4)将标签mask图和预测mask图转换为numpy数组

#------------------------------------------------#
#   读取每一个(图片-标签)对
#------------------------------------------------#
for ind in range(len(gt_imgs)): #------------------------------------------------##   读取一张图像分割结果,转化成numpy数组#------------------------------------------------#pred = np.array(Image.open(pred_imgs[ind]))  #------------------------------------------------##   读取一张对应的标签,转化成numpy数组#------------------------------------------------#label = np.array(Image.open(gt_imgs[ind]))  

(5)将两个numpy数组展平为一维数组,使用np.bincount逐像素计算,再reshape,结果累计在混淆矩阵hist中

# 如果图像分割结果与标签的大小不一样,这张图片就不计算
if len(label.flatten()) != len(pred.flatten()):  print('Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(len(label.flatten()), len(pred.flatten()), gt_imgs[ind],pred_imgs[ind]))continue#------------------------------------------------#
#   对一张图片计算21×21的hist矩阵,并累加
#------------------------------------------------#
hist += fast_hist(label.flatten(), pred.flatten(), num_classes) def fast_hist(a, b, n):#--------------------------------------------------------------------------------##   a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)#--------------------------------------------------------------------------------#k = (a >= 0) & (a < n)#--------------------------------------------------------------------------------##   np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)#   返回中,写对角线上的为分类正确的像素点#--------------------------------------------------------------------------------#return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)  

(6)根据hist混淆矩阵,计算语义分割指标mIOU,PA_Recall,Precision

#------------------------------------------------#
#   计算所有验证集图片的逐类别mIoU值
#------------------------------------------------#
IoUs        = per_class_iu(hist)
PA_Recall   = per_class_PA_Recall(hist)
Precision   = per_class_Precision(hist)

四、完整代码

class EvalCallback():def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \miou_out_path=".temp_miou_out", eval_flag=True, period=1):super(EvalCallback, self).__init__()self.net                = netself.input_shape        = input_shapeself.num_classes        = num_classesself.image_ids          = image_idsself.dataset_path       = dataset_pathself.log_dir            = log_dirself.cuda               = cudaself.miou_out_path      = miou_out_pathself.eval_flag          = eval_flagself.period             = periodself.image_ids          = [image_id.split()[0] for image_id in image_ids]self.mious      = [0]self.epoches    = [0]if self.eval_flag:with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:f.write(str(0))f.write("\n")def get_miou_png(self, image):#---------------------------------------------------------##   在这里将图像转换成RGB图像,防止灰度图在预测时报错。#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB#---------------------------------------------------------#image       = cvtColor(image)orininal_h  = np.array(image).shape[0]orininal_w  = np.array(image).shape[1]#---------------------------------------------------------##   给图像增加灰条,实现不失真的resize#   也可以直接resize进行识别#---------------------------------------------------------#image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))#---------------------------------------------------------##   添加上batch_size维度#---------------------------------------------------------#image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)with torch.no_grad():images = torch.from_numpy(image_data)if self.cuda:images = images.cuda()#---------------------------------------------------##   图片传入网络进行预测#---------------------------------------------------#pr = self.net(images)[0]#---------------------------------------------------##   取出每一个像素点的种类#---------------------------------------------------#pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()#--------------------------------------##   将灰条部分截取掉#--------------------------------------#pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]#---------------------------------------------------##   进行图片的resize#---------------------------------------------------#pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)#---------------------------------------------------##   取出每一个像素点的种类#---------------------------------------------------#pr = pr.argmax(axis=-1)image = Image.fromarray(np.uint8(pr))return imagedef on_epoch_end(self, epoch, model_eval):if epoch % self.period == 0 and self.eval_flag:self.net    = model_evalgt_dir      = os.path.join(self.dataset_path, "SegmentationClass/")pred_dir    = os.path.join(self.miou_out_path, 'detection-results')if not os.path.exists(self.miou_out_path):os.makedirs(self.miou_out_path)if not os.path.exists(pred_dir):os.makedirs(pred_dir)print("Get miou.")for image_id in tqdm(self.image_ids):#-------------------------------##   从文件中读取图像#-------------------------------#image_path  = os.path.join(self.dataset_path, "JPEGImages/"+image_id+".jpg")image       = Image.open(image_path)#------------------------------##   获得预测特征图#------------------------------#image       = self.get_miou_png(image)image.save(os.path.join(pred_dir, image_id + ".png"))print("Calculate miou.")_, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None)  # 执行计算mIoU的函数temp_miou = np.nanmean(IoUs) * 100self.mious.append(temp_miou)self.epoches.append(epoch)with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:f.write(str(temp_miou))f.write("\n")plt.figure()plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou')plt.grid(True)plt.xlabel('Epoch')plt.ylabel('Miou')plt.title('A Miou Curve')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_miou.png"))plt.cla()plt.close("all")print("Get miou done.")shutil.rmtree(self.miou_out_path)

http://www.ppmy.cn/embedded/6877.html

相关文章

在PostgreSQL中,如何创建一个触发器并在特定事件发生时执行自定义操作?

文章目录 解决方案示例代码1. 创建自定义函数2. 创建触发器 解释 在PostgreSQL中&#xff0c;触发器&#xff08;trigger&#xff09;是一种数据库对象&#xff0c;它能在特定的事件&#xff08;如INSERT、UPDATE或DELETE&#xff09;发生时自动执行一系列的操作。这些操作可以…

redmibook 14 2020 安装 ubuntu

1. 参考博客 # Ubuntu20.10系统安装 -- 小米redmibook pro14 https://zhuanlan.zhihu.com/p/616543561# ubuntu18.04 wifi 问题 https://blog.csdn.net/u012748494/article/details/105421656/# 笔记本电脑安装了Ubuntu系统设置关盖/合盖不挂起/不睡眠 https://blog.csdn.net/…

socket编程——tcp

在我这篇博客&#xff1a;网络——socket编程中介绍了关于socket编程的一些必要的知识&#xff0c;以及介绍了使用套接字在udp协议下如何通信&#xff0c;这篇博客中&#xff0c;我将会介绍如何使用套接字以及tcp协议进行网络通信。 1. 前置准备 在进行编写代码之前&#xff…

JDK 11下载、安装、配置

下载 到Oracle管网下载JDK 11&#xff0c;下载前需要登录&#xff0c;否则直接点下载会出现502 bad gateway。 下载页面链接 https://www.oracle.com/hk/java/technologies/downloads/#java11-windows 登录 有些人可能没有Oracle账号&#xff0c;注册也比较慢&#xff0c;有需…

VR全景:为户外游玩体验插上科技翅膀

随着VR全景技术的愈发成熟&#xff0c;无数人感到惊艳&#xff0c;也让各行各业看到了一片光明的发展前景。尤其是越来越多的文旅景区开始引入VR全景技术&#xff0c;相较于以往的静态风景图&#xff0c;显然现在的VR全景结合了动态图像和声音更加吸引人。 VR全景技术正在逐步改…

【rust简单工具理解】

1.map方法 map这个闭包的本质就是映射 let numbers vec![1, 2, 3, 4, 5]; let numbers_f64: Vec<f64> numbers.into_iter().map(|&x| x as f64).collect(); println!("{:?}", numbers_f64); // 输出: [1.0, 2.0, 3.0, 4.0, 5.0]2.and_then and_then …

基于Google Gemini 探索大语言模型在医学领域应用评估和前景

概述 近年来&#xff0c;大规模语言模型&#xff08;LLM&#xff09;在理解和生成人类语言方面取得了显著的飞跃&#xff0c;这些进步不仅推动了语言学和计算机编程的发展&#xff0c;还为多个领域带来了创新的突破。特别是模型如GPT-3和PaLM&#xff0c;它们通过吸收海量文本…

Unity 获取指定文件夹及其子文件夹下所有文件的方法

在Unity中&#xff0c;我们可以使用System.IO命名空间中的Directory和File类来获取指定文件夹及其子文件夹下的所有文件。 一、只获取文件夹下所有文件&#xff1a; using System.Collections.Generic; using System.IO; using UnityEngine;public class FileScanner : MonoB…