基于Pytorch框架的深度学习MODNet网络精细人像分割系统源码

news/2024/9/25 23:56:37/

 第一步:准备数据

人像精细分割数据,可分割出头发丝,为PPM-100开源数据

第二步:搭建模型

MODNet网络结构如图所示,主要包含3个部分:semantic estimation(S分支)、detail prediction(D分支)、semantic-detail fusion(F分支)。

网络结构简单描述一下:

输入一幅图像I,送入三个模块:S、D、F;
S模块:在低分辨率分支进行语义估计,在backbone最后一层输出接上e-ASPP得到语义feature map Sp;
D模块:在高分辨率分支进行细节预测,通过融合来自低分辨率分支的信息得到细节feature map Dp;
F模块:融合来自低分辨率分支和高分辨率分支的信息,得到alpha matte ap;
对S、D、F模块,均使用来自GT的显式监督信息进行监督训练。

第三步:代码

1)损失函数为:L2损失

2)网络代码:

import torch
import torch.nn as nn
import torch.nn.functional as Ffrom .backbones import SUPPORTED_BACKBONES#------------------------------------------------------------------------------
#  MODNet Basic Modules
#------------------------------------------------------------------------------class IBNorm(nn.Module):""" Combine Instance Norm and Batch Norm into One Layer"""def __init__(self, in_channels):super(IBNorm, self).__init__()in_channels = in_channelsself.bnorm_channels = int(in_channels / 2)self.inorm_channels = in_channels - self.bnorm_channelsself.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)def forward(self, x):bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())return torch.cat((bn_x, in_x), 1)class Conv2dIBNormRelu(nn.Module):""" Convolution + IBNorm + ReLu"""def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, with_ibn=True, with_relu=True):super(Conv2dIBNormRelu, self).__init__()layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)]if with_ibn:       layers.append(IBNorm(out_channels))if with_relu:layers.append(nn.ReLU(inplace=True))self.layers = nn.Sequential(*layers)def forward(self, x):return self.layers(x) class SEBlock(nn.Module):""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf """def __init__(self, in_channels, out_channels, reduction=1):super(SEBlock, self).__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, int(in_channels // reduction), bias=False),nn.ReLU(inplace=True),nn.Linear(int(in_channels // reduction), out_channels, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()w = self.pool(x).view(b, c)w = self.fc(w).view(b, c, 1, 1)return x * w.expand_as(x)#------------------------------------------------------------------------------
#  MODNet Branches
#------------------------------------------------------------------------------class LRBranch(nn.Module):""" Low Resolution Branch of MODNet"""def __init__(self, backbone):super(LRBranch, self).__init__()enc_channels = backbone.enc_channelsself.backbone = backboneself.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)def forward(self, img, inference):enc_features = self.backbone.forward(img)enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]enc32x = self.se_block(enc32x)lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)lr16x = self.conv_lr16x(lr16x)lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)lr8x = self.conv_lr8x(lr8x)pred_semantic = Noneif not inference:lr = self.conv_lr(lr8x)pred_semantic = torch.sigmoid(lr)return pred_semantic, lr8x, [enc2x, enc4x] class HRBranch(nn.Module):""" High Resolution Branch of MODNet"""def __init__(self, hr_channels, enc_channels):super(HRBranch, self).__init__()self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)self.conv_hr4x = nn.Sequential(Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),)self.conv_hr2x = nn.Sequential(Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),)self.conv_hr = nn.Sequential(Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),)def forward(self, img, enc2x, enc4x, lr8x, inference):img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)enc2x = self.tohr_enc2x(enc2x)hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))enc4x = self.tohr_enc4x(enc4x)hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))pred_detail = Noneif not inference:hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)hr = self.conv_hr(torch.cat((hr, img), dim=1))pred_detail = torch.sigmoid(hr)return pred_detail, hr2xclass FusionBranch(nn.Module):""" Fusion Branch of MODNet"""def __init__(self, hr_channels, enc_channels):super(FusionBranch, self).__init__()self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)self.conv_f = nn.Sequential(Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),)def forward(self, img, lr8x, hr2x):lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)lr4x = self.conv_lr4x(lr4x)lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)f = self.conv_f(torch.cat((f, img), dim=1))pred_matte = torch.sigmoid(f)return pred_matte#------------------------------------------------------------------------------
#  MODNet
#------------------------------------------------------------------------------class MODNet(nn.Module):""" Architecture of MODNet"""def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):super(MODNet, self).__init__()self.in_channels = in_channelsself.hr_channels = hr_channelsself.backbone_arch = backbone_archself.backbone_pretrained = backbone_pretrainedself.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)self.lr_branch = LRBranch(self.backbone)self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)for m in self.modules():if isinstance(m, nn.Conv2d):self._init_conv(m)elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):self._init_norm(m)if self.backbone_pretrained:self.backbone.load_pretrained_ckpt()                def forward(self, img, inference):pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)pred_matte = self.f_branch(img, lr8x, hr2x)return pred_semantic, pred_detail, pred_mattedef freeze_norm(self):norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]for m in self.modules():for n in norm_types:if isinstance(m, n):m.eval()continuedef _init_conv(self, conv):nn.init.kaiming_uniform_(conv.weight, a=0, mode='fan_in', nonlinearity='relu')if conv.bias is not None:nn.init.constant_(conv.bias, 0)def _init_norm(self, norm):if norm.weight is not None:nn.init.constant_(norm.weight, 1)nn.init.constant_(norm.bias, 0)

