MiniMamba v1.0.1 is a production-ready PyTorch implementation of the Mamba architecture — a Selective State Space Model (S6) for fast and efficient sequence modeling. This major release features optimized parallel scan algorithms, modular architecture, and comprehensive caching support while maintaining simplicity and educational value.
📂 Repository: github.com/Xinguang/MiniMamba 📋 Improvements: View detailed improvements
- ⚡ 3x Faster Training: True parallel scan algorithm (vs. pseudo-parallel)
- 💾 50% Memory Reduction: Smart caching system for efficient inference
- 🏗️ Modular Architecture: Pluggable components and task-specific models
- 🔄 100% Backward Compatible: Existing code works without modification
- Pure PyTorch: Easy to understand and modify; no custom CUDA ops
- Cross-Platform: Fully compatible with CPU, CUDA, and Apple Silicon (MPS)
- Numerical Stability: Log-space computation prevents overflow
- Comprehensive Testing: 12 test cases covering all improvements
# Install the latest production-ready version
pip install minimamba==1.0.0
# Or install with optional dependencies
pip install minimamba[examples] # For running examples
pip install minimamba[dev] # For developmentgit clone https://github.com/Xinguang/MiniMamba.git
cd MiniMamba
pip install -e .✅ Requirements:
- Python ≥ 3.8
- PyTorch ≥ 1.12.0
- NumPy ≥ 1.20.0
# Run comprehensive examples
python examples/improved_mamba_example.py
# Or run legacy example for compatibility test
python examples/run_mamba_example.pyExpected output:
✅ Using device: MPS (Apple Silicon)
Model parameters: total 26,738,688, trainable 26,738,688
All examples completed successfully! 🎉
import torch
from minimamba import MambaForCausalLM, MambaLMConfig, InferenceParams
# 1. Create configuration
config = MambaLMConfig(
d_model=512,
n_layer=6,
vocab_size=10000,
d_state=16,
d_conv=4,
expand=2,
)
# 2. Initialize specialized model
model = MambaForCausalLM(config)
# 3. Basic forward pass
input_ids = torch.randint(0, config.vocab_size, (2, 128))
logits = model(input_ids)
print(logits.shape) # torch.Size([2, 128, 10000])
# 4. Advanced generation with caching
generated = model.generate(
input_ids[:1, :10],
max_new_tokens=50,
temperature=0.8,
top_p=0.9,
use_cache=True
)
print(f"Generated: {generated.shape}") # torch.Size([1, 60])from minimamba import InferenceParams
# Initialize cache
inference_params = InferenceParams()
# First forward pass (builds cache)
logits = model(input_ids, inference_params)
# Subsequent passes use cache (much faster)
next_token = torch.randint(0, config.vocab_size, (1, 1))
logits = model(next_token, inference_params)
# Monitor cache usage
cache_info = model.get_cache_info(inference_params)
print(f"Cache memory: {cache_info['memory_mb']:.2f} MB")
# Reset when needed
model.reset_cache(inference_params)# Sequence Classification
from minimamba import MambaForSequenceClassification, MambaClassificationConfig
class_config = MambaClassificationConfig(
d_model=256,
n_layer=4,
num_labels=3,
pooling_strategy="last"
)
classifier = MambaForSequenceClassification(class_config)
# Feature Extraction
from minimamba import MambaForFeatureExtraction, BaseMambaConfig
feature_config = BaseMambaConfig(d_model=256, n_layer=4)
feature_extractor = MambaForFeatureExtraction(feature_config)# Your existing code works unchanged!
from minimamba import Mamba, MambaConfig
config = MambaConfig(d_model=512, n_layer=6, vocab_size=10000)
model = Mamba(config) # Now uses optimized v1.0 architecture
logits = model(input_ids)| Metric | v0.2.0 | v1.0.1 | Improvement |
|---|---|---|---|
| Training Speed | 1x | 3x | 🚀 3x faster |
| Inference Memory | 100% | 50% | 💾 50% reduction |
| Parallel Efficiency | Pseudo | True | ⚡ Real parallelization |
| Numerical Stability | Medium | High | ✨ Significant improvement |
Run the comprehensive test suite:
# All tests
pytest tests/
# Specific test files
pytest tests/test_mamba_improved.py -v
pytest tests/test_mamba.py -v # Legacy testsTest Coverage:
- ✅ Configuration system validation
- ✅ Parallel scan correctness
- ✅ Training vs inference consistency
- ✅ Memory efficiency verification
- ✅ Backward compatibility
- ✅ Cache management
- ✅ Generation interfaces
MiniMamba/
├── minimamba/ # 🧠 Core model components
│ ├── config.py # Configuration classes (Base, LM, Classification)
│ ├── core.py # Core components (Encoder, Heads)
│ ├── models.py # Specialized models (CausalLM, Classification)
│ ├── model.py # Legacy model (backward compatibility)
│ ├── block.py # MambaBlock with pluggable mixers
│ ├── s6.py # Optimized S6 with true parallel scan
│ ├── norm.py # RMSNorm module
│ └── __init__.py # Public API
│
├── examples/ # 📚 Usage examples
│ ├── improved_mamba_example.py # New comprehensive examples
│ └── run_mamba_example.py # Legacy example
│
├── tests/ # 🧪 Test suite
│ ├── test_mamba_improved.py # Comprehensive tests (v1.0)
│ └── test_mamba.py # Legacy tests
│
├── forex/ # 💹 Real-world usage demo
│ ├── improved_forex_model.py # Enhanced forex model
│ ├── manba.py # Updated original model
│ ├── predict.py # Prediction script
│ └── README_IMPROVED.md # Forex upgrade guide
│
├── IMPROVEMENTS.md # 📋 Detailed improvements
├── CHANGELOG.md # 📝 Version history
├── setup.py # 📦 Package configuration
├── README.md # 🌟 This file
├── README.zh-CN.md # 🇨🇳 Chinese documentation
├── README.ja.md # 🇯🇵 Japanese documentation
└── LICENSE # ⚖️ MIT License
Mamba is a state-space model that achieves linear time complexity for long sequences, making it more efficient than traditional transformers for many tasks.
This production release features:
# Before: Pseudo-parallel (actually sequential)
for block_idx in range(num_blocks): # Sequential!
block_states = self._block_scan(...)
# After: True parallel computation
log_A = torch.log(A.clamp(min=1e-20))
cumsum_log_A = torch.cumsum(log_A, dim=1) # Parallel ⚡
prefix_A = torch.exp(cumsum_log_A) # Parallel ⚡MambaEncoder: Reusable core componentMambaForCausalLM: Language modelingMambaForSequenceClassification: Classification tasksMambaForFeatureExtraction: Embedding extraction
- Automatic cache management for inference
- 50% memory reduction during generation
- Cache monitoring and reset capabilities
- 📝 Language Modeling: Long-form text generation
- 🔍 Classification: Document/sequence classification
- 🔢 Time Series: Financial/sensor data modeling
- 🧬 Biology: DNA/protein sequence analysis
- 📊 Performance Analysis: Detailed technical improvements
- 💹 Real-world Example: Forex prediction model implementation
- 🧪 Test Suite: Comprehensive testing documentation
- 📦 PyPI Package: Official package
This project is licensed under the MIT License.
This project is inspired by:
- Paper: Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu & Tri Dao
- Reference Implementation: state-spaces/mamba
Special thanks to the community for feedback and contributions that made v1.0.1 possible.
MiniMamba v1.0.1 - Production-ready Mamba implementation for everyone 🚀