Skip to content

[dist] bug: distribute_tensor uses src_data_rank=0 redundantly in load_model_weights when all ranks load from disk #637

@kahlun

Description

@kahlun

Hi, I noticed the load_model_weights path fires distribute_tensor(src_data_rank=0) → dist.scatter() during startup, even though every rank already reads the full checkpoint from disk. Passing src_data_rank=None eliminates the scatter with zero correctness impact — I A/B tested across 5 models (Qwen2/Qwen3/MoE, 0.5B–14B) and got bit-for-bit identical losses.

The startup time improvement scales with model size (−10% at 0.5B, −22% at 3B, −18% at MoE-14B). Details and fix below.


What does this bug do?

In _build_fsdp2_model, load_model_weights is called with the default dtensor_factory=distribute_tensor. This calls distribute_tensor(full_tensor, mesh, [Shard(0)]) with the default src_data_rank=0, which triggers:

load_model_weights()
  → _dispatch_parameter()
    → distribute_tensor(full_tensor, mesh, [Shard(0)])   # src_data_rank=0 (default)
      → Shard._shard_tensor()
        → mesh_scatter(output, chunks, group_src=0)       # ← fires dist.scatter()

However, load_model_weights already reads the full checkpoint on every rank independently — the existing log even says:

"Every rank would read weights from disk and expect this to be slow!"

So every rank holds the complete tensor before distribute_tensor is called. The scatter from rank 0 is 100% redundant: rank 0 sends chunks to ranks that already have identical data locally.

Impact

On all backends: wastes up to (world_size-1)/world_size × model_size of network bandwidth per training startup, and the overhead scales linearly with model size and number of parameters:

Model Family Scatter calls Wire bytes wasted Load time DEFAULT → FIX
Qwen2.5-0.5B qwen2 dense 290 0.988 GB 0.936s → 0.841s (−10%)
Qwen2.5-1.5B qwen2 dense 338 3.087 GB 2.746s → 2.369s (−14%)
Qwen3-0.6B qwen3 dense 310 1.192 GB 1.261s → 0.976s (−23%)
Qwen2.5-3B qwen2 dense 434 6.172 GB 5.832s → 4.520s (−22%)
Qwen1.5-MoE-A2.7B qwen2_moe MoE 4,659 28.632 GB 22.4s → 18.4s (−18%)

On backends without P2P IPC support (e.g. XCCL on PCIe): dist.scatter() causes a hard hang.


Fix

Pass src_data_rank=None so PyTorch performs a local tensor split with zero communication. Per PyTorch DTensor docs: when src_data_rank=None, each rank slices its own local chunk — which is correct because all ranks already hold the full tensor.

In veomni/distributed/torch_parallelize.py, inside _build_fsdp2_model:

BEFORE (VeOmni 0.1.4):

load_model_weights(model, weights_path, get_device_type(), dtensor_factory=distribute_tensor)

AFTER:

import functools

Every rank already read the full checkpoint from disk, so scatter is redundant.

src_data_rank=None → local split only, zero collective communication.

_dt_local_split = functools.partial(distribute_tensor, src_data_rank=None)
load_model_weights(model, weights_path, get_device_type(), dtensor_factory=_dt_local_split)

⚠️ Scope: this fix applies only to the load_model_weights (every-rank-reads) path. The rank0_load_and_broadcast_weights path should keep the default src_data_rank=0 — that path legitimately has only rank 0 reading the checkpoint.

Correctness Verification — A100 PCIe, 2-GPU

A/B tested on 5 models across 2 architecture families, DP=2 and SP=2 modes. Losses are bit-for-bit identical to 6 decimal places between DEFAULT and FIX:

Model Step 1 Step 2 Step 3 Step 4 Step 5 Match
Qwen2.5-0.5B (DP=2) 2.515763 0.822851 3.626765 2.051105 0.154655
Qwen2.5-0.5B (SP=2) 4.134645 1.723882 0.514100 0.124871 0.009825

Why src_data_rank=None is the correct API here

From PyTorch torch/distributed/tensor/_api.py, Shard._shard_tensor:

if src_data_rank is None:
# NO communication — local split only
chunks = self._split_tensor(tensor, num_chunks, ...)
return chunks[my_rank] # ← local slice, zero network I/O
else:
# COMMUNICATION — scatter from src_data_rank
mesh_scatter(output, chunks, mesh, group_src=src_data_rank) # ← dist.scatter()

src_data_rank=None is a first-class PyTorch API parameter (since PyTorch 2.1), explicitly designed for the case when every rank already has the full tensor. It is not a workaround.


Cross-framework comparison

No other major training framework uses scatter for weight loading:

Framework Who reads from disk How weights reach each rank Uses scatter?
TorchTitan Each rank reads its shard only (DCP) Direct per-rank shard read No
VERL FSDP Rank 0 only set_model_state_dict(broadcast_from_rank0=True) No
Megatron-LM Each rank reads its shard only (pre-sharded files) Direct per-rank file read No
VeOmni Every rank reads full model distribute_tensor(src_data_rank=0) → scatter Yes — the bug

VeOmni is unique in the "every rank reads full checkpoint" pattern. For that pattern, src_data_rank=None is the correct API to use.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions