Skip to content

Commit 01aad93

Browse files
BestJulyko3n1g
andauthored
Save memory using main_param for moe in param_l2_norm (#2249)
Signed-off-by: lit <[email protected]> Co-authored-by: oliver könig <[email protected]>
1 parent f6e0d42 commit 01aad93

File tree

2 files changed

+83
-3
lines changed

2 files changed

+83
-3
lines changed

megatron/training/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,21 @@ def calc_params_l2_norm(model, force_create_fp32_copy=False):
8080
continue
8181
assert is_not_tp_duplicate
8282
if not getattr(param, 'allreduce', True):
83-
# TODO: Implement memory optimization for MoE parameters.
8483
assert param_is_not_shared(param)
8584
param = to_local_if_dtensor(param)
86-
moe_params_data.append(param.data.float() if args.bf16 else param.data)
85+
if args.bf16:
86+
if not force_create_fp32_copy and hasattr(param, 'main_param'):
87+
if getattr(param, 'main_param_sharded', False):
88+
if param.main_param is not None:
89+
sharded_params_data.append(param.main_param)
90+
else:
91+
moe_params_data.append(param.main_param)
92+
else:
93+
# Fallback to original logic of making a fp32 copy of the
94+
# parameter if `.main_param` attribute is not available.
95+
moe_params_data.append(param.data.float())
96+
else:
97+
moe_params_data.append(param.data)
8798
else:
8899
if param_is_not_shared(param):
89100
param = to_local_if_dtensor(param)

tests/unit_tests/test_utils.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import megatron.training.utils as training_util
1616
from megatron.core import config
1717
from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig
18+
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
1819
from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
1920
from megatron.core.transformer import TransformerConfig
21+
from megatron.core.transformer.moe.moe_layer import MoELayer
2022
from tests.unit_tests.test_utilities import Utils
2123

2224
success_string = "hello,world"
@@ -238,7 +240,7 @@ def test_cross_check_param_hashes_across_dp_replicas():
238240
@pytest.mark.flaky
239241
@pytest.mark.flaky_in_dev
240242
@pytest.mark.internal
241-
def test_param_norm(use_distributed_optimizer: bool):
243+
def test_param_norm_linear(use_distributed_optimizer: bool):
242244
world = int(os.getenv('WORLD_SIZE', '1'))
243245
rank = int(os.getenv('RANK', '0'))
244246

@@ -286,6 +288,73 @@ def test_param_norm(use_distributed_optimizer: bool):
286288
_deinit_distributed()
287289

288290

291+
@pytest.mark.parametrize("use_distributed_optimizer", [False, True])
292+
@pytest.mark.flaky
293+
@pytest.mark.flaky_in_dev
294+
@pytest.mark.internal
295+
def test_param_norm_moe(use_distributed_optimizer: bool):
296+
world = int(os.getenv('WORLD_SIZE', '1'))
297+
rank = int(os.getenv('RANK', '0'))
298+
299+
# Setup: distributed, model, mock_args.
300+
_init_distributed(world, rank)
301+
Utils.initialize_model_parallel()
302+
transformer_config = TransformerConfig(
303+
num_layers=1,
304+
hidden_size=12,
305+
num_attention_heads=4,
306+
num_moe_experts=2,
307+
use_cpu_initialization=True,
308+
moe_token_dispatcher_type="alltoall",
309+
moe_router_topk=2,
310+
moe_aux_loss_coeff=0.01,
311+
moe_grouped_gemm=True,
312+
moe_ffn_hidden_size=128,
313+
add_bias_linear=False,
314+
bf16=True,
315+
)
316+
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
317+
num_experts=2, moe_grouped_gemm=True
318+
)
319+
model = MoELayer(transformer_config, transformer_layer_spec.submodules.mlp.submodules).to(
320+
device='cuda'
321+
)
322+
model.requires_grad_(True)
323+
# Initialize the model with all 1.0 for weights.
324+
for param in model.parameters():
325+
param.data.fill_(1.0)
326+
ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=use_distributed_optimizer)
327+
model = DistributedDataParallel(transformer_config, ddp_config, model)
328+
for param in model.parameters():
329+
assert param.requires_grad
330+
mock_args = SimpleNamespace(bf16=True)
331+
332+
with mock.patch('megatron.training.utils.get_args', new=lambda: mock_args):
333+
# Make sure norm is correct when `main_param` attribute is not available.
334+
norm_no_fp32_copy = training_util.calc_params_l2_norm(model, force_create_fp32_copy=False)
335+
norm_fp32_copy = training_util.calc_params_l2_norm(model, force_create_fp32_copy=True)
336+
assert norm_no_fp32_copy == pytest.approx(norm_fp32_copy)
337+
338+
# Make sure norm is correct when `main_param` attribute is available.
339+
optimizer_config = OptimizerConfig(
340+
bf16=True, use_distributed_optimizer=use_distributed_optimizer
341+
)
342+
_ = get_megatron_optimizer(optimizer_config, [model])
343+
for param in model.parameters():
344+
# Only bf16/fp16 parameters get main_param attribute.
345+
# Router weights are always fp32, so they won't have main_param.
346+
if param.dtype in [torch.bfloat16, torch.float16]:
347+
assert hasattr(param, 'main_param')
348+
if use_distributed_optimizer:
349+
assert getattr(param, 'main_param_sharded', False)
350+
norm_no_fp32_copy = training_util.calc_params_l2_norm(model, force_create_fp32_copy=False)
351+
norm_fp32_copy = training_util.calc_params_l2_norm(model, force_create_fp32_copy=True)
352+
assert norm_no_fp32_copy == pytest.approx(norm_fp32_copy)
353+
354+
# Teardown.
355+
_deinit_distributed()
356+
357+
289358
@pytest.mark.flaky
290359
@pytest.mark.flaky_in_dev
291360
def test_straggler_detector():

0 commit comments

Comments
 (0)