[GraphTrainer][AutoDev] Fuse RMSNorm kernels via regional Inductor compilation#3132
Open
SherlockNoMad wants to merge 4 commits intomainfrom
Open
[GraphTrainer][AutoDev] Fuse RMSNorm kernels via regional Inductor compilation#3132SherlockNoMad wants to merge 4 commits intomainfrom
SherlockNoMad wants to merge 4 commits intomainfrom
Conversation
SherlockNoMad
commented
Apr 28, 2026
SherlockNoMad
commented
Apr 28, 2026
SherlockNoMad
commented
Apr 28, 2026
Contributor
Author
|
Benchmark llama3_8b, compare MFU before and after this pass. Show proof of work. |
SherlockNoMad
added a commit
that referenced
this pull request
Apr 28, 2026
…rget Address review feedback on PR #3132: 1. Use node.target (torch.ops.aten._fused_rms_norm.default and _fused_rms_norm_backward.default) to identify RMSNorm nodes instead of module_fqn matching. This is simpler and more direct since the traced graph contains the fused ops, not decomposed small ATen ops. 2. Remove _collect_rmsnorm_fqns and the rmsnorm_fqns parameter entirely. 3. Tag getitem user nodes of each fused norm op (they extract the tuple outputs and must be in the same compiled region). 4. Fix incorrect comment that claimed RMSNorm decomposes into small ops (pow, mean, rsqrt, mul) — it actually appears as _fused_rms_norm. 5. Update tests to construct graphs with _fused_rms_norm targets and verify getitem users are tagged correctly.
SherlockNoMad
added a commit
that referenced
this pull request
Apr 29, 2026
…rget Address review feedback on PR #3132: 1. Use node.target (torch.ops.aten._fused_rms_norm.default and _fused_rms_norm_backward.default) to identify RMSNorm nodes instead of module_fqn matching. This is simpler and more direct since the traced graph contains the fused ops, not decomposed small ATen ops. 2. Remove _collect_rmsnorm_fqns and the rmsnorm_fqns parameter entirely. 3. Tag getitem user nodes of each fused norm op (they extract the tuple outputs and must be in the same compiled region). 4. Fix incorrect comment that claimed RMSNorm decomposes into small ops (pow, mean, rsqrt, mul) — it actually appears as _fused_rms_norm. 5. Update tests to construct graphs with _fused_rms_norm targets and verify getitem users are tagged correctly.
865e09a to
ea63cde
Compare
SherlockNoMad
added a commit
that referenced
this pull request
Apr 29, 2026
…rget Address review feedback on PR #3132: 1. Use node.target (torch.ops.aten._fused_rms_norm.default and _fused_rms_norm_backward.default) to identify RMSNorm nodes instead of module_fqn matching. This is simpler and more direct since the traced graph contains the fused ops, not decomposed small ATen ops. 2. Remove _collect_rmsnorm_fqns and the rmsnorm_fqns parameter entirely. 3. Tag getitem user nodes of each fused norm op (they extract the tuple outputs and must be in the same compiled region). 4. Fix incorrect comment that claimed RMSNorm decomposes into small ops (pow, mean, rsqrt, mul) — it actually appears as _fused_rms_norm. 5. Update tests to construct graphs with _fused_rms_norm targets and verify getitem users are tagged correctly.
ea63cde to
aec61bf
Compare
Contributor
Author
|
add test for run-to-run bitwise deterministic. |
RMSNorm decomposes into small ATen ops (pow, mean, rsqrt, mul) that launch many separate CUDA kernels. This adds a graph pass that tags RMSNorm nodes for regional Inductor compilation, fusing each norm into a single Triton kernel to close the performance gap vs Megatron's fused TransformerEngine RMSNorm kernels. The pass identifies RMSNorm nodes via their module_fqn metadata, assigns each distinct norm instance its own inductor_region for independent compilation, and is registered in the compile_time_passes pipeline right before regional_inductor_pass.
…rget Address review feedback on PR #3132: 1. Use node.target (torch.ops.aten._fused_rms_norm.default and _fused_rms_norm_backward.default) to identify RMSNorm nodes instead of module_fqn matching. This is simpler and more direct since the traced graph contains the fused ops, not decomposed small ATen ops. 2. Remove _collect_rmsnorm_fqns and the rmsnorm_fqns parameter entirely. 3. Tag getitem user nodes of each fused norm op (they extract the tuple outputs and must be in the same compiled region). 4. Fix incorrect comment that claimed RMSNorm decomposes into small ops (pow, mean, rsqrt, mul) — it actually appears as _fused_rms_norm. 5. Update tests to construct graphs with _fused_rms_norm targets and verify getitem users are tagged correctly.
…s_changing_optim Add test_numerics_changing_optim_run_to_run to all four test classes (Llama3, DSv3, Llama3 FlexAttn, DSv3 FlexAttn) to verify that two GraphTrainer runs with numerics_changing_optim=True produce bitwise identical loss, model weights, and gradients. Also add test_precompile_vs_trace to the Flex* classes, remove test_precompile_vs_eager (redundant with precompile_vs_trace + trace_vs_eager), and remove unnecessary _set_deterministic() calls from tests where setUp already handles it.
aec61bf to
7c05a1d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
annotate_rmsnorm_for_regional_inductor_pass, a graph pass that tags RMSNorm forward and backward ops for regional Inductor compilation viacompile_with_inductornode metadata_fused_rms_norm/_fused_rms_norm_backwardops in the traced graph. Compiling each norm region with Inductor fuses them into a single Triton kernel, closing the performance gap vs Megatron's fused TransformerEngine RMSNormnode.target(simple and direct) and tags theirgetitemusers so the full fused op is compiled togetherHow it works
annotate_rmsnorm_for_regional_inductor_passwalks the traced graph and tags nodes whosenode.targetis_fused_rms_norm.defaultor_fused_rms_norm_backward.defaultgetitemusers of each fused norm op are also tagged (they extract the tuple outputs)regional_inductor_passcompiles tagged regions with Inductor'sstandalone_compileCapabilityBasedPartitionerinregional_inductorautomatically separates disconnected norm subgraphs into distinct compiled partitionsmax_autotune,coordinate_descent_tuning,wrap_inductor_compiled_regionsSample TLP: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp27Bxjy/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
Microbenchmark
Isolated RMSNorm fwd+bwd (single layer, H100):
Benchmark
Llama3 8B, 8×H100, FSDP4+TP2,
c4_test, 20 steps (last step steady-state):mainDelta: +0.36% MFU, +0.9% tps, no memory regression.
Numerics
This may change numerics slightly (Inductor kernel numerics vs native ATen fused ops), which is expected and acceptable.
Test plan
TestAnnotateRMSNormForRegionalInductorPasscovering: