Skip to content

[cpu-offloading] Implement prefetching for cpu offloading pass#3166

Open
mlazos wants to merge 5 commits intogh/mlazos/2/basefrom
gh/mlazos/2/head
Open

[cpu-offloading] Implement prefetching for cpu offloading pass#3166
mlazos wants to merge 5 commits intogh/mlazos/2/basefrom
gh/mlazos/2/head

Conversation

@mlazos
Copy link
Copy Markdown
Contributor

@mlazos mlazos commented Apr 29, 2026

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 29, 2026
@mlazos
Copy link
Copy Markdown
Contributor Author

mlazos commented Apr 29, 2026

@claude can you review this and ensure that the prefetching per layer does not miss any edge cases and that the tests do not mask any behavior

@claude
Copy link
Copy Markdown

claude Bot commented Apr 29, 2026

Claude finished @mlazos's task in 2m 44s —— View job


Review: CPU Offload Prefetching Implementation

  • Read PR diff and changed files
  • Analyze prefetching logic for edge cases
  • Analyze test coverage for masked behavior
  • Post review findings

Prefetching Edge Cases

1. prefetch_offloads: Backward layer direction mismatch (cpu_offload.py:744)

The prefetch logic uses target_layer = min(layer_id + n_layers, max_layer) — this moves reloads to a higher layer ID. But in backward execution, layers run in descending order (layer N-1 → 0). So "earlier in backward" means higher layer ID, which is correct.

However, there's a subtle issue: layer_first_wait is built by iterating reload_info sorted by node_to_idx[wait] (graph position of the wait node). Since backward runs layer N-1 first, the first waits in graph order belong to the highest layers. The layer_first_wait[lid] will correctly capture the first wait for each layer (since we only set if lid not in layer_first_wait). This is correct.

2. prefetch_offloads: max_layer capping misses reloads at the highest layer (cpu_offload.py:744-745)

target_layer = min(layer_id + n_layers, max_layer)
if target_layer == layer_id or target_layer not in layer_first_wait:
    continue

When layer_id == max_layer, the condition target_layer == layer_id is true and the reload is not moved. This is correct — the highest layer's backward runs first, so there's no earlier layer to prefetch into.

But consider n_layers=2 with layers {0, 1, 2} (last layer = 2, so only layers 0 and 1 are offloaded). Reloads for layer 0 target layer min(0+2, 1) = 1, and for layer 1 target min(1+2, 1) = 1 which equals layer_id, so it's skipped. Correct. With more layers this works as expected.

3. prefetch_offloads: Potential issue with multiple reloads prepending to the same anchor (cpu_offload.py:748)

layer_first_wait[target_layer].prepend(reload_node)

When multiple reloads target the same layer_first_wait[target_layer], each prepend places the reload before the anchor. After multiple prepends, the last-processed reload ends up closest to the anchor. This means the order of prefetched reloads is determined by iteration order of reload_info (sorted by wait position). This is functionally correct — all reloads end up before the target wait — but the relative ordering among co-located reloads is reversed from their original order. Not a correctness bug (H2D transfers are async and independent), but worth noting.

4. _get_reload_layer: Fragile chain assumption after wait sinking (cpu_offload.py:675-690)

_get_reload_layer traces reload → fwd_wait → offload → original. But _sink_forward_waits_per_group moves the wait_offload_node in the graph. The key insight is that node.args references are preserved even after prepend/append — only the position in the graph changes, not the args. So fwd_wait.args[0] still points to offload_node, and the chain is intact. Correct.

However, this makes the code order-dependent: prefetch_offloads is called (step 6 in apply_cpu_offload_pass) after _sink_forward_waits_per_group (step 4). If this ordering were reversed, the _get_reload_layer would still work (since args don't change), but the position-based sort would use stale indices. The current ordering is safe.

5. _sink_forward_waits_per_group: Last group handling (cpu_offload.py:449-451)

When target_idx == idx (last group or lookahead exceeds remaining groups), waits are placed at fallback_anchor.prepend(wait_node). The fallback is the first backward node. This means last-group waits are placed just before the first backward op. Correct — this ensures D2H completes before backward begins to reload.

6. Edge case: sink_lookahead=0 in the env var path

In cpu_offload_pass (line 807):

functorch_config.activation_offload_sink_wait = sink_la == 0

When OFFLOAD_SINK_LOOKAHEAD=0, this sets sink_wait = True, which tells the partitioner to sink waits immediately (no lookahead). Meanwhile activation_offload_sink_wait_group_lookahead = 0. This looks intentional — sink_wait=True means "use the legacy global sink" while group_lookahead > 0 means "use per-group sinking with N lookahead". But verify the upstream semantics — if both are set, which takes priority?

Similarly for activation_reload_prefetch = prefetch_la == 0 (line 809). When OFFLOAD_PREFETCH_LOOKAHEAD=0, this sets prefetch = True and prefetch_group_lookahead = 0. Verify upstream interprets this as "enable prefetch but with 0 group lookahead" (i.e., no-op or immediate prefetch).

Fix this →

7. Edge case: Layers with offloadable nodes but no reloads in backward

If a layer has offloaded activations but none of its reloads appear in reload_info (e.g., because _get_reload_layer returns _NOT_IN_LAYERS for them), that layer won't appear in layer_first_wait. Other layers targeting it via target_layer not in layer_first_wait will be skipped. This is safe but could mean some reloads are never prefetched. This could happen if the graph structure doesn't match the expected reload → wait → offload → original chain shape.


Test Coverage Gaps

1. test_prefetch_moves_reloads_earlier applies apply_cpu_offload_pass twice (test_cpu_offload.py:360-361)

gm = apply_cpu_offload_pass(gm)  # inserts ops AND calls prefetch_offloads(gm, 1) by default
prefetch_offloads(gm, n_layers=1)  # applies prefetch AGAIN

apply_cpu_offload_pass already calls prefetch_offloads(gm, prefetch_lookahead) at line 663, and prefetch_lookahead defaults to 1. So the test applies prefetching twice. The second call finds the already-moved reloads and may try to move them again (they'd already be at the target layer, so target_layer == layer_id would skip them). The test still passes because the second prefetch is effectively a no-op, but it masks the behavior — it's testing the idempotency of prefetch rather than the initial prefetch behavior in isolation.

Fix this →

2. No test for n_layers > 1 prefetch

All prefetch tests use n_layers=1. There's no test verifying that n_layers=2 correctly skips an intermediate layer. With only 4 layers (and last layer skipped), there are 3 offloadable layers (0, 1, 2). n_layers=2 should move layer 0's reloads to layer 2's position. A test with num_layers=5 and n_layers=2 would better exercise the multi-layer skip.

3. No test for sink_lookahead behavior

The _sink_forward_waits_per_group function is new and has no dedicated test. The existing tests exercise it indirectly (via default sink_lookahead=1), but there's no test validating:

  • Wait nodes are correctly positioned after the target group's last offload
  • Fallback behavior when target group has no offload nodes
  • sink_lookahead > 1 behavior

4. No test for env var configuration path

The cpu_offload_pass function reads from OFFLOAD_SINK_LOOKAHEAD, OFFLOAD_PREFETCH_LOOKAHEAD, OFFLOAD_CPU_MEMORY_BUDGET_GB, and OFFLOAD_DEALLOC env vars. None of these paths are tested. A test with os.environ patching would verify the env-var-driven configuration doesn't break.

5. test_single_layer_tagged may mask behavior (test_cpu_offload.py:534-549)

The test asserts nodes are tagged with a single layer because last_layer_id = _NOT_IN_LAYERS when there's only one layer (due to max(all_layer_ids) if len(all_layer_ids) > 1 else _NOT_IN_LAYERS). This means single-layer models skip the last-layer optimization entirely and offload everything. The test comment says "last-layer skip only applies with multiple layers" — this is correct behavior, but it's worth noting that in production, single-layer models would always offload (no skip), which adds overhead without benefit since backward immediately follows forward.

6. test_prefetch_via_cpu_offload_pass doesn't verify prefetch distance (test_cpu_offload.py:407-432)

The test only checks that reloads exist and that each reload precedes its wait. It doesn't verify that reloads were actually moved earlier — a no-op prefetch would also pass this test.


Minor Issues

  1. Duplicated env var reading (cpu_offload.py:799-803 and 819-821): The env vars OFFLOAD_SINK_LOOKAHEAD, OFFLOAD_PREFETCH_LOOKAHEAD, and OFFLOAD_CPU_MEMORY_BUDGET_GB are read in both the has_should_offload and else branches. Consider hoisting these reads above the conditional to DRY up the code.

  2. prefetch_n_layers parameter is unused in the joint-graph path (cpu_offload.py:761): The cpu_offload_pass accepts prefetch_n_layers as a keyword arg but ignores it in favor of the OFFLOAD_PREFETCH_LOOKAHEAD env var. The config field cpu_offload_prefetch_n_layers in configs.py and the functools.partial in graph_utils.py:588-590 wire up this parameter, but it's never consulted. The env var always wins.

    Fix this →

  3. _has_recompute_consumer unbounded recursion (cpu_offload.py:209-222): The recursive call through view chains has no cycle detection. While view chains in practice are DAGs (a view's users are topologically later), a malformed graph could cause infinite recursion. Low risk but worth noting.


Summary

The prefetching logic is correct for standard layer-structured models. The key design — moving ao.reload earlier while keeping ao.wait_tensor in place — properly overlaps H2D transfers with backward compute without introducing data hazards.

The main actionable items are:

  1. Fix test double-prefetch: test_prefetch_moves_reloads_earlier applies prefetch twice, masking the actual behavior
  2. prefetch_n_layers param is dead code: env var always wins over the config parameter
  3. Add tests for sink_lookahead and multi-layer prefetch (n_layers > 1)

@mlazos mlazos requested a review from sanketpurandare as a code owner April 30, 2026 01:27
mlazos added a commit that referenced this pull request Apr 30, 2026
@mlazos
Copy link
Copy Markdown
Contributor Author

mlazos commented Apr 30, 2026

@claude can you review to see if we are missing any overlap opportunities or edge cases?

@claude
Copy link
Copy Markdown

claude Bot commented Apr 30, 2026

Claude finished @mlazos's task in 3m 41s —— View job


  • Read PR diff and changed files
  • Analyze prefetching logic for missed overlap opportunities
  • Analyze edge cases in the implementation
  • Post review findings

Review: Overlap Opportunities & Edge Cases

The core design is sound — decoupling ao.reload (H2D start) from ao.wait_tensor (H2D sync) and moving reloads earlier to overlap H2D with compute. I checked all the interaction paths and found one missed overlap opportunity and two edge cases worth noting.


Missed Overlap Opportunity

Highest offloaded layer gets zero prefetch benefit (cpu_offload.py:656-657)

target_layer = min(layer_id + n_layers, max_layer)
if target_layer == layer_id or target_layer not in layer_first_wait:
    continue

When layer_id == max_layer (the layer whose backward runs first among offloaded layers), the reload can't be moved earlier because there's no higher layer to target. This layer's H2D transfer adds directly to backward latency with no overlap.

Fix: Add a special case that moves the highest-layer reloads before the first backward node (or the boundary between forward and backward). This would overlap them with loss computation or the last forward layer's execution:

# After building layer_first_wait, find backward boundary
first_backward_node = next(
    (n for n in gm.graph.nodes if n.op == "call_function" and _is_backward_node(n)),
    None,
)
for reload_node, _wait_node, layer_id in reload_info:
    target_layer = min(layer_id + n_layers, max_layer)
    if target_layer == layer_id:
        # Highest layer: move before first backward op
        if first_backward_node is not None:
            first_backward_node.prepend(reload_node)
            moved += 1
        continue
    ...

This is the single biggest overlap gap — on a trace with 32 layers, the highest offloaded layer (layer 30) has zero H2D/compute overlap. For large activations, this can be 1-2ms of exposed latency per iteration.


Edge Cases

1. Positional-only scheduling dep when _find_tensor_dep returns None (cpu_offload.py:565-570)

dep_node = _find_tensor_dep(last_consumer, node_positions)
if dep_node is not None:
    wait_node.args = (*wait_node.args[:2], dep_node)
    dep_node.append(wait_node)
else:
    last_consumer.append(wait_node)

When dep_node is None (last_consumer doesn't produce a tensor and has no tensor-valued getitem), the forward wait is placed positionally after last_consumer but without a data dependency. Subsequent graph optimization passes could reorder the wait before the last consumer, freeing GPU storage prematurely.

This can occur with ops like aten._foreach_add_ or multi-output ops without getitem extraction. Low probability in standard transformer graphs, but worth a defensive assertion or fallback (e.g., searching backwards from last_consumer for a tensor-producing node).

2. Budget filtering reorders insertion, affecting co-located reload ordering (cpu_offload.py:432-445)

Step 2 sorts by tensor size descending. After filtering, offloadable is processed in size order, not graph order. When multiple offloaded tensors share the same first_consumer in backward, their reloads are stacked in size order rather than topological order. This doesn't affect correctness (H2D copies are independent), but could affect CUDA stream scheduling if the runtime processes them in graph order.

If you want deterministic scheduling behavior, re-sort offloadable back to graph order after budget filtering:

offloadable.sort(key=lambda t: node_to_index.get(t[0], 0))

3. Forward wait repositioning skips nodes with has_bwd=True — but this should never trigger (cpu_offload.py:562)

if has_bwd or not real_consumers:
    continue

After step 4 redirects ALL backward users via replace_input_with, _get_storage_chain should always return has_bwd=False for offloaded nodes. The has_bwd check is dead code in practice. If it ever triggers, it means a backward user wasn't redirected, which would be a bug in the redirect logic. Consider converting this to an assertion:

assert not has_bwd, f"Unexpected un-redirected backward user for {node.name}"

Verified as Correct

These aspects I verified are handled correctly:

  1. _get_reload_layer after step 4b modifies wait argsfwd_wait.args[0] still points to offload_node since dep is appended as third arg ✓
  2. prepend doesn't invalidate layer_first_wait references — the target node stays in place, reload is inserted before it ✓
  3. View replay arg substitutionreplay_map correctly chains multi-level views (node → view1 → view2) ✓
  4. Mixed direct + view-chain backward users — both paths converge correctly through wait_node / replay ✓
  5. _get_storage_chain walks only view edges — non-view consumers create new storage, so storage-sharing analysis is correct ✓
  6. Moving reload before forward data dependency — reload reads from wait_offload_node which is in forward (always earlier than any backward position) ✓

Summary

The implementation is correct for all standard cases. The one actionable missed overlap is the highest-layer prefetch gap — adding a "move before first backward node" fallback would eliminate exposed H2D latency for that layer. The _find_tensor_dep returning None edge case is low-probability but could be hardened with a broader search.


View job

f"Available joint passes: {list(AVAILABLE_JOINT_PASSES.keys())}"
)

# Skip inductor_decomposition - it's handled in joint_graph_builder
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this should be skipped

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need this here.

Comment thread torchtitan/experiments/graph_trainer/passes.py Outdated
Comment thread torchtitan/experiments/graph_trainer/cpu_offload.py Outdated
Comment thread torchtitan/experiments/graph_trainer/configs.py Outdated
@mlazos mlazos changed the title [cpu-offloading] Implement prefetching via env var configs [cpu-offloading] Implement prefetching for cpu offloading pass Apr 30, 2026
Comment thread torchtitan/experiments/graph_trainer/configs.py
Comment thread torchtitan/experiments/graph_trainer/passes.py
)


def cpu_offload_pass(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this pass now.

tag_all_offloadable_activations is now run via tag_with_memory_policy_pass

Comment thread torchtitan/experiments/graph_trainer/passes.py Outdated
continue
existing = node.meta.get("recompute")

if sac_active:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need to detect sac_active?
just do defensive programming, and assume some tagging has been applied

Comment thread torchtitan/experiments/graph_trainer/cpu_offload.py Outdated
Comment thread torchtitan/experiments/graph_trainer/cpu_offload.py Outdated
Comment thread torchtitan/experiments/graph_trainer/cpu_offload.py Outdated
Comment thread torchtitan/experiments/graph_trainer/cpu_offload.py Outdated
Comment thread torchtitan/experiments/graph_trainer/cpu_offload.py Outdated
Comment thread torchtitan/experiments/graph_trainer/cpu_offload.py Outdated
Comment thread torchtitan/experiments/graph_trainer/cpu_offload.py Outdated
mlazos added a commit that referenced this pull request Apr 30, 2026
@mlazos mlazos requested a review from SherlockNoMad April 30, 2026 19:27
mlazos added a commit that referenced this pull request Apr 30, 2026
mlazos added a commit that referenced this pull request Apr 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants