99"""
1010
1111from vllm import ModelRegistry
12+ from vllm .model_executor .models .config import (
13+ MODELS_CONFIG_MAP ,
14+ HybridAttentionMambaModelConfig ,
15+ MambaModelConfig ,
16+ VerifyAndUpdateConfig ,
17+ )
1218from vllm .transformers_utils .model_arch_config_convertor import (
1319 MODEL_ARCH_CONFIG_CONVERTORS ,
1420 ModelArchConfigConvertorBase ,
@@ -70,6 +76,39 @@ def get_head_size(self) -> int:
7076 return self ._get_first_attention_block ().get ("head_size" , 0 )
7177
7278
79+ class Apriel2ModelConfig (VerifyAndUpdateConfig ):
80+ """Config handler for Apriel2 models with heterogeneous mixer types.
81+
82+ Apriel2 can be pure-attention, pure-mamba, or hybrid (attention + mamba)
83+ depending on the decoder config. vLLM's default ``is_hybrid`` dispatch
84+ calls ``HybridAttentionMambaModelConfig`` which crashes for pure-mamba
85+ models (``ZeroDivisionError`` when ``num_kv_heads=0``).
86+
87+ This handler inspects ``layers_block_type`` on the HF config to determine
88+ the model composition and routes to the correct config handler.
89+ """
90+
91+ @staticmethod
92+ def verify_and_update_config (vllm_config ) -> None :
93+ hf_config = vllm_config .model_config .hf_config
94+ layer_types = getattr (hf_config , "layers_block_type" , None )
95+
96+ if layer_types is None :
97+ # Fallback: no layer type info — assume standard transformer.
98+ return
99+
100+ has_attention = any (t == "attention" for t in layer_types )
101+ has_mamba = any (t == "mamba" for t in layer_types )
102+
103+ if has_attention and has_mamba :
104+ # Hybrid: attention + mamba page size alignment required.
105+ HybridAttentionMambaModelConfig .verify_and_update_config (vllm_config )
106+ elif has_mamba :
107+ # Pure mamba: enable FULL_AND_PIECEWISE, set mamba_block_size.
108+ MambaModelConfig .verify_and_update_config (vllm_config )
109+ # Pure attention: no special config needed.
110+
111+
73112def register ():
74113 """Register Apriel2 models and config convertors with vLLM.
75114
@@ -130,7 +169,7 @@ def register():
130169 # Best-effort only; vLLM can still proceed with the generic config.
131170 pass
132171
133- # Register model class
172+ # Register model class and config handler.
134173 # Note: some exported checkpoints may list "Apriel2ForConditionalGeneration"
135174 # in config.json's "architectures". vLLM's model selection is driven by that
136175 # field, so we alias it to the same vLLM implementation for text-only usage.
@@ -139,3 +178,8 @@ def register():
139178 arch ,
140179 "fast_llm_external_models.apriel2.vllm:Apriel2ForCausalLM" ,
141180 )
181+ # Register in MODELS_CONFIG_MAP so vLLM calls our handler instead of
182+ # relying on the is_hybrid class attribute dispatch (which can't handle
183+ # models that are sometimes hybrid, sometimes pure-mamba).
184+ if arch not in MODELS_CONFIG_MAP :
185+ MODELS_CONFIG_MAP [arch ] = Apriel2ModelConfig
0 commit comments