Skip to content

Commit f6e0d42

Browse files
authored
Fix aux loss scale when CP is enabled. (#2237)
1 parent e2199af commit f6e0d42

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

megatron/core/pipeline_parallel/schedules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def forward_step_calc_loss(
268268
if config.calculate_per_token_loss:
269269
MoEAuxLossAutoScaler.set_loss_scale(loss_scale)
270270
else:
271-
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
271+
MoEAuxLossAutoScaler.set_loss_scale(loss_scale * cp_group_size / num_microbatches)
272272

273273
# Set the loss scale for Multi-Token Prediction (MTP) loss.
274274
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:

tests/unit_tests/transformer/moe/test_aux_loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_seq_aux_loss(self, tp_size, ep_size, cp_size):
331331
not torch.cuda.is_available() or not HAVE_ROUTER_FUSION,
332332
reason="CUDA or TE fused router ops not available",
333333
)
334-
@pytest.mark.parametrize("aux_type", ["aux_loss", "seq_aux_loss"])
334+
@pytest.mark.parametrize("aux_type", ["aux_loss", "seq_aux_loss", "global_aux_loss"])
335335
def test_aux_loss_fusion_equivalence(self, aux_type):
336336
# Compare fused vs unfused aux loss path to ensure numerical equivalence
337337
router_ref = self.new_router(
@@ -350,6 +350,7 @@ def test_aux_loss_fusion_equivalence(self, aux_type):
350350
loss_name_map = {
351351
"aux_loss": "load_balancing_loss",
352352
"seq_aux_loss": "seq_load_balancing_loss",
353+
"global_aux_loss": "global_load_balancing_loss",
353354
}
354355
loss_name = loss_name_map[aux_type]
355356

0 commit comments

Comments
 (0)