Unet项目解析(6): 图像分块、整合 / 数据对齐、网络输出转成图像

news/2024/12/1 20:32:29/

项目GitHub主页:https://github.com/orobix/retina-unet

参考论文:Retina blood vessel segmentation with a convolution neural network (U-net)


1. 训练数据

1.1 训练图像、训练金标准随机分块

主代码:

# 训练集太少,采用分块的方法进行训练
def get_data_training(DRIVE_train_imgs_original,  #训练图像路径DRIVE_train_groudTruth,     #金标准图像路径patch_height,patch_width,N_subimgs,inside_FOV):train_imgs_original = load_hdf5(DRIVE_train_imgs_original)train_masks = load_hdf5(DRIVE_train_groudTruth) #visualize(group_images(train_imgs_original[0:20,:,:,:],5),'imgs_train').show() train_imgs = my_PreProc(train_imgs_original) # 图像预处理 归一化等train_masks = train_masks/255.train_imgs = train_imgs[:,:,9:574,:]   # 图像裁剪 size=565*565train_masks = train_masks[:,:,9:574,:] # 图像裁剪 size=565*565data_consistency_check(train_imgs,train_masks) # 训练图像和金标准图像一致性检查assert(np.min(train_masks)==0 and np.max(train_masks)==1) #金标准图像 2类 0-1print ("\n train images/masks shape:")print (train_imgs.shape)print ("train images range (min-max): " +str(np.min(train_imgs)) +' - '+str(np.max(train_imgs)))print ("train masks are within 0-1\n")# 从整张图像中-随机提取-训练子块patches_imgs_train, patches_masks_train =extract_random(train_imgs,train_masks,patch_height,patch_width,N_subimgs,inside_FOV)data_consistency_check(patches_imgs_train, patches_masks_train) # 训练图像子块和金标准图像子块一致性检查print ("\n train PATCHES images/masks shape:")print (patches_imgs_train.shape)print ("train PATCHES images range (min-max): " +str(np.min(patches_imgs_train)) +' - '+str(np.max(patches_imgs_train)))return patches_imgs_train, patches_masks_train

随机提取子块:

# 训练集图像 随机 提取子块
def extract_random(full_imgs,full_masks, patch_h,patch_w, N_patches, inside=True):if (N_patches%full_imgs.shape[0] != 0): # 检验每张图像应该提取多少块print "N_patches: plase enter a multiple of 20"exit()assert (len(full_imgs.shape)==4 and len(full_masks.shape)==4)  # 张量尺寸检验assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3)  # 通道检验assert (full_masks.shape[1]==1)   # 通道检验assert (full_imgs.shape[2] == full_masks.shape[2] and full_imgs.shape[3] == full_masks.shape[3]) # 尺寸检验patches = np.empty((N_patches,full_imgs.shape[1],patch_h,patch_w)) # 训练图像总子块patches_masks = np.empty((N_patches,full_masks.shape[1],patch_h,patch_w)) # 训练金标准总子块img_h = full_imgs.shape[2]  img_w = full_imgs.shape[3] patch_per_img = int(N_patches/full_imgs.shape[0])  # 每张图像中提取的子块数量print ("patches per full image: " +str(patch_per_img))iter_tot = 0   # 图像子块总量计数器for i in range(full_imgs.shape[0]):  # 遍历每一张图像k=0 # 每张图像子块计数器while k <patch_per_img:x_center = random.randint(0+int(patch_w/2),img_w-int(patch_w/2)) # 块中心的范围y_center = random.randint(0+int(patch_h/2),img_h-int(patch_h/2))if inside==True:if is_patch_inside_FOV(x_center,y_center,img_w,img_h,patch_h)==False:continuepatch = full_imgs[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)]patch_mask = full_masks[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)]patches[iter_tot]=patch # size=[Npatches, 3, patch_h, patch_w]patches_masks[iter_tot]=patch_mask # size=[Npatches, 1, patch_h, patch_w]iter_tot +=1   # 子块总量计数器k+=1  # 每张图像子块总量计数器return patches, patches_masks

数据一致性检查函数:

# 训练集图像 和 金标准图像一致性检验
def data_consistency_check(imgs,masks):assert(len(imgs.shape)==len(masks.shape))assert(imgs.shape[0]==masks.shape[0])assert(imgs.shape[2]==masks.shape[2])assert(imgs.shape[3]==masks.shape[3])assert(masks.shape[1]==1)assert(imgs.shape[1]==1 or imgs.shape[1]==3)

1.2 训练金标准改写成Une输出形式

# 将金标准图像改写成模型输出形式
def masks_Unet(masks): # size=[Npatches, 1, patch_height, patch_width]assert (len(masks.shape)==4)assert (masks.shape[1]==1 )im_h = masks.shape[2]im_w = masks.shape[3]masks = np.reshape(masks,(masks.shape[0],im_h*im_w)) # 单像素建模new_masks = np.empty((masks.shape[0],im_h*im_w,2)) # 二分类输出for i in range(masks.shape[0]):for j in range(im_h*im_w):if  masks[i,j] == 0:new_masks[i,j,0]=1 # 金标准图像的反转new_masks[i,j,1]=0 # 金标准图像else:new_masks[i,j,0]=0new_masks[i,j,1]=1return new_masks

2. 网络输出转换成图像子块

# 网络输出 size=[Npatches, patch_height*patch_width, 2]
def pred_to_imgs(pred, patch_height, patch_width, mode="original"):assert (len(pred.shape)==3)  assert (pred.shape[2]==2 )  # 确认是否为二分类pred_images = np.empty((pred.shape[0],pred.shape[1]))  #(Npatches,height*width)if mode=="original": # 网络概率输出for i in range(pred.shape[0]):for pix in range(pred.shape[1]):pred_images[i,pix]=pred[i,pix,1] #pred[:, :, 0] 是反分割图像输出 pred[:, :, 1]是分割输出elif mode=="threshold": # 网络概率-阈值输出for i in range(pred.shape[0]):for pix in range(pred.shape[1]):if pred[i,pix,1]>=0.5:pred_images[i,pix]=1else:pred_images[i,pix]=0else:print ("mode " +str(mode) +" not recognized, it can be 'original' or 'threshold'")exit()# 改写成(Npatches,1, height, width)pred_images = np.reshape(pred_images,(pred_images.shape[0],1, patch_height, patch_width)) return pred_images

3. 测试图像按顺序分块、预测子块重新整合成图像

3.1 测试图像分块

def get_data_testing_overlap(DRIVE_test_imgs_original, DRIVE_test_groudTruth, Imgs_to_test, # 20patch_height, patch_width, stride_height, stride_width):test_imgs_original = load_hdf5(DRIVE_test_imgs_original)test_masks = load_hdf5(DRIVE_test_groudTruth)test_imgs = my_PreProc(test_imgs_original)test_masks = test_masks/255.test_imgs = test_imgs[0:Imgs_to_test,:,:,:]test_masks = test_masks[0:Imgs_to_test,:,:,:]test_imgs = paint_border_overlap(test_imgs, patch_height, # 拓展图像 可以准确划分patch_width, stride_height, stride_width)assert(np.max(test_masks)==1  and np.min(test_masks)==0)print ("\n test images shape:")print (test_imgs.shape)print ("\n test mask shape:")print (test_masks.shape)print ("test images range (min-max): " +str(np.min(test_imgs)) +' - '+str(np.max(test_imgs)))# 按照顺序提取图像快 方便后续进行图像恢复(作者采用了overlap策略)patches_imgs_test = extract_ordered_overlap(test_imgs,patch_height,patch_width,stride_height,stride_width)print ("\n test PATCHES images shape:")print (patches_imgs_test.shape)print ("test PATCHES images range (min-max): " +str(np.min(patches_imgs_test)) +' - '+str(np.max(patches_imgs_test)))return patches_imgs_test, test_imgs.shape[2], test_imgs.shape[3], test_masks #原始大小

原始图像进行拓展填充:

def paint_border_overlap(full_imgs, patch_h, patch_w, stride_h, stride_w):assert (len(full_imgs.shape)==4)  #4D arraysassert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3)  #check the channel is 1 or 3img_h = full_imgs.shape[2]  #height of the full imageimg_w = full_imgs.shape[3] #width of the full imageleftover_h = (img_h-patch_h)%stride_h  #leftover on the h dimleftover_w = (img_w-patch_w)%stride_w  #leftover on the w dimif (leftover_h != 0):  #change dimension of img_htmp_full_imgs = np.zeros((full_imgs.shape[0],full_imgs.shape[1],img_h+(stride_h-leftover_h),img_w))tmp_full_imgs[0:full_imgs.shape[0],0:full_imgs.shape[1],0:img_h,0:img_w] = full_imgsfull_imgs = tmp_full_imgsif (leftover_w != 0):   #change dimension of img_wtmp_full_imgs = np.zeros((full_imgs.shape[0],full_imgs.shape[1],full_imgs.shape[2],img_w+(stride_w - leftover_w)))tmp_full_imgs[0:full_imgs.shape[0],0:full_imgs.shape[1],0:full_imgs.shape[2],0:img_w] = full_imgsfull_imgs = tmp_full_imgsreturn full_imgs

按顺序提取图像子块:

# 按照顺序对拓展后的图像进行子块采样
def extract_ordered_overlap(full_imgs, patch_h, patch_w,stride_h,stride_w):assert (len(full_imgs.shape)==4)  assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3)  img_h = full_imgs.shape[2]  img_w = full_imgs.shape[3] assert ((img_h-patch_h)%stride_h==0 and (img_w-patch_w)%stride_w==0)N_patches_img = ((img_h-patch_h)//stride_h+1)*((img_w-patch_w)//stride_w+1)  # 每张图像采集到的子图像N_patches_tot = N_patches_img*full_imgs.shape[0] # 测试集总共的子图像数量patches = np.empty((N_patches_tot,full_imgs.shape[1],patch_h,patch_w))iter_tot = 0   for i in range(full_imgs.shape[0]):  for h in range((img_h-patch_h)//stride_h+1):for w in range((img_w-patch_w)//stride_w+1):patch = full_imgs[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]patches[iter_tot]=patchiter_tot +=1   #totalassert (iter_tot==N_patches_tot)return patches 

3.2 对于图像子块进行复原

# [Npatches, 1, patch_h, patch_w]  img_h=new_height[588] img_w=new_width[568] stride-[10,10]
def recompone_overlap(preds, img_h, img_w, stride_h, stride_w):assert (len(preds.shape)==4)  # 检查张量尺寸assert (preds.shape[1]==1 or preds.shape[1]==3)patch_h = preds.shape[2]patch_w = preds.shape[3]N_patches_h = (img_h-patch_h)//stride_h+1 # img_h方向包括的patch_h数量N_patches_w = (img_w-patch_w)//stride_w+1 # img_w方向包括的patch_w数量N_patches_img = N_patches_h * N_patches_w # 每张图像包含的patch的数目assert (preds.shape[0]%N_patches_img==0   N_full_imgs = preds.shape[0]//N_patches_img # 全幅图像的数目full_prob = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w))full_sum = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w))k = 0 #迭代所有的子块for i in range(N_full_imgs):for h in range((img_h-patch_h)//stride_h+1):for w in range((img_w-patch_w)//stride_w+1):full_prob[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=preds[k]full_sum[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=1k+=1assert(k==preds.shape[0])assert(np.min(full_sum)>=1.0) final_avg = full_prob/full_sum # 叠加概率 / 叠加权重 : 采用了均值的方法print final_avg.shapeassert(np.max(final_avg)<=1.0)assert(np.min(final_avg)>=0.0)return final_avg

http://www.ppmy.cn/news/314621.html

相关文章

openmv学习之旅②之色块追踪算法的改善

大家好&#xff0c;我是杰杰。 实在不好意思&#xff0c;最近比较忙&#xff0c;之前说的连载现在才更新出来。 从上一篇openmv的学习中openmv学习之旅①我们可以很简单运用micropython在openmv上做我们想做的事情。 Python这个东西用起来是很简单的&#xff0c;&#xff0c;…

诺基亚智能手机与NFC功能推出

诺基亚智能手机与NFC功能推出 我收到了我长达十年的朋友威乐Makinnen谁出席推出三款全新的诺基亚智能手机分别是700,701和600 8月24日其在Symbian上运行的所有代号为百丽的电话。诺基亚700仅重96gm&#xff0c;并在110 X 50.7点x9.7毫米它不仅是诺基亚在塞班范围最小巧的智能手…

诺基亚X3兴奋

诺基亚X3兴奋 多年来&#xff0c;诺基亚一直被认为是产生对高端品质和用户友好的移动电话的领先的手机制​​造商。对于每一个诺基亚洁具你会购买&#xff0c;你会放心的无与伦比的质量和功能。 的宗旨&#xff0c;以填补客户的需求越多&#xff0c;诺基亚已经推出了另一种手机…

应用层下的人脸识别(一):图像获取

本文为大家总结了人脸识别技术在安防领域应用的完整流程&#xff0c;以及产品设计的细节。其中包括&#xff1a;如何获取最佳图像&#xff0c;如何进行设备对接等经验。 图像获取是人脸识别的第一步&#xff0c;人脸识别项目中图像来源主要依靠各类监控相机&#xff0c;图像质量…

使用手机作单反相机的遥控器

2019独角兽企业重金招聘Python工程师标准>>> 你的相机用什么方式取景&#xff1f;液晶显示器&#xff1f;光学取景器&#xff1f;还是电子取景器&#xff1f;我们今天要介绍的就是颠覆大部分人使用习惯的一种取景方式&#xff0c;用手机的液晶屏取 景。这里我们要用…

加载图片节省内存的方法

加载图片节省内存1&#xff0c;在Image Views中调整图片大小如果要在UIImageView中显示一个来自bundle的图片&#xff0c;应该保证图片的大小和UIIimageView的大小是相同的&#xff0c;在运行中缩放图片是很消耗资源的&#xff0c;特别是UIImageView嵌套在UIScrollView中的情况…

设置通知栏的背景颜色或全幅背景

1. 效果图 2. 在xml布局中添加一个背景图片 , 这里就不贴代码了 3. 在MainActivity中添加如下代码 protected void onCreate(Bundle savedInstanceState) {super.onCreate(savedInstanceState);setContentView(R.layout.activity_main);translucentStatusBar(this, false);}sta…

图像几何校正

几何校正中混淆的概念 名词描述几何校正几何畸变会给基于遥感图像的定量分析、变化检测、图像融合 、地图测量或更新等处理带来误差&#xff08;主要指二维平面坐标&#xff09;&#xff0c;所以需要针对图像的几何畸变进行校正&#xff0c;也就是几何校正。图像配准图像配准与…