[MoE][SAC] Use deterministic ops in MoE routing#3146
[MoE][SAC] Use deterministic ops in MoE routing#3146songhappy wants to merge 1 commit intopytorch:mainfrom
Conversation
…lity Two related changes that fix non-deterministic MoE routing under selective activation checkpointing on backends where torch.histc / torch.topk are not guaranteed deterministic (e.g. XPU): 1. Replace torch.histc with torch.bincount in TokenChoiceTopKRouter and TokenReorderer. histc can produce different counts between the forward and recompute passes on some backends, while bincount is deterministic and functionally equivalent for integer expert indices. 2. Add aten.topk.default to the SAC save list. topk can also be non-deterministic on recompute on some backends. Saving its output (top-k scores + indices per token) is cheap and guarantees stable expert assignments across forward and recompute. Both changes are no-ops on backends where these ops are already deterministic, and avoid silent gradient corruption on those that aren't. Signed-off-by: guoqiong song <guoqiong.song@intel.com>
|
The following ciflow label(s) have been added but CI has not been triggered yet because the workflows are awaiting approval:
Once a maintainer approves the workflows (scroll to the bottom of the PR page), the corresponding CI jobs will be triggered automatically. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
for the torch.hist -> bincount change, wondering could you measure speed / mfu before and after? another approach we've been seeing used is |
| # FlexAttention (torch.ops.higher_order.flex_attention is the same object) | ||
| torch._higher_order_ops.flex_attention, | ||
| torch.ops.aten.linear.default, | ||
| # topk can be non-deterministic on some backends; save to keep MoE |
There was a problem hiding this comment.
wondering did you try turning on torch determinism mode? is torch.topk not deterministic even under determinism mode?
wondering is it possible to only add to the list for non-deterministic backend(e.g. XPU) instead of doing this for all backends.
MoE routing uses torch.histc and relies on torch.topk being recomputed under selective activation checkpointing. Neither op is deterministic on all backends (notably Intel XPU), so SAC recompute can produce different expert assignments than the original forward, silently corrupting gradients.
This PR:
Replaces torch.histc with torch.bincount in TokenChoiceTopKRouter and TokenReorderer — bincount is deterministic and equivalent for integer indices.
Adds aten.topk.default to the SAC save list so its output is reused on recompute.
Both changes are no-ops on backends where these ops are already deterministic (e.g. CUDA).
Changes
torchtitan/models/common/moe/moe.py: histc → bincount (×2)
torchtitan/distributed/activation_checkpoint.py: save aten.topk.default
tests/unit_tests/test_moe_routing.py: routing-count and recompute-consistency tests
Testing
python -m unittest tests.unit_tests.test_moe_routing— 3/3 pass on CUDA and Intel XPU. DeepSeek-V3 training on CUDA/XPU now runs cleanly under SAC.Risk
No API changes. bincount and histc produce identical integer counts for valid inputs (covered by tests). Saving topk adds negligible memory.