【DeepLearning-8】MobileViT模块配置

news/2024/12/22 19:33:59/

完整代码: 

import torch
import torch.nn as nn
from einops import rearrange
def conv_1x1_bn(inp, oup):return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),nn.SiLU())
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):return nn.Sequential(nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),nn.BatchNorm2d(oup),nn.SiLU())
class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fn # mgdef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)
class Attention(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)# mg) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b p h n d -> b p n (h d)')return self.to_out(out)
class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.SiLU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)
class UserDefined(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads, dim_head, dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn xclass IRBlock(nn.Module):def __init__(self, inp, oup, stride=1, expansion=4):super().__init__()self.stride = strideassert stride in [1, 2]hidden_dim = int(inp * expansion)self.use_res_connect = self.stride == 1 and inp == oupif expansion == 1: # 构建没有扩展层的卷积块self.conv = nn.Sequential(# 深度可分离卷积(Depthwise Convolution)nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# “线性”逐点卷积 (Pointwise-Linear Convolution)nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),)else:  # 构建包含扩展层的卷积块self.conv = nn.Sequential(# 逐点卷积 (Pointwise Convolution)nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# 深度可分离卷积 (Depthwise Convolution)nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# “线性”逐点卷积 (Pointwise-Linear Convolution)nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),)def forward(self, x):if self.use_res_connect:return x + self.conv(x)else:return self.conv(x)class MobileViTBv3(nn.Module):def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):super().__init__()self.ph, self.pw = patch_sizeself.mv01 = IRBlock(channel, channel) self.conv1 = conv_nxn_bn(channel, channel, kernel_size)self.conv3 = conv_1x1_bn(dim, channel)self.conv2 = conv_1x1_bn(channel, dim)self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)def forward(self, x):y = x.clone()x = self.conv1(x)x = self.conv2(x)z = x.clone()_, _, h, w = x.shapex = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)x = self.transformer(x)x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)x = self.conv3(x)x = torch.cat((x, z), 1)x = self.conv4(x)x = x + yx = self.mv01(x)return x

文件配置在D:\yolov5-master\models路径下


http://www.ppmy.cn/news/1335397.html

相关文章

一、对人工智能大模型了解与认知

黑8说 月黑风高,乌云密布,树木低垂,黯淡沉闷。这黎明前的风暴,预示着新时代的变革即将到来。 在一个8线小城市的办公室中 黑8对主任说: 世界上有男人、女人、人妖,米国有1/3男,2/3女…&#xff…

Python学习笔记--创建最简单的自定义异常类

在Python中,当创建一个函数时,它应该执行一些操作或返回一些值。如果函数为空,则没有实际的操作或返回值,这是不符合函数设计的初衷的。因此,在Python中,函数体不能为空,必须至少包含一个语句&a…

Stable Diffusion与Midjourney:如何做出明智之选?

Stable Diffusion与Midjourney:如何做出明智之选? 在人工智能领域中,Stable Diffusion和Midjourney是两个备受瞩目的技术。它们各自具有独特的特点和优势,但选择哪一个更适合您的需求呢?本文将为您详细分析两者的差异…

Linux使用宝塔面板一键部署Discuz论坛并实现固定地址公网访问

文章目录 前言1.安装基础环境2.一键部署Discuz3.安装cpolar工具4.配置域名访问Discuz5.固定域名公网地址6.配置Discuz论坛 前言 Crossday Discuz! Board(以下简称 Discuz!)是一套通用的社区论坛软件系统,用户可以在不需要任何编程的基础上&a…

day33_js

今日内容 0 复习昨日 1 JS概述 2 JS的引入方式 3 JS语法 3.1 变量 3.2 基本数据类型 3.3 引用类型 3.4 数组类型 3.5 日期类型 3.6 运算符(算术运算,逻辑,关系运算,三目运算) 3.7 分支 3.8 循环 3.9 函数(重点) 3 常见弹窗函数 alter,confirm,prompt 0 复习昨日 1 盒子模型 对d…

支付宝开通GPT4.0,最新经验分享

ChatGPT是由OpenAI开发的一种生成式对话模型,具有生成对话响应的能力。它是以GPT(Generative Pre-trained Transformer)为基础进行训练的,GPT是一种基于Transformer架构的预训练语言模型,被广泛用于各种自然语言处理任…

DPlayer m3u8 视频禁止下载

1. 介绍 正常的 m3u8 格式视频通过控制台是无法下载的,但是可以通过插件下载,下面介绍如何规避这个问题。 思路:后端生成一个一次性的密钥,前端放在请求头中,可以防止大部分插件下载。这里只说前端。 2. 实现 集成 …

【INTEL(ALTERA)】为什么 niosv-download 实用程序无法下载 NiosV 处理器应用程序 ELF 文件

说明 当您执行以下任务时,英特尔 Quartus Prime Pro Edition 软件版本 21.3 和 21.4 中会显示以下错误消息: 使用 niosv-download 实用程序将 Nios V 处理器应用程序 ELF 文件下载到英特尔 FPGAs。使用 openocd-cfg-gen 实用程序生成 OpenOCD 配置文件…