|
15 | 15 | import megatron.training.utils as training_util |
16 | 16 | from megatron.core import config |
17 | 17 | from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig |
| 18 | +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec |
18 | 19 | from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer |
19 | 20 | from megatron.core.transformer import TransformerConfig |
| 21 | +from megatron.core.transformer.moe.moe_layer import MoELayer |
20 | 22 | from tests.unit_tests.test_utilities import Utils |
21 | 23 |
|
22 | 24 | success_string = "hello,world" |
@@ -238,7 +240,7 @@ def test_cross_check_param_hashes_across_dp_replicas(): |
238 | 240 | @pytest.mark.flaky |
239 | 241 | @pytest.mark.flaky_in_dev |
240 | 242 | @pytest.mark.internal |
241 | | -def test_param_norm(use_distributed_optimizer: bool): |
| 243 | +def test_param_norm_linear(use_distributed_optimizer: bool): |
242 | 244 | world = int(os.getenv('WORLD_SIZE', '1')) |
243 | 245 | rank = int(os.getenv('RANK', '0')) |
244 | 246 |
|
@@ -286,6 +288,73 @@ def test_param_norm(use_distributed_optimizer: bool): |
286 | 288 | _deinit_distributed() |
287 | 289 |
|
288 | 290 |
|
| 291 | +@pytest.mark.parametrize("use_distributed_optimizer", [False, True]) |
| 292 | +@pytest.mark.flaky |
| 293 | +@pytest.mark.flaky_in_dev |
| 294 | +@pytest.mark.internal |
| 295 | +def test_param_norm_moe(use_distributed_optimizer: bool): |
| 296 | + world = int(os.getenv('WORLD_SIZE', '1')) |
| 297 | + rank = int(os.getenv('RANK', '0')) |
| 298 | + |
| 299 | + # Setup: distributed, model, mock_args. |
| 300 | + _init_distributed(world, rank) |
| 301 | + Utils.initialize_model_parallel() |
| 302 | + transformer_config = TransformerConfig( |
| 303 | + num_layers=1, |
| 304 | + hidden_size=12, |
| 305 | + num_attention_heads=4, |
| 306 | + num_moe_experts=2, |
| 307 | + use_cpu_initialization=True, |
| 308 | + moe_token_dispatcher_type="alltoall", |
| 309 | + moe_router_topk=2, |
| 310 | + moe_aux_loss_coeff=0.01, |
| 311 | + moe_grouped_gemm=True, |
| 312 | + moe_ffn_hidden_size=128, |
| 313 | + add_bias_linear=False, |
| 314 | + bf16=True, |
| 315 | + ) |
| 316 | + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( |
| 317 | + num_experts=2, moe_grouped_gemm=True |
| 318 | + ) |
| 319 | + model = MoELayer(transformer_config, transformer_layer_spec.submodules.mlp.submodules).to( |
| 320 | + device='cuda' |
| 321 | + ) |
| 322 | + model.requires_grad_(True) |
| 323 | + # Initialize the model with all 1.0 for weights. |
| 324 | + for param in model.parameters(): |
| 325 | + param.data.fill_(1.0) |
| 326 | + ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=use_distributed_optimizer) |
| 327 | + model = DistributedDataParallel(transformer_config, ddp_config, model) |
| 328 | + for param in model.parameters(): |
| 329 | + assert param.requires_grad |
| 330 | + mock_args = SimpleNamespace(bf16=True) |
| 331 | + |
| 332 | + with mock.patch('megatron.training.utils.get_args', new=lambda: mock_args): |
| 333 | + # Make sure norm is correct when `main_param` attribute is not available. |
| 334 | + norm_no_fp32_copy = training_util.calc_params_l2_norm(model, force_create_fp32_copy=False) |
| 335 | + norm_fp32_copy = training_util.calc_params_l2_norm(model, force_create_fp32_copy=True) |
| 336 | + assert norm_no_fp32_copy == pytest.approx(norm_fp32_copy) |
| 337 | + |
| 338 | + # Make sure norm is correct when `main_param` attribute is available. |
| 339 | + optimizer_config = OptimizerConfig( |
| 340 | + bf16=True, use_distributed_optimizer=use_distributed_optimizer |
| 341 | + ) |
| 342 | + _ = get_megatron_optimizer(optimizer_config, [model]) |
| 343 | + for param in model.parameters(): |
| 344 | + # Only bf16/fp16 parameters get main_param attribute. |
| 345 | + # Router weights are always fp32, so they won't have main_param. |
| 346 | + if param.dtype in [torch.bfloat16, torch.float16]: |
| 347 | + assert hasattr(param, 'main_param') |
| 348 | + if use_distributed_optimizer: |
| 349 | + assert getattr(param, 'main_param_sharded', False) |
| 350 | + norm_no_fp32_copy = training_util.calc_params_l2_norm(model, force_create_fp32_copy=False) |
| 351 | + norm_fp32_copy = training_util.calc_params_l2_norm(model, force_create_fp32_copy=True) |
| 352 | + assert norm_no_fp32_copy == pytest.approx(norm_fp32_copy) |
| 353 | + |
| 354 | + # Teardown. |
| 355 | + _deinit_distributed() |
| 356 | + |
| 357 | + |
289 | 358 | @pytest.mark.flaky |
290 | 359 | @pytest.mark.flaky_in_dev |
291 | 360 | def test_straggler_detector(): |
|
0 commit comments