mamba复现—mamba+yolov8魔改(win)

embedded/2024/9/23 3:04:18/

Mamba复现出现的问题

安装下列步骤一步步走

一、

注:若是Windows环境下python一定是3.10版本的,要不然trition无法安装

conda create -n mamba python=3.10
conda activate mamba 
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
二、安装trition

由于是先在Windows10上进行调试,然后再在linux服务器上进行跑实验,所以这里先以Windows为准,
Windows下的trition无法直接pip,需要对其源码进行修改,网上有大神编译了Win下的二进制文件的安装包,但是只适用于python3.10!!!
下载到本地后,在anacoda终端中,切换到tritan所在文件夹,输入pip install 进行安装。

pip install triton-2.0.0-cp310-cp310-win_amd64.whl
mambassm_21">三、安装causal-conv1d、mamba-ssm

causal-conv1d == 1.1.1
mamba-ssm 1.1.2
(亲测有效,有博主mamba-ssm
1.1.1,我试了会报错)

方法一:
1、causal-conv1d

由于是Windows下,所以采用源码安装,去git上下载(https://gitcode.com/Dao-AILab/causal-conv1d/tags?utm_source=csdn_github_accelerator&isLogin=1),
下载到本地后解压,然后切换到该文件下,输入pip install .进行安装,可能会出现以下报错,

User
WARNING: Ignoring invalid distribution -orch (c:\users\16786\.conda\envs\yolov8\lib\site-packages)
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting mamba-ssmDownloading https://pypi.tuna.tsinghua.edu.cn/packages/d3/12/dc792f3136fc8969ac6404f091135ab1aa9260a978a625a77a3cce5299dd/mamba_ssm-1.2.0.post1.tar.gz (34 kB)Preparing metadata (setup.py) ... errorerror: subprocess-exited-with-error× python setup.py egg_info did not run successfully.│ exit code: 1╰─> [11 lines of output]Traceback (most recent call last):File "<string>", line 2, in <module>File "<pip-setuptools-caller>", line 34, in <module>File "C:\Users\16786\AppData\Local\Temp\pip-install-qnerb67y\mamba-ssm_d5a352e22e5a430989b40813c061fa67\setup.py", line 103, in <module>raise RuntimeError(RuntimeError: mamba_ssm is only supported on CUDA 11.6 and above.  Note: make sure nvcc has a supported version by running nvcc -V.torch.__version__  = 1.13.1+cu117[end of output]note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed× Encountered error while generating package metadata.
╰─> See above for output.note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

原因有两种,
1)有时候缓存文件可能会导致安装出错。你可以尝试清理 pip 或 conda 的缓存

pip cache purge

2)由于cuda版本不对,我这边遇到的情况是版本不对,于是乎我又安装了CUDA11.8和CUDNN,就是双CUD环境,因为其他模型需要11.2的CUDA。
在这里插入图片描述
然后再输入pip install .就可以了。

之后在mamba源码 setup.py修改配置

FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
mambassm_80">2、mamba-ssm

pip install mamba-ssm,有时候会出错

方法二:
下载causal-conv1d:

Dao-AILabcausal-conv1d
mamba-ssm:
state-spacesmamba
我的causal-conv1d 正常安装了,所以这边以mamba-ssm为例,下载后
pip install mamba_ssm-1.1.1+cu118torch2.1cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
在这里插入图片描述

causal-conv1d文件下下载界面如下在这里插入图片描述
安装成功
在这里插入图片描述

也可以直接拉取Docker镜像

参考:直接使用Mamba基础环境docker镜像

此时可以进行mamba的编译了,但是会出现没有模块selective_scan_cuda,

方法一:

此时我们可以将mamba_ssm->ops/selective_scan_interface.py 的import selective_scan_cuda注释掉,然后对该文件的selective_scan_fn和mamba_inner_fn函数进行修改。

###原代码
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,return_last_state=False):"""if return_last_state is True, returns (out, last_state)last_state has shape (batch, dim, dstate). Note that the gradient of the last state isnot considered in the backward pass."""return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)def mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
##修改后的代码
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,return_last_state=False):"""if return_last_state is True, returns (out, last_state)last_state has shape (batch, dim, dstate). Note that the gradient of the last state isnot considered in the backward pass."""return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)def mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

yolov8进行魔改

这里我参考别的博主进行复现魔改,发现根本跑不动,陷入死循环,下面是该播主给的backone以及MambaLayer

# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2      # 0.  320- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4     # 1.  160- [-1, 3, MambaLayer, [128]]                # 2.  160- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8     # 3.  80- [-1, 6, MambaLayer, [256]]                # 4.  80- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16    # 5.  40- [-1, 6, MambaLayer, [512]]                # 6.  40- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32   # 7.  20- [-1, 3, MambaLayer, [1024]]               # 8.  20- [-1, 1, SPPF, [1024, 5]]  # 9            # 9.  20
class MambaLayer(nn.Module):def __init__(self, dim, d_state=16, d_conv=4, expand=2):super().__init__()self.dim = dimself.norm = nn.LayerNorm(dim)self.mamba = Mamba(d_model=dim,  # Model dimension d_modeld_state=d_state,  # SSM state expansion factord_conv=d_conv,  # Local convolution widthexpand=expand,  # Block expansion factorbimamba_type="v2",)def forward(self, x):B, C = x.shape[:2]assert C == self.dimn_tokens = x.shape[2:].numel()img_dims = x.shape[2:]x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)x_norm = self.norm(x_flat)# x_norm = x_norm.to('cuda')x_mamba = self.mamba(x_norm)out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)#out = out.to(x.device)return out**加粗样式**

然后我就按照他的进行复现,最后没能成功,感觉应该是我哪块有问题,然后我就将mamba直接和CBAM进行结合,效果好一点点吧,后面会进行更多的尝试。直接上我的代码

class MambaCBAM(nn.Module):# Convolutional Block Attention Moduledef __init__(self, c1, kernel_size=7, d_state=16, d_conv=4, expand=2):  # ch_in, kernelsprint(f"kernel_size = {kernel_size}")super().__init__()self.dim = c1self.channel_attention = ChannelAttention(c1)self.spatial_attention = SpatialAttention(kernel_size)# self.norm = nn.LayerNorm(self.dim)self.mamba = Mamba(d_model=self.dim,  # Model dimension d_modeld_state=d_state,  # SSM state expansion factord_conv=d_conv,  # Local convolution widthexpand=expand,  # Block expansion factorbimamba_type="v2",)def forward(self, x):# print(f"cbam x{x[0].size()}")cbma = self.spatial_attention(self.channel_attention(x))B, C = x.shape[:2]assert C == self.dimn_tokens = x.shape[2:].numel()  # 该行代码计算了输入张量x中获取了批量大小和通道数以外的所有维度的元素数量,即图像中的像素数或特征处理的长度。img_dims = x.shape[2:]  # 该行代码实现了输入张量x中获取批量大小和通道数除了所有维度的大小,即图像的高度和宽度。x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)if str(x.device) != 'cpu':x_mamba = self.mamba(x_flat)else:x_mamba = x_flatout= x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)return out+cbma

http://www.ppmy.cn/embedded/39694.html

相关文章

JavaScript创建日期

创建日期 在JavaScript中创建日期有四种方法 ● 使用new Date() const now new Date(); console.log(now);● 直接输入月、日、年、时间 console.log(new Date(Aug 02 2024 18:05:41));● 也可以输入年月日 console.log(new Date(December 24, 2015));● 直接按照年、月、…

校园网拨号上网环境下多开虚拟机,实现宿主机与虚拟机互通,并访问外部网络

校园网某些登录客户端只允许同一时间一台设备登录&#xff0c;因此必须使用NAT模式共享宿主机的真实IP&#xff0c;相当于访问外网时只使用宿主机IP&#xff0c;此方式通过虚拟网卡与物理网卡之间的数据转发实现访问外网及互通 经验证&#xff0c;将centos的物理地址与主机物理…

Linux基础之git与调试工具gdb

目录 一、git的简单介绍和使用方法 1.1 git的介绍 1.2 git的使用方法 1.2.1 三板斧之git add 1.2.2 三板斧之git commit 1.2.3 三板斧之git push 二、gdb的介绍和一些基本使用方法 2.1 背景介绍 2.2 基本的使用方法 一、git的简单介绍和使用方法 1.1 git的介绍 Git是一…

MySQL 数据库中 Insert 语句的锁机制

在数据库系统中&#xff0c;Insert 语句是常用的操作之一&#xff0c;用于向数据库表中插入新的数据记录。然而&#xff0c;当多个会话&#xff08;或者线程&#xff09;同时对同一张表执行 Insert 操作时&#xff0c;可能会引发一些并发控制的问题&#xff0c;特别是涉及到锁的…

YOLOv5-7.0改进(四)添加EMA注意力机制

前言 关于网络中注意力机制的改进有很多种&#xff0c;本篇内容从EMA注意力机制开始&#xff01; 往期回顾 YOLOv5-7.0改进&#xff08;一&#xff09;MobileNetv3替换主干网络 YOLOv5-7.0改进&#xff08;二&#xff09;BiFPN替换Neck网络 YOLOv5-7.0改进&#xff08;三&…

万兆POE网络变压器90W的性能和作用

万兆POE网络变压器GX82405SP-90W是一种应用于网络设备的电力供应器件&#xff0c;它结合了数据传输和电力供应功能&#xff0c;可以为PoE&#xff08;Power over Ethernet&#xff09;设备提供高功率供电。它的性能和作用主要包括&#xff1a; 1. 高功率供电&#xff1a;万兆P…

读写备份寄存器BKP与实时时钟RTC

文章目录 读写备份寄存器接线图代码 RTC实时时钟接线图代码 读写备份寄存器 接线图 即接个3.3v的电源到VBT引脚 代码 代码效果&#xff1a;第一次写入备份寄存器&#xff0c;下载程序后再注释掉&#xff0c;再进行下载&#xff0c;之前写入的数据还会保存在备份寄存器中&am…

Midjourney Imagine API 申请及使用

Midjourney Imagine API 申请及使用 申请流程 要使用 Midjourney Imagine API&#xff0c;首先可以到 Midjourney Imagine API 页面点击「Acquire」按钮&#xff0c;获取请求所需要的凭证&#xff1a; 如果你尚未登录或注册&#xff0c;会自动跳转到登录页面邀请您来注册和登…