简化的动态稀疏视觉Transformer的PyTorch代码

ops/2025/2/13 20:49:57/

存一串代码(简化的动态稀疏视觉Transformer的PyTorch代码)


import torch 
import torch.nn  as nn 
import torch.nn.functional  as F class DynamicSparseAttention(nn.Module): def __init__(self, dim, num_heads=8, dropout=0.1): super().__init__() self.num_heads  = num_heads self.head_dim  = dim // num_heads self.scale  = self.head_dim  ** -0.5 self.qkv  = nn.Linear(dim, dim * 3, bias=False) self.attn_drop  = nn.Dropout(dropout) self.proj  = nn.Linear(dim, dim) self.proj_drop  = nn.Dropout(dropout) def forward(self, x): B, N, C = x.shape  qkv = self.qkv(x).reshape(B,  N, 3, self.num_heads,  self.head_dim).permute(2,  0, 3, 1, 4) q, k, v = qkv.unbind(0)  attn = (q @ k.transpose(-2,  -1)) * self.scale  attn = attn.softmax(dim=-1)  attn = self.attn_drop(attn)  x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x)  x = self.proj_drop(x)  return x class HierarchicalRoutingBlock(nn.Module): def __init__(self, dim, num_heads=8, mlp_ratio=4., dropout=0.1): super().__init__() self.norm1  = nn.LayerNorm(dim) self.attn  = DynamicSparseAttention(dim, num_heads, dropout) self.norm2  = nn.LayerNorm(dim) self.mlp  = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), nn.GELU(), nn.Dropout(dropout), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(dropout) ) def forward(self, x): x = x + self.attn(self.norm1(x))  x = x + self.mlp(self.norm2(x))  return x class DynamicSparseVisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, num_classes=1000, dim=768, num_heads=8, depth=12, mlp_ratio=4., dropout=0.1): super().__init__() self.patch_embed  = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size) self.pos_embed  = nn.Parameter(torch.zeros(1,  (img_size // patch_size) ** 2, dim)) self.dropout  = nn.Dropout(dropout) self.blocks  = nn.ModuleList([HierarchicalRoutingBlock(dim, num_heads, mlp_ratio, dropout) for _ in range(depth)]) self.norm  = nn.LayerNorm(dim) self.head  = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x): x = self.patch_embed(x).flatten(2).transpose(1,  2) x = x + self.pos_embed  x = self.dropout(x)  for blk in self.blocks:  x = blk(x) x = self.norm(x)  x = x[:, 0] x = self.head(x)  return x # 使用 
model = DynamicSparseVisionTransformer() 
x = torch.randn(1,  3, 224, 224) 
output = model(x) 
print(output.shape)  

代码解释
DynamicSparseAttention:实现动态稀疏注意力模块。
HierarchicalRoutingBlock:实现层次化路由块,包含注意力模块和多层感知机。
DynamicSparseVisionTransformer:实现完整的动态稀疏视觉Transformer模型,包括补丁嵌入、位置嵌入、层次化路由块和分类头。


http://www.ppmy.cn/ops/158131.html

相关文章

http 与 https 的区别?

HTTP(超文本传输协议)和 HTTPS(安全超文本传输协议)是互联网通信的基础协议。随着网络技术的发展和安全需求的提升,HTTPS变得越来越重要。本文将深入探讨HTTP与HTTPS之间的区别,包括其工作原理、安全性、性能、应用场景及未来发展等。 1. HTTP与HTTPS的基本概念 1.1 HT…

从MyBatis-Plus看Spring Boot自动配置原理

一、问题引入&#xff1a;神秘的配置生效之谜 当我们使用MyBatis-Plus时&#xff0c;只需在pom.xml中添加依赖&#xff1a; <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3…

Vision Transformer:打破CNN垄断,全局注意力机制重塑计算机视觉范式

目录 引言 一、ViT模型的起源和历史 二、什么是ViT&#xff1f; 图像处理流程 图像切分 展平与线性映射 位置编码 Transformer编码器 分类头&#xff08;Classification Head&#xff09; 自注意力机制 注意力图 三、Coovally AI模型训练与应用平台 四、ViT与图像…

App UI自动化--Appium学习--第二篇

如果第一篇在运行代码的时候出现问题&#xff0c;建议参考我的上一篇文章解决。 1、APP界面信息获取 adb logcat|grep -i displayed代码含义是获取当前应用的包名和界面名。 根据日志信息修改代码当中的包名和界面名&#xff0c;就可以跳转对应的界面。 2、界面元素获取 所…

Seaweedfs(master volume filer) docker run参数帮助文档

文章目录 进入容器后执行获取weed -h英文中文 weed server -h英文中文 weed volume -h英文中文 关键点测试了一下&#xff0c;这个-volume.minFreeSpace string有点狠&#xff0c;比如设置值为10&#xff08;10%&#xff09;&#xff0c;它直接给系统只留下10%的空间&#xff0…

python+unity落地方案实现AI 换脸融合

先上效果再说技术结论&#xff0c;使用的是自行搭建的AI人脸融合库&#xff0c;可以离线不受限制无限次生成&#xff0c;有需要的可以后台私信python ai换脸融合。 TODO 未来的方向&#xff1a;3D人脸融合和AI数据训练 这个技术使用的是openvcinsighface&#xff0c;openvc…

支持向量机原理

支持向量机&#xff08;简称SVM&#xff09;虽然诞生只有短短的二十多年&#xff0c;但是自一诞生便由于它良好的分类性能席卷了机器学习领域。如果不考虑集成学习的算法&#xff0c;不考虑特定的训练数据集&#xff0c;尤其在分类任务中表现突出。在分类算法中的表现SVM说是排…

华为云的分布式缓存服务适合什么场景

华为云的分布式缓存服务&#xff08;DCS&#xff09;适用于多种场景&#xff0c;能够有效提升系统的性能和可靠性。以下是九河云总结的其主要适用场景&#xff1a; 高并发读取场景 在电商、社交平台等高并发应用中&#xff0c;华为云DCS可以将热点数据缓存到内存中&#xff0c…