第四步:搭建GUI界面

第五步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码见:基于Pytorch框架的深度学习MODNet网络精细人像分割系统源码

有问题可以私信或者留言,有问必答

e0420b8919fe4c1cb3d1e3dd52176a8a.png


http://www.ppmy.cn/news/1530504.html

相关文章

EMQX MQTT 服务器启用 SSL/TLS 安全连接,使用8883端口

1.提前下载安装openssl 2.新建openssl文件夹打开在命令行操作 3.按照下面的操作进行 MQTT 安全 作为基于现代密码学公钥算法的安全协议,TLS/SSL 能在计算机通讯网络上保证传输安全,EMQX 内置对 TLS/SSL 的支持,包括支持单/双向认证、X.509 …

QT窗口无法激活弹出问题排查记录

问题背景 问题环境 操作系统: 银河麒麟V10SP1qt版本 : 5.12.12 碰见了一个问题应用最小化,然后激活程序窗口无法弹出 这里描述一下代码的逻辑,使用QLocalServer实现一个单例进程,具体的功能就是在已存在一个程序A进程时,再启动这个程序A,新的程序A进程会被杀死,然后激活已存…

使用 Internet 共享 (ICS) 方式分配ip

设备A使用dhcp的情况下,通过设备B分配ip并共享网络的方法。 启用网络共享(ICS)并配置 NAT Windows 自带的 Internet Connection Sharing (ICS) 功能可以简化 NAT 设置,允许共享一个网络连接给其他设备。 打开网络设置&#xff1…

Maven 项目无法下载某个依赖

问题描述 在使用 Maven 构建 Java 项目时,有时会遇到无法下载特定依赖的问题。这可能会影响项目的构建过程,导致构建失败。本文档旨在提供一系列步骤,帮助开发者定位并解决此类问题。 常见原因 依赖仓库配置错误:项目的 pom.xm…

使用Docker和Macvlan驱动程序模拟跨主机跨网段通信

以下是使用Docker和Macvlan驱动程序模拟跨主机跨网段通信的架构图: #mermaid-svg-b7wuGoTr6eQYSNHJ {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-b7wuGoTr6eQYSNHJ .error-icon{fill:#552222;}#mermai…

xpath应用大全

一、xpath在爬虫中的应用 1、/div 表示从根节点开始选取div节点 2、/span 表示从根节点开始选取span节点 3、//a 表示选取文档中所有a节点而不考虑其位置 4、class 表示选取名为class的属性 5、 . 表示选取当前节点 6、 .. 表示选取当前节点的父节点 7、/div/a 表示从根…

香港服务器PING测试有什么作用?

PING测试是一种常用的网络诊断工具,用于测试计算机与服务器之间的网络连通性和响应时间。对于香港服务器,进行PING测试有以下几个作用: 香港服务器PING测试的作用包括: 检查网络连通性:PING测试可以帮助确定从本地计算…

DataX实战:从MongoDB到MySQL的数据迁移--修改源码并测试打包

在现代数据驱动的业务环境中,数据迁移和集成是常见的需求。DataX,作为阿里云开源的数据集成工具,提供了强大的数据同步能力,支持多种数据源和目标端。本文将介绍如何使用DataX将数据从MongoDB迁移到MySQL。 环境准备 安装MongoDB…