测试环境:
anaconda3+python3.10
pip list
Package Version
------------------ ------------
attrs 24.3.0
Automat 24.8.1
buildtools 1.0.6
causal-conv1d 1.1.1
certifi 2024.12.14
cffi 1.15.0
charset-normalizer 3.4.1
colorama 0.4.6
constantly 23.10.4
docopt 0.6.2
einops 0.8.0
filelock 3.16.1
fsspec 2024.12.0
furl 2.1.3
greenlet 3.1.1
huggingface-hub 0.27.0
hyperlink 21.0.0
idna 3.10
incremental 24.7.2
Jinja2 3.1.5
mamba_ssm 1.1.3
MarkupSafe 3.0.2
mpmath 1.3.0
networkx 3.4.2
ninja 1.11.1.3
numpy 1.24.1
orderedmultidict 1.0.1
packaging 24.2
pillow 11.0.0
pip 24.2
pycparser 2.22
python-dateutil 2.9.0.post0
PyYAML 6.0.2
redo 3.0.0
regex 2024.11.6
requests 2.32.3
safetensors 0.4.5
setuptools 68.2.2
simplejson 3.19.3
six 1.17.0
SQLAlchemy 2.0.36
sympy 1.13.3
tokenizers 0.21.0
tomli 2.2.1
torch 2.1.1+cu118
torchaudio 2.1.1+cu118
torchvision 0.16.1+cu118
tqdm 4.67.1
transformers 4.47.1
triton 2.1.0
Twisted 24.11.0
typing_extensions 4.12.2
urllib3 2.3.0
wheel 0.44.0
zope.interface 7.2
测试代码:
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print('success')
运行结果: