Skip to content

[MoE][SAC] Use deterministic ops in MoE routing#3146

Open
songhappy wants to merge 1 commit intopytorch:mainfrom
songhappy:xpu-moe-determinism
Open

[MoE][SAC] Use deterministic ops in MoE routing#3146
songhappy wants to merge 1 commit intopytorch:mainfrom
songhappy:xpu-moe-determinism

Conversation

@songhappy
Copy link
Copy Markdown

MoE routing uses torch.histc and relies on torch.topk being recomputed under selective activation checkpointing. Neither op is deterministic on all backends (notably Intel XPU), so SAC recompute can produce different expert assignments than the original forward, silently corrupting gradients.

This PR:

Replaces torch.histc with torch.bincount in TokenChoiceTopKRouter and TokenReorderer — bincount is deterministic and equivalent for integer indices.
Adds aten.topk.default to the SAC save list so its output is reused on recompute.
Both changes are no-ops on backends where these ops are already deterministic (e.g. CUDA).

Changes
torchtitan/models/common/moe/moe.py: histc → bincount (×2)
torchtitan/distributed/activation_checkpoint.py: save aten.topk.default
tests/unit_tests/test_moe_routing.py: routing-count and recompute-consistency tests

Testing
python -m unittest tests.unit_tests.test_moe_routing — 3/3 pass on CUDA and Intel XPU. DeepSeek-V3 training on CUDA/XPU now runs cleanly under SAC.

Risk
No API changes. bincount and histc produce identical integer counts for valid inputs (covered by tests). Saving topk adds negligible memory.

…lity

Two related changes that fix non-deterministic MoE routing under selective

activation checkpointing on backends where torch.histc / torch.topk are not

guaranteed deterministic (e.g. XPU):

1. Replace torch.histc with torch.bincount in TokenChoiceTopKRouter and

   TokenReorderer. histc can produce different counts between the forward

   and recompute passes on some backends, while bincount is deterministic

   and functionally equivalent for integer expert indices.

2. Add aten.topk.default to the SAC save list. topk can also be

   non-deterministic on recompute on some backends. Saving its output

   (top-k scores + indices per token) is cheap and guarantees stable

   expert assignments across forward and recompute.

Both changes are no-ops on backends where these ops are already

deterministic, and avoid silent gradient corruption on those that aren't.

Signed-off-by: guoqiong song <guoqiong.song@intel.com>
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 28, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 28, 2026

The following ciflow label(s) have been added but CI has not been triggered yet because the workflows are awaiting approval:

  • ciflow/8gpu

Once a maintainer approves the workflows (scroll to the bottom of the PR page), the corresponding CI jobs will be triggered automatically. Please ping one of the reviewers if you do not have access to approve and run workflows.

@tianyu-l tianyu-l requested a review from acisseJZhong April 29, 2026 00:04
@acisseJZhong
Copy link
Copy Markdown
Contributor

acisseJZhong commented Apr 29, 2026

for the torch.hist -> bincount change, wondering could you measure speed / mfu before and after?

another approach we've been seeing used is torch.scatter(zeros, -1, topk_ids, 1) to replace torch.hist, curious to learn what's the performance comparison between these three options.

# FlexAttention (torch.ops.higher_order.flex_attention is the same object)
torch._higher_order_ops.flex_attention,
torch.ops.aten.linear.default,
# topk can be non-deterministic on some backends; save to keep MoE
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering did you try turning on torch determinism mode? is torch.topk not deterministic even under determinism mode?

wondering is it possible to only add to the list for non-deterministic backend(e.g. XPU) instead of doing this for all backends.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants