Skip to content

[Bug] Flashinfer-TRTLLM kernel fails for MoE model when using TP only (no EP), intermediate size = 704 before sharding for nvfp4 #3206

@juhi10071998

Description

@juhi10071998

Failure

FlashInfer trtllm kernel for nvfp4 MoE model with intermediate size 704 errors out.
The model loads and works fine when doing TP=4, enable-expert-parallel in vllm serve but fails with TP only case.

Error trace

(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] WorkerProc failed to start.                                                                                       
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] Traceback (most recent call last):                                                                                
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 837, in worker_main 
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     worker = WorkerProc(*args, **kwargs)                                                                          
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                          
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper                  
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     return func(*args, **kwargs)                                                                                  
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]            ^^^^^^^^^^^^^^^^^^^^^                                                                                  
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 619, in __init__    
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     self.worker.load_model()                                                                                      
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 323, in load_model            
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     self.model_runner.load_model(load_dummy_weights=load_dummy_weights)                                           
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper                  
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     return func(*args, **kwargs)                                                                                  
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]            ^^^^^^^^^^^^^^^^^^^^^                                                                                  
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 4793, in load_model     
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     self.model = model_loader.load_model(                                                                         
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]                  ^^^^^^^^^^^^^^^^^^^^^^^^                                                                         
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper                  
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     return func(*args, **kwargs)                                                                                  
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]            ^^^^^^^^^^^^^^^^^^^^^                                                                                  
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/base_loader.py", line 80, in load
_model                                                                                                                                                                                     
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     process_weights_after_loading(model, model_config, target_device)                                             
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/utils.py", line 107, in process_w
eights_after_loading                                                                                                                                                                       
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     quant_method.process_weights_after_loading(module)                                                            
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/modelopt.py", line 1391, i
n process_weights_after_loading                                                                                                                                                            
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     ) = convert_to_nvfp4_moe_kernel_format(                                                                       
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                       
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py", line 346, i
n convert_to_nvfp4_moe_kernel_format                                                                                                                                                       
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     ) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(                                                                
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.p
y", line 361, in prepare_nvfp4_moe_layer_for_fi_or_cutlass                                                                                                                                 
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(                                     
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                     
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.p
y", line 221, in prepare_static_weights_for_trtllm_fp4_moe                                                                                                                                 
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(                                                 
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                 
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 121, in _maybe_get_cached_w3_w
1_permute_indices                                                                                                                                                                          
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     permute1 = get_shuffle_matrix_sf_a_row_indices(                                                               
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                               
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]   File "/usr/local/lib/python3.12/dist-packages/flashinfer/utils.py", line 873, in get_shuffle_matrix_sf_a_row_ind
ices                                                                                                                                                                                       
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]     assert M % 128 == 0                                                                                           
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870]            ^^^^^^^^^^^^                                                                                           
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] AssertionError

Things tried

  1. Added a PR and test in vLLM to invoke the FlashInfer TRT-LLM kernel when the model has GeLU activation, otehrwise it falls back to Marlin. [Kernel][MoE] Support GELU on TRT-LLM NvFP4 fused MoE for Gemma4 vllm-project/vllm#41050 -> please use this to reproduce the error.

  2. I added some prints in the code inside the process_weights_after_loading method and these are the shapes that can help with reproducing the issue. The ckpts come from ModelOpt in nvfp4.- https://github.com/vllm-project/vllm/blob/296741d0257107a9d0301409005c85d38bb247bc/vllm/model_executor/layers/quantization/modelopt.py#L1380

Image
  1. Updated the min_alignment from 16 --> 64, that helps in loading the model fine but I see gibberish output which likely hints an issue in the kernel side to handle the padding vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py

Similar issue filed in vllm about not being able to serve TP only for nvfp4. vllm-project/vllm#39595

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions