【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块FusionNet网络解析

news/2024/12/2 17:58:39/

【视频分割】【深度学习】MiVOS官方Pytorch代码–Propagation模块FusionNet网络解析

MiVOS模型将交互到掩码和掩码传播分离,从而实现更高的泛化性和更好的性能。单独训练的交互模块将用户交互转换为对象掩码,传播模块使用一种新的top-k过滤策略在读取时空存储器时进行临时传播,本博客将讲解Propagation(用户交互产生分割图)模块的深度网络代码,Propagation模块封装了PropagationNet和FusionNet模型。

文章目录

  • 【视频分割】【深度学习】MiVOS官方Pytorch代码--Propagation模块FusionNet网络解析
  • 前言
  • AttentionMemory
  • FusionNet类
  • fuse_one_frame
  • 总结


前言

在详细解析MiVOS代码之前,首要任务是成功运行MiVOS代码【win10下参考教程】,后续学习才有意义。
本博客讲解Propagation模块的深度网络(FusionNet)代码,不再复述其他功能模块代码。
MiVOS原论文中关于Fusion Module的示意图:

关键帧是用户在某一帧有交互行为,传播帧是根据这些交互行为而需要改变的帧。


AttentionMemory

注意力区域:在model/propagation/prop_net.py文件内
pos_mask和neg_mask是关键帧的mask与当前传播帧上次的mask之间进行算术运算操作得到的"差异",attn_memory(AttentionMemory)方法通过Memory key特征和Query key特征计算得到weight map(权重图),然后pos_mask和neg_mask做加权获得pos_map和neg_map。

class AttentionMemory(nn.Module):def __init__(self, k):super().__init__()self.k = kdef forward(self, mk, qk): """T=1 only. Only needs to obtain W"""B, CK, _, H, W = mk.shapemk = mk.view(B, CK, H*W) mk = torch.transpose(mk, 1, 2)          # B * HW * CKqk = qk.view(1, CK, H*W).expand(B, -1, -1) / math.sqrt(CK)  # B * CK * HWaffinity = torch.bmm(mk, qk)            # B * HW * HWaffinity = F.softmax(affinity, dim=1)return affinity

pos_mask和neg_mask分别做加权获得新的pos_map和neg_map后拼接。

def get_W(self, mk16, qk):W = self.attn_memory(mk16, qk)return W      
def get_attention(self, mk16, pos_mask, neg_mask, qk16):b, _, h, w = pos_mask.shapenh = h//16nw = w//16W = self.get_W(mk16, qk16)pos_map = (F.interpolate(pos_mask, size=(nh,nw), mode='area').view(b, 1, nh*nw) @ W)neg_map = (F.interpolate(neg_mask, size=(nh,nw), mode='area').view(b, 1, nh*nw) @ W)attn_map = torch.cat([pos_map, neg_map], 1)attn_map = attn_map.reshape(b, 2, nh, nw)attn_map = F.interpolate(attn_map, mode='bilinear', size=(h,w), align_corners=False)return attn_map

weight map(权重图)是关键帧的Memory key 和当前传播的帧Query key矩阵相乘计算而来,而后加权到pos_mask和neg_mask获得pos_map和neg_map。PropagationNet也有一部类似的操作,注意区分。

FusionNet类

在model/fusion_net.py内

class FusionNet(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(9, 32, kernel_size=3, padding=1, stride=1),nn.ReLU(),)self.conv2 = nn.Sequential(nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),nn.ReLU(),nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),)self.conv3 = nn.Sequential(nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),nn.ReLU(),nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),)self.relu = nn.ReLU()self.final_conv = nn.Conv2d(32, 1, kernel_size=3, padding=1, stride=1)def forward(self, im, seg1, seg2, attn, time):'''Args:im: 原始图片seg1: 当前传播帧上次生成的maskseg2: PropagationNet生成的当前传播帧maskattn: 注意力区域time: 时间Returns:'''h, w = im.shape[-2:]time = time.unsqueeze(2).unsqueeze(2)time = time.expand(-1, -1, h, w)x = torch.cat([im, seg1, seg2, attn, time], 1)x = self.conv1(x)r = self.conv2(x)x = self.relu(x + r)r = self.conv3(x)x = self.relu(x + r)x = self.final_conv(x)return x

网络结构如下图所示:

fuse_one_frame

在inference_core.py内
时间相关其实就是看当前传播帧离前向传播和反向传播的终点的距离,现在有了原始图片、当前传播帧上次的mask、PropagationNet输出传播帧的mask和注意力区域就能通过fuse_net(FusionNet)融合出传播帧此次的mask。

