Skip to content

[Perf] Use torch compile to fuse pack topk in trtllm moe#37695

Open
wzhao18 wants to merge 2 commits intovllm-project:mainfrom
wzhao18:wzhao/fuse-trtllm-moe-pack-topk
Open

[Perf] Use torch compile to fuse pack topk in trtllm moe#37695
wzhao18 wants to merge 2 commits intovllm-project:mainfrom
wzhao18:wzhao/fuse-trtllm-moe-pack-topk

Conversation

@wzhao18
Copy link
Contributor

@wzhao18 wzhao18 commented Mar 20, 2026

Purpose

Fuse the packing topk ids and weights op in trtllm MoE using torch.compile. From benchmarking, this gives ~2% speedup in Minimax M2.5 TP=2 Concurrency 64 1K/1K across FP8 and NVFP4.

Test Plan

Use minimax m2.5 fp8 and nvfp4 (private checkpoint) for accuracy and performance testing.

Test Result

FP8:

vllm serve MiniMaxAI/MiniMax-M2.5 \
    --trust-remote-code \
    --stream-interval 20 \
    --tensor-parallel-size 2

vllm bench serve \
    --endpoint /v1/completions \
    --port 8000 \
    --model MiniMaxAI/MiniMax-M2.5 \
    --dataset-name random \
    --random-input 1024 \
    --random-output 1024 \
    --num-prompt 1280 \
    --max-concurrency 64 \
    --num-warmups 100  \
    --ignore-eos    \
    --trust-remote-code

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9303|±  |0.0070|
|     |       |strict-match    |     5|exact_match|↑  |0.9287|±  |0.0071|

Main:
============ Serving Benchmark Result ============
Successful requests:                     1280      
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  482.50    
Total input tokens:                      1310720   
Total generated tokens:                  1310720   
Request throughput (req/s):              2.65      
Output token throughput (tok/s):         2716.55   
Peak output token throughput (tok/s):    173.00    
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          5433.09   
---------------Time to First Token----------------
Mean TTFT (ms):                          390.04    
Median TTFT (ms):                        350.18    
P99 TTFT (ms):                           1283.75   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          23.19     
Median TPOT (ms):                        23.23     
P99 TPOT (ms):                           23.58     
---------------Inter-token Latency----------------
Mean ITL (ms):                           456.29    
Median ITL (ms):                         446.20    
P99 ITL (ms):                            1108.22   
==================================================

Branch:
============ Serving Benchmark Result ============
Successful requests:                     1280      
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  474.09    
Total input tokens:                      1310720   
Total generated tokens:                  1310720   
Request throughput (req/s):              2.70      
Output token throughput (tok/s):         2764.70   
Peak output token throughput (tok/s):    199.00    
Peak concurrent requests:                127.00    
Total token throughput (tok/s):          5529.39   
---------------Time to First Token----------------
Mean TTFT (ms):                          468.97    
Median TTFT (ms):                        373.37    
P99 TTFT (ms):                           1358.45   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          22.71     
Median TPOT (ms):                        22.73     
P99 TPOT (ms):                           23.16     
---------------Inter-token Latency----------------
Mean ITL (ms):                           446.70    
Median ITL (ms):                         437.64    
P99 ITL (ms):                            1055.59   
==================================================

NVFP4:

vllm serve MiniMax-M2.5-NVFP4  \
    --trust-remote-code \
    --stream-interval 20 \
    --tensor-parallel-size 2

vllm bench serve \
    --endpoint /v1/completions \
    --port 8000 \
    --model MiniMax-M2.5-NVFP4 \
    --dataset-name random \
    --random-input 1024 \
    --random-output 1024 \
    --num-prompt 1280 \
    --max-concurrency 64 \
    --num-warmups 100  \
    --ignore-eos    \
    --trust-remote-code

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9257|±  |0.0072|
|     |       |strict-match    |     5|exact_match|↑  |0.9212|±  |0.0074|

Main:
============ Serving Benchmark Result ============
Successful requests:                     1280      
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  318.45    
Total input tokens:                      1310720   
Total generated tokens:                  1310720   
Request throughput (req/s):              4.02      
Output token throughput (tok/s):         4115.97   
Peak output token throughput (tok/s):    254.00    
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          8231.94   
---------------Time to First Token----------------
Mean TTFT (ms):                          284.88    
Median TTFT (ms):                        231.55    
P99 TTFT (ms):                           940.15    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          15.28     
Median TPOT (ms):                        15.32     
P99 TPOT (ms):                           15.56     
---------------Inter-token Latency----------------
Mean ITL (ms):                           300.62    
Median ITL (ms):                         293.73    
P99 ITL (ms):                            737.18    
==================================================

Branch:
============ Serving Benchmark Result ============
Successful requests:                     1280      
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  312.61    
Total input tokens:                      1310720   
Total generated tokens:                  1310720   
Request throughput (req/s):              4.09      
Output token throughput (tok/s):         4192.77   
Peak output token throughput (tok/s):    261.00    
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          8385.54   
---------------Time to First Token----------------
Mean TTFT (ms):                          386.67    
Median TTFT (ms):                        413.34    
P99 TTFT (ms):                           927.91    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.90     
Median TPOT (ms):                        14.88     
P99 TPOT (ms):                           15.23     
---------------Inter-token Latency----------------
Mean ITL (ms):                           293.09    
Median ITL (ms):                         288.91    
P99 ITL (ms):                            582.89    
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.



@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def _pack_topk_ids_weights(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this have a better function name that makes it clear it is for trtllm routed moe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was thinking about it just now. Can you check the updated?

@wzhao18 wzhao18 force-pushed the wzhao/fuse-trtllm-moe-pack-topk branch from dc254ad to 9e3fe99 Compare March 20, 2026 15:25
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for Mixture-of-Experts layers by using torch.compile to fuse the packing of top-k expert IDs and weights. The logic is encapsulated in a new _pack_topk_ids_weights utility function, which is then used to refactor the packing logic in trtllm_fp8_moe.py and trtllm_nvfp4_moe.py. This change not only improves performance as demonstrated by the benchmarks but also increases code robustness by ensuring correct data types during the packing operation. The implementation is clean and the provided test results confirm its correctness and performance benefits.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 20, 2026
@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels Mar 20, 2026
@mgoin mgoin enabled auto-merge (squash) March 20, 2026 21:32
return is_torch_equal_or_newer("2.9")


@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend to mark as dynamic only really dynamic variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 I am not an expert on this. Could you comment on the best practice?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Ready

Development

Successfully merging this pull request may close these issues.

5 participants