FlashInfer trtllm kernel for nvfp4 MoE model with intermediate size 704 errors out.
The model loads and works fine when doing TP=4, enable-expert-parallel in vllm serve but fails with TP only case.
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] WorkerProc failed to start.
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] Traceback (most recent call last):
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 837, in worker_main
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] worker = WorkerProc(*args, **kwargs)
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] return func(*args, **kwargs)
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 619, in __init__
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] self.worker.load_model()
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 323, in load_model
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] self.model_runner.load_model(load_dummy_weights=load_dummy_weights)
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] return func(*args, **kwargs)
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 4793, in load_model
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] self.model = model_loader.load_model(
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] return func(*args, **kwargs)
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/base_loader.py", line 80, in load
_model
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] process_weights_after_loading(model, model_config, target_device)
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/utils.py", line 107, in process_w
eights_after_loading
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] quant_method.process_weights_after_loading(module)
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/modelopt.py", line 1391, i
n process_weights_after_loading
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ) = convert_to_nvfp4_moe_kernel_format(
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py", line 346, i
n convert_to_nvfp4_moe_kernel_format
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.p
y", line 361, in prepare_nvfp4_moe_layer_for_fi_or_cutlass
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.p
y", line 221, in prepare_static_weights_for_trtllm_fp4_moe
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 121, in _maybe_get_cached_w3_w
1_permute_indices
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] permute1 = get_shuffle_matrix_sf_a_row_indices(
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] File "/usr/local/lib/python3.12/dist-packages/flashinfer/utils.py", line 873, in get_shuffle_matrix_sf_a_row_ind
ices
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] assert M % 128 == 0
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] ^^^^^^^^^^^^
(Worker_TP2 pid=459163) ERROR 04-29 07:06:45 [multiproc_executor.py:870] AssertionError
Failure
FlashInfer trtllm kernel for nvfp4 MoE model with intermediate size 704 errors out.
The model loads and works fine when doing
TP=4, enable-expert-parallelin vllm serve but fails with TP only case.Error trace
Things tried
Added a PR and test in vLLM to invoke the FlashInfer TRT-LLM kernel when the model has GeLU activation, otehrwise it falls back to Marlin. [Kernel][MoE] Support GELU on TRT-LLM NvFP4 fused MoE for Gemma4 vllm-project/vllm#41050 -> please use this to reproduce the error.
I added some prints in the code inside the process_weights_after_loading method and these are the shapes that can help with reproducing the issue. The ckpts come from ModelOpt in nvfp4.- https://github.com/vllm-project/vllm/blob/296741d0257107a9d0301409005c85d38bb247bc/vllm/model_executor/layers/quantization/modelopt.py#L1380
min_alignmentfrom16 --> 64,that helps in loading the model fine but I see gibberish output which likely hints an issue in the kernel side to handle the padding vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.pySimilar issue filed in vllm about not being able to serve TP only for nvfp4. vllm-project/vllm#39595