基于Pytorch框架的深度学习MODNet网络精细人像分割系统源码

server/2024/9/24 20:05:11/

 第一步:准备数据

人像精细分割数据,可分割出头发丝,为PPM-100开源数据

第二步:搭建模型

MODNet网络结构如图所示,主要包含3个部分:semantic estimation(S分支)、detail prediction(D分支)、semantic-detail fusion(F分支)。

网络结构简单描述一下:

输入一幅图像I,送入三个模块:S、D、F;
S模块:在低分辨率分支进行语义估计,在backbone最后一层输出接上e-ASPP得到语义feature map Sp;
D模块:在高分辨率分支进行细节预测,通过融合来自低分辨率分支的信息得到细节feature map Dp;
F模块:融合来自低分辨率分支和高分辨率分支的信息,得到alpha matte ap;
对S、D、F模块,均使用来自GT的显式监督信息进行监督训练。

第三步:代码

1)损失函数为:L2损失

2)网络代码:

import torch
import torch.nn as nn
import torch.nn.functional as Ffrom .backbones import SUPPORTED_BACKBONES#------------------------------------------------------------------------------
#  MODNet Basic Modules
#------------------------------------------------------------------------------class IBNorm(nn.Module):""" Combine Instance Norm and Batch Norm into One Layer"""def __init__(self, in_channels):super(IBNorm, self).__init__()in_channels = in_channelsself.bnorm_channels = int(in_channels / 2)self.inorm_channels = in_channels - self.bnorm_channelsself.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)def forward(self, x):bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())return torch.cat((bn_x, in_x), 1)class Conv2dIBNormRelu(nn.Module):""" Convolution + IBNorm + ReLu"""def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, with_ibn=True, with_relu=True):super(Conv2dIBNormRelu, self).__init__()layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)]if with_ibn:       layers.append(IBNorm(out_channels))if with_relu:layers.append(nn.ReLU(inplace=True))self.layers = nn.Sequential(*layers)def forward(self, x):return self.layers(x) class SEBlock(nn.Module):""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf """def __init__(self, in_channels, out_channels, reduction=1):super(SEBlock, self).__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, int(in_channels // reduction), bias=False),nn.ReLU(inplace=True),nn.Linear(int(in_channels // reduction), out_channels, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()w = self.pool(x).view(b, c)w = self.fc(w).view(b, c, 1, 1)return x * w.expand_as(x)#------------------------------------------------------------------------------
#  MODNet Branches
#------------------------------------------------------------------------------class LRBranch(nn.Module):""" Low Resolution Branch of MODNet"""def __init__(self, backbone):super(LRBranch, self).__init__()enc_channels = backbone.enc_channelsself.backbone = backboneself.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)def forward(self, img, inference):enc_features = self.backbone.forward(img)enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]enc32x = self.se_block(enc32x)lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)lr16x = self.conv_lr16x(lr16x)lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)lr8x = self.conv_lr8x(lr8x)pred_semantic = Noneif not inference:lr = self.conv_lr(lr8x)pred_semantic = torch.sigmoid(lr)return pred_semantic, lr8x, [enc2x, enc4x] class HRBranch(nn.Module):""" High Resolution Branch of MODNet"""def __init__(self, hr_channels, enc_channels):super(HRBranch, self).__init__()self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)self.conv_hr4x = nn.Sequential(Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),)self.conv_hr2x = nn.Sequential(Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),)self.conv_hr = nn.Sequential(Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),)def forward(self, img, enc2x, enc4x, lr8x, inference):img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)enc2x = self.tohr_enc2x(enc2x)hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))enc4x = self.tohr_enc4x(enc4x)hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))pred_detail = Noneif not inference:hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)hr = self.conv_hr(torch.cat((hr, img), dim=1))pred_detail = torch.sigmoid(hr)return pred_detail, hr2xclass FusionBranch(nn.Module):""" Fusion Branch of MODNet"""def __init__(self, hr_channels, enc_channels):super(FusionBranch, self).__init__()self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)self.conv_f = nn.Sequential(Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),)def forward(self, img, lr8x, hr2x):lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)lr4x = self.conv_lr4x(lr4x)lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)f = self.conv_f(torch.cat((f, img), dim=1))pred_matte = torch.sigmoid(f)return pred_matte#------------------------------------------------------------------------------
#  MODNet
#------------------------------------------------------------------------------class MODNet(nn.Module):""" Architecture of MODNet"""def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):super(MODNet, self).__init__()self.in_channels = in_channelsself.hr_channels = hr_channelsself.backbone_arch = backbone_archself.backbone_pretrained = backbone_pretrainedself.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)self.lr_branch = LRBranch(self.backbone)self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)for m in self.modules():if isinstance(m, nn.Conv2d):self._init_conv(m)elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):self._init_norm(m)if self.backbone_pretrained:self.backbone.load_pretrained_ckpt()                def forward(self, img, inference):pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)pred_matte = self.f_branch(img, lr8x, hr2x)return pred_semantic, pred_detail, pred_mattedef freeze_norm(self):norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]for m in self.modules():for n in norm_types:if isinstance(m, n):m.eval()continuedef _init_conv(self, conv):nn.init.kaiming_uniform_(conv.weight, a=0, mode='fan_in', nonlinearity='relu')if conv.bias is not None:nn.init.constant_(conv.bias, 0)def _init_norm(self, norm):if norm.weight is not None:nn.init.constant_(norm.weight, 1)nn.init.constant_(norm.bias, 0)

第四步:搭建GUI界面

第五步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码见:基于Pytorch框架的深度学习MODNet网络精细人像分割系统源码

有问题可以私信或者留言,有问必答

e0420b8919fe4c1cb3d1e3dd52176a8a.png


http://www.ppmy.cn/server/121482.html

相关文章

python自学笔记

python部分总结 主要记录的是python与之前学的语言的不同之处 函数总结 首字母大写: name.title() 删除右边空格(暂时):name.rstrip() 删除左边空格(暂时):name.lstrip() 删除前缀(暂时):name.removeprefi…

计算机网络的性能指标

【王道的书没有视频里讲的详细,这里把视频里的课件和笔记扒下来以供复习】 主机的网卡速率上限表示该主机的接收数据/上传数据均不能超过这个阈值。 E g . \rm Eg. Eg. 游戏延迟,反映的就是“手机——服务器”之间的往返时延。

学习记录:js算法(四十三):翻转二叉树

文章目录 翻转二叉树我的思路网上思路递归栈 总结 翻转二叉树 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点 图一: 图二: 示例 1:(如图一) 输入:root [4,2,7,1…

Gateway--服务网关

网关简介 大家都知道在微服务架构中,一个系统会被拆分为很多个微服务。那么作为客户端要如何去调用 这么多的微服务呢?如果没有网关的存在,我们只能在客户端记录每个微服务的地址,然后分别去调用。 这样的架构,会存在…

SpringCloud~

帮你轻松入门SpringCloud~ 1 微服务概述 1.1什么是微服务 如idea中使用maven建立的一个个moudle,它具体是使用SpringBoot开发的一个小模块,专业的事交给专业的模块来做,每个模块完成一个具体的任务或功能。 1.2 什么是微服务架构 它将单一应用…

Java 入门指南:JVM(Java虚拟机)垃圾回收机制 —— 垃圾回收算法

文章目录 垃圾回收机制垃圾判断算法引用计数法可达性分析算法虚拟机栈中的引用(方法的参数、局部变量等)本地方法栈中 JNI 的引用类静态变量运行时常量池中的常量 垃圾收集算法Mark-Sweep(标记-清除)算法Copying(标记-…

SQLAlchemy思维导图

SQLAlchemy思维导图 创建一个 SQLAlchemy 的思维导图可以帮助你理解其核心概念、组件及其工作方式。虽然我不能直接绘制图形,但我可以提供一个文本格式的思维导图结构,方便你在任何思维导图工具中创建。 SQLAlchemy 思维导图结构 SQLAlchemy ├── 1. ORM (对象关系映射)…

基于PHP的电脑线上销售系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于phpMySQL的电脑线上销售系…