[Perf] Use torch compile to fuse pack topk in trtllm moe#37695
[Perf] Use torch compile to fuse pack topk in trtllm moe#37695wzhao18 wants to merge 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: wzhao18 <[email protected]>
|
|
||
|
|
||
| @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) | ||
| def _pack_topk_ids_weights( |
There was a problem hiding this comment.
can this have a better function name that makes it clear it is for trtllm routed moe?
There was a problem hiding this comment.
Was thinking about it just now. Can you check the updated?
dc254ad to
9e3fe99
Compare
Signed-off-by: wzhao18 <[email protected]>
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization for Mixture-of-Experts layers by using torch.compile to fuse the packing of top-k expert IDs and weights. The logic is encapsulated in a new _pack_topk_ids_weights utility function, which is then used to refactor the packing logic in trtllm_fp8_moe.py and trtllm_nvfp4_moe.py. This change not only improves performance as demonstrated by the benchmarks but also increases code robustness by ensuring correct data types during the packing operation. The implementation is clean and the provided test results confirm its correctness and performance benefits.
| return is_torch_equal_or_newer("2.9") | ||
|
|
||
|
|
||
| @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) |
There was a problem hiding this comment.
I'd recommend to mark as dynamic only really dynamic variables.
There was a problem hiding this comment.
@zou3519 I am not an expert on this. Could you comment on the best practice?
Purpose
Fuse the packing topk ids and weights op in trtllm MoE using torch.compile. From benchmarking, this gives ~2% speedup in Minimax M2.5 TP=2 Concurrency 64 1K/1K across FP8 and NVFP4.
Test Plan
Use minimax m2.5 fp8 and nvfp4 (private checkpoint) for accuracy and performance testing.
Test Result
FP8:
NVFP4:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.