Skip to content

Commit 4779436

Browse files
oleksostclaudebigximik
authored
Vllm modelling cleanup (#483)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: bigximik <denisko@live.com>
1 parent f424d2c commit 4779436

3 files changed

Lines changed: 199 additions & 273 deletions

File tree

apriel2-vllm-plugin/pyproject.toml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
[build-system]
2+
requires = ["setuptools>=64"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "apriel2-vllm-plugin"
7+
version = "0.1.0"
8+
description = "Standalone vLLM plugin for Apriel2 models (extracted from Fast-LLM)"
9+
requires-python = ">=3.12"
10+
dependencies = [
11+
"torch",
12+
"transformers",
13+
"einops",
14+
]
15+
16+
[project.entry-points."vllm.general_plugins"]
17+
apriel2 = "fast_llm_external_models.apriel2.vllm.config_convertor:register"
18+
19+
[tool.setuptools.packages.find]
20+
where = [".."]
21+
include = [
22+
"fast_llm_external_models",
23+
"fast_llm_external_models.apriel2",
24+
"fast_llm_external_models.apriel2.vllm",
25+
]
26+
27+
[tool.setuptools.package-dir]
28+
"" = ".."

fast_llm_external_models/apriel2/vllm/config_convertor.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
"""
1010

1111
from vllm import ModelRegistry
12+
from vllm.model_executor.models.config import (
13+
MODELS_CONFIG_MAP,
14+
HybridAttentionMambaModelConfig,
15+
MambaModelConfig,
16+
VerifyAndUpdateConfig,
17+
)
1218
from vllm.transformers_utils.model_arch_config_convertor import (
1319
MODEL_ARCH_CONFIG_CONVERTORS,
1420
ModelArchConfigConvertorBase,
@@ -70,6 +76,39 @@ def get_head_size(self) -> int:
7076
return self._get_first_attention_block().get("head_size", 0)
7177

7278

79+
class Apriel2ModelConfig(VerifyAndUpdateConfig):
80+
"""Config handler for Apriel2 models with heterogeneous mixer types.
81+
82+
Apriel2 can be pure-attention, pure-mamba, or hybrid (attention + mamba)
83+
depending on the decoder config. vLLM's default ``is_hybrid`` dispatch
84+
calls ``HybridAttentionMambaModelConfig`` which crashes for pure-mamba
85+
models (``ZeroDivisionError`` when ``num_kv_heads=0``).
86+
87+
This handler inspects ``layers_block_type`` on the HF config to determine
88+
the model composition and routes to the correct config handler.
89+
"""
90+
91+
@staticmethod
92+
def verify_and_update_config(vllm_config) -> None:
93+
hf_config = vllm_config.model_config.hf_config
94+
layer_types = getattr(hf_config, "layers_block_type", None)
95+
96+
if layer_types is None:
97+
# Fallback: no layer type info — assume standard transformer.
98+
return
99+
100+
has_attention = any(t == "attention" for t in layer_types)
101+
has_mamba = any(t == "mamba" for t in layer_types)
102+
103+
if has_attention and has_mamba:
104+
# Hybrid: attention + mamba page size alignment required.
105+
HybridAttentionMambaModelConfig.verify_and_update_config(vllm_config)
106+
elif has_mamba:
107+
# Pure mamba: enable FULL_AND_PIECEWISE, set mamba_block_size.
108+
MambaModelConfig.verify_and_update_config(vllm_config)
109+
# Pure attention: no special config needed.
110+
111+
73112
def register():
74113
"""Register Apriel2 models and config convertors with vLLM.
75114
@@ -130,7 +169,7 @@ def register():
130169
# Best-effort only; vLLM can still proceed with the generic config.
131170
pass
132171

133-
# Register model class
172+
# Register model class and config handler.
134173
# Note: some exported checkpoints may list "Apriel2ForConditionalGeneration"
135174
# in config.json's "architectures". vLLM's model selection is driven by that
136175
# field, so we alias it to the same vLLM implementation for text-only usage.
@@ -139,3 +178,8 @@ def register():
139178
arch,
140179
"fast_llm_external_models.apriel2.vllm:Apriel2ForCausalLM",
141180
)
181+
# Register in MODELS_CONFIG_MAP so vLLM calls our handler instead of
182+
# relying on the is_hybrid class attribute dispatch (which can't handle
183+
# models that are sometimes hybrid, sometimes pure-mamba).
184+
if arch not in MODELS_CONFIG_MAP:
185+
MODELS_CONFIG_MAP[arch] = Apriel2ModelConfig

0 commit comments

Comments
 (0)