def fuse_one_frame(self, tc, tr, ti, prev_mask, curr_mask, mk16, qk16):assert(tc<ti<tr or tr<ti<tc)    # 必须在符合的传播范围内prob = torch.zeros((self.k, 1, self.nh, self.nw), dtype=torch.float32, device=self.device)nc = abs(tc-ti) / abs(tc-tr)nr = abs(tr-ti) / abs(tc-tr)# 时间相关dist = torch.FloatTensor([nc, nr]).to(self.device).unsqueeze(0)for k in range(1, self.k+1):# 注意力位置attn_map = self.prop_net.get_attention(mk16[k-1:k], self.pos_mask_diff[k:k+1], self.neg_mask_diff[k:k+1], qk16)# 融合过程w = torch.sigmoid(self.fuse_net(self.get_image_buffered(ti), prev_mask[k:k+1].to(self.device), curr_mask[k:k+1].to(self.device), attn_map, dist))prob[k-1] = w return aggregate_wbg(prob, keep_bg=True)

总结

尽可能简单、详细的介绍MiVOS中Propagation模块中FusionNet网络的代码。后续会讲解MiVOS的训练。


http://www.ppmy.cn/news/38912.html

相关文章

【观察】坚持科技创新,天翼云铸牢数字中国关键底座

毫无疑问&#xff0c;今天“算力就是生产力”已成为业界共识&#xff0c;特别是算力作为数字经济时代的关键生产力要素&#xff0c;已成为了挖掘数据要素价值&#xff0c;推动数字经济发展的核心支撑力和驱动力。但也要看到&#xff0c;随着数据空前地增长和扩张&#xff0c;加…

【从零开始学习 UVM】11.4、UVM Register Layer —— UVM Register Model 实战项目(RAL实战,交通灯为例)

文章目录 DesignInterfaceRegister Model ExampleRegister EnvironmentAPB Agent ExampleTestbench EnvironmentSequencesTest在之前的几篇文章中,我们已经了解了寄存器模型是什么以及如何使用它来访问给定设计中的寄存器。现在让我们看一个完整的例子,展示如何为给定设计编写…

创新之路永不止步,看亚马逊云科技“Serverless First”进阶之路

科技云报道原创。 2009年&#xff0c;加州大学伯克利分校一个研究团队以独特视角发布了一篇文献&#xff0c;正式定义了云计算。自此&#xff0c;千行百业的IT基础设施开启上云之路。 2019年&#xff0c;该研究团队在《Cloud Programming Simplified》预言&#xff1a;“Serv…

过程控制系统中的模块技术MTP

在过程自动化行业中&#xff0c;模块化设备概念近年来越来越受欢迎。其中最热门的是MTP。MTP称为模块类型封装&#xff0c;它是过程工业自动化技术用户协会&#xff08;NAMUR&#xff09;提出的过程自动化行业的模块化标准&#xff0c;通过这种模型&#xff0c;开发工作的重点从…

C++中运算符new的深入讲解

目录一 new运算符语法二 new 如何工作&#xff08;理解运算符new 和函数operator new&#xff09;三 operator new函数四 一个综合性的例子五 使用new可能导致的内存泄漏写这篇文章的原因是&#xff1a;有一次&#xff0c;我见到了类似下面的代码&#xff0c;我感到很惊奇&…

miniprogram-to-uniapp使用指南(各种小程序项目转换为uni-app项目)

小程序分类&#xff1a;uni-app qq小程序 支付宝小程序 百度小程序 钉钉小程序 微信小程序 小程序转成uni_app 小程序转为uni_app 小程序转uni_app 小程序转换 工具现在支持npm全局库、HBuilderX插件两种方式使用&#xff0c;任君选择&#xff0c;HBuilderX插件地址&#xff1a…

PDF怎么加密?11 款最好的 PDF 加密软件

保护您的商业文档和机密数据免受黑客攻击和欺诈的愿望可能成为寻找最佳 PDF 加密软件的重要动机。此类程序可防止未经授权的入侵者访问您的数据。 黑客可以访问电子文档和表格&#xff0c;尤其是当您共享它们时。这意味着如果您不保护您的 PDF&#xff0c;您将面临很大的风险。…

Himall商城BillingApplication获取店铺财务总览、根据日期获取该日期的结算周期

/// <summary> /// 获取店铺财务总览 /// </summary> /// <param name"shopId"></param> /// <returns></returns> public static ShopBillingIndex GetShopBillingIndex(long shopId)…