mamba,mamba2环境搭建
mamba和mamba2安装步骤的相关代码
conda create -n mamba_test python=3.10
conda activate mamba_test
conda install cudatoolkit=11.8 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/linux-64/pip install mamba_ssm-2.2.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install causal_conv1d-1.4.0+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl pip install triton==2.1.0
pip install numpy==1.22.4
对应的whl文件的下载地址:
mamba_ssm下载
causal_conv1d下载
可以运行的mamba和mamba2测试代码:
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(# This module uses roughly 3 * expand * d_model^2 parametersd_model=dim, # Model dimension d_modeld_state=16, # SSM state expansion factord_conv=4, # Local convolution widthexpand=2, # Block expansion factor
).to("cuda")
y = model(x)
print("Mamba result", y.shape)
assert y.shape == x.shapeimport torch
from mamba_ssm import Mamba2batch, length, dim = 2, 64, 512
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(# This module uses roughly 3 * expand * d_model^2 parameters# make sure d_model * expand / headdim = multiple of 8d_model=dim, # Model dimension d_modeld_state=64, # SSM state expansion factor, typically 64 or 128d_conv=4, # Local convolution widthexpand=2, # Block expansion factorheaddim=64, # default 64
).to("cuda")
y = model(x)
print("Mamba2 result", y.shape)
assert y.shape == x.shape
可以参考的调试步骤:
Mamba-2 Error: ‘NoneType‘ object has no attribute ‘causal_conv1d_fwd‘
mamba_ssm和causal-conv1d安装教程