Skip to content

Commit ea63cde

Browse files
committed
[GraphTrainer] Remove inductor_region from RMSNorm annotation pass
Address review feedback: 1. Default rmsnorm_compile_config to None (was required kwarg). 2. Remove per-norm inductor_region assignment — the CapabilityBasedPartitioner in regional_inductor already separates disconnected norm subgraphs into distinct partitions. Explicit region IDs caused each norm to compile as a separate subgraph, inflating memory by ~17 GiB. Benchmark (Llama3 8B, 8×H100, FSDP4+TP2, c4_test, 20 steps): | Branch | tps | MFU | Memory | |--------------------------------|-------|--------|-----------| | main | 6,538 | 38.29% | 30.92 GiB | | branch (with inductor_region) | 6,559 | 38.41% | 47.74 GiB | | branch (without inductor_region)| 6,600 | 38.65% | 30.92 GiB |
1 parent 2e6baa8 commit ea63cde

2 files changed

Lines changed: 14 additions & 59 deletions

File tree

torchtitan/experiments/graph_trainer/passes.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def annotate_rmsnorm_for_regional_inductor_pass(
539539
gm: torch.fx.GraphModule,
540540
example_inputs: tuple | None = None,
541541
*,
542-
rmsnorm_compile_config: dict | None,
542+
rmsnorm_compile_config: dict | None = None,
543543
) -> torch.fx.GraphModule:
544544
"""Tag RMSNorm ops with compile_with_inductor for regional_inductor.
545545
@@ -548,10 +548,6 @@ def annotate_rmsnorm_for_regional_inductor_pass(
548548
users) so that ``regional_inductor_pass`` compiles each norm as a
549549
fused Inductor kernel.
550550
551-
Each ``_fused_rms_norm`` / ``_fused_rms_norm_backward`` call is
552-
assigned a separate ``inductor_region`` so that norms at different
553-
points in the graph are compiled independently.
554-
555551
Args:
556552
gm: The graph module to annotate.
557553
example_inputs: Example inputs (unused, required by pass interface).
@@ -570,31 +566,26 @@ def annotate_rmsnorm_for_regional_inductor_pass(
570566
torch.ops.aten._fused_rms_norm_backward.default,
571567
}
572568

573-
next_region_id = 0
574569
num_tagged = 0
575570

576571
for node in gm.graph.nodes:
577572
if node.op != "call_function" or node.target not in _RMSNORM_TARGETS:
578573
continue
579574

580-
annotation = dict(compile_annotation)
581-
annotation["inductor_region"] = next_region_id
582-
next_region_id += 1
583-
584-
# Tag the fused norm node itself.
585-
node.meta.setdefault("custom", {})["compile_with_inductor"] = annotation
575+
node.meta.setdefault("custom", {})["compile_with_inductor"] = compile_annotation
586576
num_tagged += 1
587577

588578
# Tag getitem users that extract outputs from the fused op.
589579
for user in node.users:
590580
if user.op == "call_function" and user.target is operator.getitem:
591-
user.meta.setdefault("custom", {})["compile_with_inductor"] = annotation
581+
user.meta.setdefault("custom", {})[
582+
"compile_with_inductor"
583+
] = compile_annotation
592584
num_tagged += 1
593585

594586
if num_tagged > 0:
595587
logger.info(
596-
f"Tagged {num_tagged} RMSNorm nodes across "
597-
f"{next_region_id} regions for regional Inductor compilation"
588+
f"Tagged {num_tagged} RMSNorm nodes for regional Inductor compilation"
598589
)
599590

600591
return gm

torchtitan/experiments/graph_trainer/tests/test_passes.py

Lines changed: 8 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,17 +1356,6 @@ def _count_tagged_nodes(self, gm):
13561356
count += 1
13571357
return count
13581358

1359-
def _get_tagged_regions(self, gm):
1360-
"""Return a dict mapping inductor_region -> list of node names."""
1361-
regions = {}
1362-
for node in gm.graph.nodes:
1363-
custom = node.meta.get("custom", {})
1364-
annotation = custom.get("compile_with_inductor")
1365-
if annotation is not None:
1366-
region = annotation.get("inductor_region")
1367-
regions.setdefault(region, []).append(node.name)
1368-
return regions
1369-
13701359
def test_tags_fused_rmsnorm_and_getitems(self):
13711360
"""_fused_rms_norm nodes and their getitem users are tagged."""
13721361
gm = self._build_rmsnorm_gm(
@@ -1401,8 +1390,8 @@ def test_does_not_tag_non_rmsnorm_nodes(self):
14011390

14021391
self.assertEqual(self._count_tagged_nodes(gm), 0)
14031392

1404-
def test_separate_regions_per_fused_norm(self):
1405-
"""Each _fused_rms_norm call gets its own inductor_region ID."""
1393+
def test_multiple_fused_norms_all_tagged(self):
1394+
"""Multiple _fused_rms_norm calls and their getitems are all tagged."""
14061395
gm = self._build_rmsnorm_gm(
14071396
[
14081397
torch.ops.aten._fused_rms_norm.default,
@@ -1416,12 +1405,8 @@ def test_separate_regions_per_fused_norm(self):
14161405
rmsnorm_compile_config={"max_autotune": True},
14171406
)
14181407

1419-
regions = self._get_tagged_regions(gm)
1420-
# Should have exactly 2 distinct regions
1421-
self.assertEqual(len(regions), 2)
1422-
# Each region: 1 fused node + 2 getitem users = 3 nodes
1423-
for region_id, node_names in regions.items():
1424-
self.assertEqual(len(node_names), 3)
1408+
# 2 fused nodes + 2*2 getitem users = 6 tagged nodes
1409+
self.assertEqual(self._count_tagged_nodes(gm), 6)
14251410

14261411
def test_backward_fused_norm_tagged(self):
14271412
"""_fused_rms_norm_backward nodes and their getitem users are tagged."""
@@ -1438,8 +1423,6 @@ def test_backward_fused_norm_tagged(self):
14381423

14391424
# 1 fused backward node + 2 getitem users = 3 tagged nodes
14401425
self.assertEqual(self._count_tagged_nodes(gm), 3)
1441-
regions = self._get_tagged_regions(gm)
1442-
self.assertEqual(len(regions), 1)
14431426

14441427
def test_no_fused_norm_is_noop(self):
14451428
"""When no fused RMSNorm ops exist, no nodes are tagged."""
@@ -1475,7 +1458,6 @@ def test_none_compile_config_tags_with_empty_annotation(self):
14751458
annotation = node.meta.get("custom", {}).get("compile_with_inductor")
14761459
if annotation is not None:
14771460
self.assertNotIn("inductor_configs", annotation)
1478-
self.assertIn("inductor_region", annotation)
14791461

14801462
def test_compile_config_propagated_to_annotation(self):
14811463
"""The compile config dict is wrapped under inductor_configs in the annotation."""
@@ -1496,8 +1478,8 @@ def test_compile_config_propagated_to_annotation(self):
14961478
if annotation is not None:
14971479
self.assertEqual(annotation["inductor_configs"], config)
14981480

1499-
def test_fwd_and_bwd_get_distinct_regions(self):
1500-
"""Forward and backward fused norms get distinct region IDs."""
1481+
def test_fwd_and_bwd_both_tagged(self):
1482+
"""Forward and backward fused norms are both tagged."""
15011483
gm = self._build_rmsnorm_gm(
15021484
[
15031485
torch.ops.aten._fused_rms_norm.default,
@@ -1511,26 +1493,8 @@ def test_fwd_and_bwd_get_distinct_regions(self):
15111493
rmsnorm_compile_config={"max_autotune": True},
15121494
)
15131495

1514-
regions = self._get_tagged_regions(gm)
1515-
self.assertEqual(len(regions), 2)
1516-
1517-
def test_multiple_fused_norms_get_distinct_regions(self):
1518-
"""Multiple fused norm calls each get their own region ID."""
1519-
gm = self._build_rmsnorm_gm(
1520-
[
1521-
torch.ops.aten._fused_rms_norm.default,
1522-
torch.ops.aten._fused_rms_norm.default,
1523-
torch.ops.aten._fused_rms_norm.default,
1524-
]
1525-
)
1526-
1527-
annotate_rmsnorm_for_regional_inductor_pass(
1528-
gm,
1529-
rmsnorm_compile_config={"max_autotune": True},
1530-
)
1531-
1532-
regions = self._get_tagged_regions(gm)
1533-
self.assertEqual(len(regions), 3)
1496+
# 2 fused nodes + 2*2 getitem users = 6 tagged nodes
1497+
self.assertEqual(self._count_tagged_nodes(gm), 6)
15341498

15351499

15361500
if __name__ == "__main__":

0 commit comments

Comments
 (0)