Skip to content

[GraphTrainer][AutoDev] Fuse RMSNorm kernels via regional Inductor compilation#3132

Open
SherlockNoMad wants to merge 4 commits intomainfrom
graph_trainer/rmsnorm-regional-inductor
Open

[GraphTrainer][AutoDev] Fuse RMSNorm kernels via regional Inductor compilation#3132
SherlockNoMad wants to merge 4 commits intomainfrom
graph_trainer/rmsnorm-regional-inductor

Conversation

@SherlockNoMad
Copy link
Copy Markdown
Contributor

@SherlockNoMad SherlockNoMad commented Apr 28, 2026

Summary

  • Adds annotate_rmsnorm_for_regional_inductor_pass, a graph pass that tags RMSNorm forward and backward ops for regional Inductor compilation via compile_with_inductor node metadata
  • RMSNorm appears as fused _fused_rms_norm / _fused_rms_norm_backward ops in the traced graph. Compiling each norm region with Inductor fuses them into a single Triton kernel, closing the performance gap vs Megatron's fused TransformerEngine RMSNorm
  • Identifies RMSNorm nodes by node.target (simple and direct) and tags their getitem users so the full fused op is compiled together

How it works

  1. annotate_rmsnorm_for_regional_inductor_pass walks the traced graph and tags nodes whose node.target is _fused_rms_norm.default or _fused_rms_norm_backward.default
  2. getitem users of each fused norm op are also tagged (they extract the tuple outputs)
  3. The existing regional_inductor_pass compiles tagged regions with Inductor's standalone_compile
  4. The CapabilityBasedPartitioner in regional_inductor automatically separates disconnected norm subgraphs into distinct compiled partitions
  5. Inductor configs match FlexAttention: max_autotune, coordinate_descent_tuning, wrap_inductor_compiled_regions

Sample TLP: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp27Bxjy/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Microbenchmark

Isolated RMSNorm fwd+bwd (single layer, H100):

Path ms/iter (run 1) ms/iter (run 2)
Eager 26.13 25.73
make_fx (uncompiled) 21.99 22.03
regional_inductor 14.35 14.17
Speedup vs eager 1.82x 1.82x

Benchmark

Llama3 8B, 8×H100, FSDP4+TP2, c4_test, 20 steps (last step steady-state):

Branch tps tflops MFU Memory
main 6,538 378.66 38.29% 30.92 GiB
this PR 6,600 382.26 38.65% 30.92 GiB

Delta: +0.36% MFU, +0.9% tps, no memory regression.

Numerics

This may change numerics slightly (Inductor kernel numerics vs native ATen fused ops), which is expected and acceptable.

Test plan

  • 8 CPU unit tests in TestAnnotateRMSNormForRegionalInductorPass covering:
    • Correct tagging of fused RMSNorm nodes and getitem users
    • Non-RMSNorm nodes are not tagged
    • Multiple norms all tagged
    • Forward and backward both tagged
    • Empty graph is a no-op
    • None compile config produces empty annotation
    • Config dict propagation
  • Bitwise deterministic tests pass (RMSNorm Inductor compilation patched to no-op for eager-vs-traced comparisons since the pass intentionally changes numerics)
  • GPU benchmark comparison showing fused kernel performance

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 28, 2026
Comment thread torchtitan/experiments/graph_trainer/passes.py Outdated
Comment thread torchtitan/experiments/graph_trainer/passes.py Outdated
Comment thread torchtitan/experiments/graph_trainer/passes.py Outdated
@SherlockNoMad
Copy link
Copy Markdown
Contributor Author

SherlockNoMad commented Apr 28, 2026

Benchmark llama3_8b, compare MFU before and after this pass. Show proof of work.

NGPU=8 MODULE=graph_trainer.llama3 CONFIG=graph_trainer_llama3_8b ./run_train.sh \
    --compile.mode aot_fx_trace \
    --parallelism.data_parallel_shard_degree=4 \
    --parallelism.tensor_parallel_degree=2 \
    --dataloader.dataset c4_test \
    --metrics.no-enable_tensorboard \
    --profiler.no-enable_profiling \
    --comm.trace_buf_size=0 \
    --training.steps 20

SherlockNoMad added a commit that referenced this pull request Apr 28, 2026
…rget

Address review feedback on PR #3132:

1. Use node.target (torch.ops.aten._fused_rms_norm.default and
   _fused_rms_norm_backward.default) to identify RMSNorm nodes instead
   of module_fqn matching. This is simpler and more direct since the
   traced graph contains the fused ops, not decomposed small ATen ops.

2. Remove _collect_rmsnorm_fqns and the rmsnorm_fqns parameter entirely.

3. Tag getitem user nodes of each fused norm op (they extract the tuple
   outputs and must be in the same compiled region).

4. Fix incorrect comment that claimed RMSNorm decomposes into small ops
   (pow, mean, rsqrt, mul) — it actually appears as _fused_rms_norm.

5. Update tests to construct graphs with _fused_rms_norm targets and
   verify getitem users are tagged correctly.
SherlockNoMad added a commit that referenced this pull request Apr 29, 2026
…rget

Address review feedback on PR #3132:

1. Use node.target (torch.ops.aten._fused_rms_norm.default and
   _fused_rms_norm_backward.default) to identify RMSNorm nodes instead
   of module_fqn matching. This is simpler and more direct since the
   traced graph contains the fused ops, not decomposed small ATen ops.

2. Remove _collect_rmsnorm_fqns and the rmsnorm_fqns parameter entirely.

3. Tag getitem user nodes of each fused norm op (they extract the tuple
   outputs and must be in the same compiled region).

4. Fix incorrect comment that claimed RMSNorm decomposes into small ops
   (pow, mean, rsqrt, mul) — it actually appears as _fused_rms_norm.

5. Update tests to construct graphs with _fused_rms_norm targets and
   verify getitem users are tagged correctly.
@SherlockNoMad SherlockNoMad force-pushed the graph_trainer/rmsnorm-regional-inductor branch from 865e09a to ea63cde Compare April 29, 2026 06:47
@SherlockNoMad SherlockNoMad marked this pull request as ready for review April 29, 2026 15:39
SherlockNoMad added a commit that referenced this pull request Apr 29, 2026
…rget

Address review feedback on PR #3132:

1. Use node.target (torch.ops.aten._fused_rms_norm.default and
   _fused_rms_norm_backward.default) to identify RMSNorm nodes instead
   of module_fqn matching. This is simpler and more direct since the
   traced graph contains the fused ops, not decomposed small ATen ops.

2. Remove _collect_rmsnorm_fqns and the rmsnorm_fqns parameter entirely.

3. Tag getitem user nodes of each fused norm op (they extract the tuple
   outputs and must be in the same compiled region).

4. Fix incorrect comment that claimed RMSNorm decomposes into small ops
   (pow, mean, rsqrt, mul) — it actually appears as _fused_rms_norm.

5. Update tests to construct graphs with _fused_rms_norm targets and
   verify getitem users are tagged correctly.
@SherlockNoMad SherlockNoMad force-pushed the graph_trainer/rmsnorm-regional-inductor branch from ea63cde to aec61bf Compare April 29, 2026 19:34
@SherlockNoMad
Copy link
Copy Markdown
Contributor Author

add test for run-to-run bitwise deterministic.

RMSNorm decomposes into small ATen ops (pow, mean, rsqrt, mul) that
launch many separate CUDA kernels. This adds a graph pass that tags
RMSNorm nodes for regional Inductor compilation, fusing each norm
into a single Triton kernel to close the performance gap vs Megatron's
fused TransformerEngine RMSNorm kernels.

The pass identifies RMSNorm nodes via their module_fqn metadata,
assigns each distinct norm instance its own inductor_region for
independent compilation, and is registered in the compile_time_passes
pipeline right before regional_inductor_pass.
…rget

Address review feedback on PR #3132:

1. Use node.target (torch.ops.aten._fused_rms_norm.default and
   _fused_rms_norm_backward.default) to identify RMSNorm nodes instead
   of module_fqn matching. This is simpler and more direct since the
   traced graph contains the fused ops, not decomposed small ATen ops.

2. Remove _collect_rmsnorm_fqns and the rmsnorm_fqns parameter entirely.

3. Tag getitem user nodes of each fused norm op (they extract the tuple
   outputs and must be in the same compiled region).

4. Fix incorrect comment that claimed RMSNorm decomposes into small ops
   (pow, mean, rsqrt, mul) — it actually appears as _fused_rms_norm.

5. Update tests to construct graphs with _fused_rms_norm targets and
   verify getitem users are tagged correctly.
…s_changing_optim

Add test_numerics_changing_optim_run_to_run to all four test classes
(Llama3, DSv3, Llama3 FlexAttn, DSv3 FlexAttn) to verify that two
GraphTrainer runs with numerics_changing_optim=True produce bitwise
identical loss, model weights, and gradients.

Also add test_precompile_vs_trace to the Flex* classes, remove
test_precompile_vs_eager (redundant with precompile_vs_trace +
trace_vs_eager), and remove unnecessary _set_deterministic() calls
from tests where setUp already handles it.
@SherlockNoMad SherlockNoMad force-pushed the graph_trainer/rmsnorm-regional-inductor branch from aec61bf to 7c05a1d Compare April 30, 2026 04:45
@SherlockNoMad SherlockNoMad added the ciflow/h100.8 Trigger H100.8 CI label Apr 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100.8 Trigger H100.8 CI ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant