Hi, I noticed the load_model_weights path fires distribute_tensor(src_data_rank=0) → dist.scatter() during startup, even though every rank already reads the full checkpoint from disk. Passing src_data_rank=None eliminates the scatter with zero correctness impact — I A/B tested across 5 models (Qwen2/Qwen3/MoE, 0.5B–14B) and got bit-for-bit identical losses.
The startup time improvement scales with model size (−10% at 0.5B, −22% at 3B, −18% at MoE-14B). Details and fix below.
What does this bug do?
In _build_fsdp2_model, load_model_weights is called with the default dtensor_factory=distribute_tensor. This calls distribute_tensor(full_tensor, mesh, [Shard(0)]) with the default src_data_rank=0, which triggers:
load_model_weights()
→ _dispatch_parameter()
→ distribute_tensor(full_tensor, mesh, [Shard(0)]) # src_data_rank=0 (default)
→ Shard._shard_tensor()
→ mesh_scatter(output, chunks, group_src=0) # ← fires dist.scatter()
However, load_model_weights already reads the full checkpoint on every rank independently — the existing log even says:
"Every rank would read weights from disk and expect this to be slow!"
So every rank holds the complete tensor before distribute_tensor is called. The scatter from rank 0 is 100% redundant: rank 0 sends chunks to ranks that already have identical data locally.
Impact
On all backends: wastes up to (world_size-1)/world_size × model_size of network bandwidth per training startup, and the overhead scales linearly with model size and number of parameters:
| Model |
Family |
Scatter calls |
Wire bytes wasted |
Load time DEFAULT → FIX |
| Qwen2.5-0.5B |
qwen2 dense |
290 |
0.988 GB |
0.936s → 0.841s (−10%) |
| Qwen2.5-1.5B |
qwen2 dense |
338 |
3.087 GB |
2.746s → 2.369s (−14%) |
| Qwen3-0.6B |
qwen3 dense |
310 |
1.192 GB |
1.261s → 0.976s (−23%) |
| Qwen2.5-3B |
qwen2 dense |
434 |
6.172 GB |
5.832s → 4.520s (−22%) |
| Qwen1.5-MoE-A2.7B |
qwen2_moe MoE |
4,659 |
28.632 GB |
22.4s → 18.4s (−18%) |
On backends without P2P IPC support (e.g. XCCL on PCIe): dist.scatter() causes a hard hang.
Fix
Pass src_data_rank=None so PyTorch performs a local tensor split with zero communication. Per PyTorch DTensor docs: when src_data_rank=None, each rank slices its own local chunk — which is correct because all ranks already hold the full tensor.
In veomni/distributed/torch_parallelize.py, inside _build_fsdp2_model:
BEFORE (VeOmni 0.1.4):
load_model_weights(model, weights_path, get_device_type(), dtensor_factory=distribute_tensor)
AFTER:
import functools
Every rank already read the full checkpoint from disk, so scatter is redundant.
src_data_rank=None → local split only, zero collective communication.
_dt_local_split = functools.partial(distribute_tensor, src_data_rank=None)
load_model_weights(model, weights_path, get_device_type(), dtensor_factory=_dt_local_split)
⚠️ Scope: this fix applies only to the load_model_weights (every-rank-reads) path. The rank0_load_and_broadcast_weights path should keep the default src_data_rank=0 — that path legitimately has only rank 0 reading the checkpoint.
Correctness Verification — A100 PCIe, 2-GPU
A/B tested on 5 models across 2 architecture families, DP=2 and SP=2 modes. Losses are bit-for-bit identical to 6 decimal places between DEFAULT and FIX:
| Model |
Step 1 |
Step 2 |
Step 3 |
Step 4 |
Step 5 |
Match |
| Qwen2.5-0.5B (DP=2) |
2.515763 |
0.822851 |
3.626765 |
2.051105 |
0.154655 |
✅ |
| Qwen2.5-0.5B (SP=2) |
4.134645 |
1.723882 |
0.514100 |
0.124871 |
0.009825 |
✅ |
Why src_data_rank=None is the correct API here
From PyTorch torch/distributed/tensor/_api.py, Shard._shard_tensor:
if src_data_rank is None:
# NO communication — local split only
chunks = self._split_tensor(tensor, num_chunks, ...)
return chunks[my_rank] # ← local slice, zero network I/O
else:
# COMMUNICATION — scatter from src_data_rank
mesh_scatter(output, chunks, mesh, group_src=src_data_rank) # ← dist.scatter()
src_data_rank=None is a first-class PyTorch API parameter (since PyTorch 2.1), explicitly designed for the case when every rank already has the full tensor. It is not a workaround.
Cross-framework comparison
No other major training framework uses scatter for weight loading:
| Framework |
Who reads from disk |
How weights reach each rank |
Uses scatter? |
| TorchTitan |
Each rank reads its shard only (DCP) |
Direct per-rank shard read |
No |
| VERL FSDP |
Rank 0 only |
set_model_state_dict(broadcast_from_rank0=True) |
No |
| Megatron-LM |
Each rank reads its shard only (pre-sharded files) |
Direct per-rank file read |
No |
| VeOmni |
Every rank reads full model |
distribute_tensor(src_data_rank=0) → scatter |
Yes — the bug |
VeOmni is unique in the "every rank reads full checkpoint" pattern. For that pattern, src_data_rank=None is the correct API to use.
Hi, I noticed the load_model_weights path fires distribute_tensor(src_data_rank=0) → dist.scatter() during startup, even though every rank already reads the full checkpoint from disk. Passing src_data_rank=None eliminates the scatter with zero correctness impact — I A/B tested across 5 models (Qwen2/Qwen3/MoE, 0.5B–14B) and got bit-for-bit identical losses.
The startup time improvement scales with model size (−10% at 0.5B, −22% at 3B, −18% at MoE-14B). Details and fix below.
What does this bug do?
In _build_fsdp2_model, load_model_weights is called with the default dtensor_factory=distribute_tensor. This calls distribute_tensor(full_tensor, mesh, [Shard(0)]) with the default src_data_rank=0, which triggers:
However, load_model_weights already reads the full checkpoint on every rank independently — the existing log even says:
"Every rank would read weights from disk and expect this to be slow!"
So every rank holds the complete tensor before distribute_tensor is called. The scatter from rank 0 is 100% redundant: rank 0 sends chunks to ranks that already have identical data locally.
Impact
On all backends: wastes up to
(world_size-1)/world_size × model_sizeof network bandwidth per training startup, and the overhead scales linearly with model size and number of parameters:On backends without P2P IPC support (e.g. XCCL on PCIe):
dist.scatter()causes a hard hang.Fix
Pass
src_data_rank=Noneso PyTorch performs a local tensor split with zero communication. Per PyTorch DTensor docs: whensrc_data_rank=None, each rank slices its own local chunk — which is correct because all ranks already hold the full tensor.In
veomni/distributed/torch_parallelize.py, inside_build_fsdp2_model:BEFORE (VeOmni 0.1.4):
load_model_weights(model, weights_path, get_device_type(), dtensor_factory=distribute_tensor)
AFTER:
import functools
Every rank already read the full checkpoint from disk, so scatter is redundant.
src_data_rank=None → local split only, zero collective communication.
_dt_local_split = functools.partial(distribute_tensor, src_data_rank=None)
load_model_weights(model, weights_path, get_device_type(), dtensor_factory=_dt_local_split)
Correctness Verification — A100 PCIe, 2-GPU
A/B tested on 5 models across 2 architecture families, DP=2 and SP=2 modes. Losses are bit-for-bit identical to 6 decimal places between DEFAULT and FIX:
Why
src_data_rank=Noneis the correct API hereFrom PyTorch
torch/distributed/tensor/_api.py,Shard._shard_tensor:if src_data_rank is None:
# NO communication — local split only
chunks = self._split_tensor(tensor, num_chunks, ...)
return chunks[my_rank] # ← local slice, zero network I/O
else:
# COMMUNICATION — scatter from src_data_rank
mesh_scatter(output, chunks, mesh, group_src=src_data_rank) # ← dist.scatter()
src_data_rank=Noneis a first-class PyTorch API parameter (since PyTorch 2.1), explicitly designed for the case when every rank already has the full tensor. It is not a workaround.Cross-framework comparison
No other major training framework uses
scatterfor weight loading:VeOmni is unique in the "every rank reads full checkpoint" pattern. For that pattern,
src_data_rank=Noneis the correct API to use.