Skip to content

Commit b51db3e

Browse files
deepakn94venmugilericharper
authored
Changes to support latent MoEs (#2296)
Signed-off-by: Deepak Narayanan <[email protected]> Co-authored-by: Venmugil Elango <[email protected]> Co-authored-by: Venmugil Elango <[email protected]> Co-authored-by: Eric Harper <[email protected]>
1 parent 01aad93 commit b51db3e

File tree

8 files changed

+95
-18
lines changed

8 files changed

+95
-18
lines changed

megatron/core/extensions/transformer_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def __init__(
535535
)
536536

537537
if is_te_min_version("0.8.0"):
538-
if self.config.tp_comm_overlap:
538+
if self.config.tp_comm_overlap and parallel_mode != "duplicated":
539539
if is_te_min_version("1.5.0"):
540540
# Use old overlap flags if they were supplied instead
541541
extra_kwargs["ub_overlap_ag"] = (

megatron/core/transformer/mlp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,13 @@ def __init__(
104104
if self.config.gated_linear_unit:
105105
ffn_hidden_size *= 2
106106

107+
# Use moe_latent_size only for routed experts. 'is_expert' is false for
108+
# shared_experts.
109+
use_latent_size = (self.config.moe_latent_size is not None) and is_expert
110+
107111
self.linear_fc1 = build_module(
108112
submodules.linear_fc1,
109-
self.input_size,
113+
self.input_size if not use_latent_size else self.config.moe_latent_size,
110114
ffn_hidden_size,
111115
config=self.config,
112116
init_method=self.config.init_method,
@@ -126,7 +130,7 @@ def __init__(
126130
self.linear_fc2 = build_module(
127131
submodules.linear_fc2,
128132
self.config.ffn_hidden_size,
129-
self.config.hidden_size,
133+
self.config.hidden_size if not use_latent_size else self.config.moe_latent_size,
130134
config=self.config,
131135
init_method=self.config.output_layer_init_method,
132136
bias=self.config.add_bias_linear,

megatron/core/transformer/moe/experts.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def __init__(
118118
assert (
119119
config.add_bias_linear == False
120120
), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
121+
assert (
122+
config.moe_latent_size is None
123+
), "MoE latent projection not supported in GroupedMLP yet."
121124

122125
self.expert_parallel = config.expert_model_parallel_size > 1
123126
if self.config.gated_linear_unit:
@@ -778,7 +781,7 @@ def __init__(
778781
self.linear_fc1 = build_module(
779782
submodules.linear_fc1,
780783
self.num_local_experts,
781-
self.input_size,
784+
self.input_size if self.config.moe_latent_size is None else self.config.moe_latent_size,
782785
ffn_hidden_size,
783786
config=self.config,
784787
init_method=self.config.init_method,
@@ -799,7 +802,11 @@ def __init__(
799802
submodules.linear_fc2,
800803
self.num_local_experts,
801804
self.config.moe_ffn_hidden_size,
802-
self.config.hidden_size,
805+
(
806+
self.config.hidden_size
807+
if self.config.moe_latent_size is None
808+
else self.config.moe_latent_size
809+
),
803810
config=self.config,
804811
init_method=self.config.output_layer_init_method,
805812
bias=self.config.add_bias_linear,

megatron/core/transformer/moe/moe_layer.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
try:
2424
import transformer_engine as te # pylint: disable=unused-import
2525

26-
from megatron.core.extensions.transformer_engine import te_checkpoint
26+
from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint
2727

2828
HAVE_TE = True
2929
except ImportError:
@@ -120,9 +120,35 @@ def __init__(
120120
and "shared_experts" in config.recompute_modules
121121
)
122122

123-
# Initialize router
123+
# Initialize router.
124124
self.router = TopKRouter(config=self.config, pg_collection=pg_collection)
125125

126+
# Initialize latent projections.
127+
if self.config.moe_latent_size:
128+
assert HAVE_TE, "TransformerEngine is required for MoE latent projections."
129+
self.fc1_latent_proj = TELinear(
130+
self.config.hidden_size,
131+
self.config.moe_latent_size,
132+
parallel_mode="duplicated",
133+
config=self.config,
134+
init_method=self.config.init_method,
135+
bias=self.config.add_bias_linear,
136+
skip_bias_add=False,
137+
skip_weight_param_allocation=False,
138+
is_expert=False,
139+
)
140+
self.fc2_latent_proj = TELinear(
141+
self.config.moe_latent_size,
142+
self.config.hidden_size,
143+
parallel_mode="duplicated",
144+
config=self.config,
145+
init_method=self.config.output_layer_init_method,
146+
bias=self.config.add_bias_linear,
147+
skip_bias_add=False,
148+
skip_weight_param_allocation=False,
149+
is_expert=False,
150+
)
151+
126152
# Initialize token dispatcher
127153
if config.moe_token_dispatcher_type == "allgather":
128154
self.token_dispatcher = MoEAllGatherTokenDispatcher(
@@ -176,6 +202,12 @@ def router_and_preprocess(self, hidden_states: torch.Tensor):
176202
"""
177203
residual = hidden_states
178204
probs, routing_map = self.router(hidden_states)
205+
# Project the hidden_states from hidden dimension down to latent dimenion.
206+
if self.config.moe_latent_size:
207+
assert (
208+
not self.shared_expert_overlap
209+
), "Shared expert overlap not supported when MoE latent projections are used."
210+
hidden_states, _ = self.fc1_latent_proj(hidden_states)
179211
hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
180212
hidden_states, routing_map, probs
181213
)
@@ -243,6 +275,10 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten
243275
"""
244276
output = self.token_dispatcher.token_combine(output)
245277
output = self.token_dispatcher.combine_postprocess(output)
278+
# Project the output back from latent dimension to hidden dimension after combine
279+
# in latent dimension.
280+
if self.config.moe_latent_size:
281+
output, _ = self.fc2_latent_proj(output)
246282
if shared_expert_output is not None:
247283
output = output + shared_expert_output
248284
return output
@@ -274,7 +310,9 @@ def custom_forward(hidden_states):
274310
hidden_states, probs, residual = self.router_and_preprocess(hidden_states)
275311
dispatched_input, probs = self.dispatch(hidden_states, probs)
276312
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)
313+
assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}"
277314
output = self.combine(output, shared_expert_output)
315+
278316
return output, mlp_bias
279317

280318
if self.moe_layer_recompute:

megatron/core/transformer/transformer_config.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,6 @@ class TransformerConfig(ModelParallelConfig):
209209
A list of integers: Defines a custom pattern where 1 means skip RoPE and 0 means apply RoPE.
210210
For example, [0,1,1,0] means: apply RoPE, skip RoPE, skip RoPE, apply RoPE."""
211211

212-
moe_deepep_num_sms: int = 20
213-
"""Number of SMs to use for DeepEP."""
214-
215-
moe_hybridep_num_sms: int = 16
216-
"""Number of SMs to use for HybridEP. In pure NVL scenarios,
217-
16 SMs can generally achieve good bandwidth."""
218-
219212
####################
220213
# initialization
221214
####################
@@ -609,6 +602,16 @@ class TransformerConfig(ModelParallelConfig):
609602
moe_apply_probs_on_input: bool = False
610603
"""Apply probs on input of experts instead of applying after activation and glu."""
611604

605+
moe_latent_size: Optional[int] = None
606+
"""Latent projection dimension for MoE. If None, MoE latent projections are not used."""
607+
608+
moe_deepep_num_sms: int = 20
609+
"""Number of SMs to use for DeepEP."""
610+
611+
moe_hybridep_num_sms: int = 16
612+
"""Number of SMs to use for HybridEP. In pure NVL scenarios,
613+
16 SMs can generally achieve good bandwidth."""
614+
612615
##################
613616
# Context Parallel
614617
##################

megatron/training/arguments.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,13 @@ def validate_args(args, defaults={}):
12491249
args.recompute_granularity != 'full'
12501250
), 'recompute_granularity must not be full when CUDA Graphs are enabled.'
12511251

1252+
# MoE latent projections
1253+
if args.moe_latent_size is not None:
1254+
assert args.moe_latent_size > 0, "MoE latent projection dimension has to be greater than zero."
1255+
assert args.num_experts is not None, "MoE latent projections are applicable only for MoE models."
1256+
assert not args.use_legacy_models, "MoE latent projections are only supported for mcore models."
1257+
assert not args.moe_use_legacy_grouped_gemm, "MoE latent projection is not supported yet with legacy grouped GEMM."
1258+
12521259
if args.tiktoken_special_tokens and not args.tokenizer_special_tokens:
12531260
warn_rank_0(
12541261
"--tiktoken-special-tokens argument is deprecated and will be removed soon. "
@@ -1355,6 +1362,8 @@ def core_transformer_config_from_args(args, config_class=None):
13551362
kw_args['use_kitchen'] = True
13561363
kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number)
13571364

1365+
kw_args['moe_latent_size'] = args.moe_latent_size
1366+
13581367
if args.te_precision_config_file:
13591368
assert not 'quant_recipe' in kw_args, "Quantization recipe already configured."
13601369
# TODO(kwyss): Prohibit fp8_params or fp4_params with this flexibility
@@ -1743,6 +1752,8 @@ def _add_network_size_args(parser):
17431752
'We compute the average of the MTP losses across all depths, '
17441753
'and multiply it the scaling factor to obtain the overall MTP loss, '
17451754
'which serves as an additional training objective.')
1755+
group.add_argument('--moe-latent-size', type=int, default=None,
1756+
help='Latent projection dimension for MoE. If None, MoE latent projections are not used.')
17461757
return parser
17471758

17481759

megatron/training/checkpointing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,9 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
13411341
_set_arg('heterogeneous_layers_config_path', force=True)
13421342
_set_arg('heterogeneous_layers_config_encoded_json', force=True)
13431343

1344+
# MoE latent projection.
1345+
_set_arg('moe_latent_size', force=True)
1346+
13441347
# Tokenizer args.
13451348
_set_arg('tokenizer_type', force=True)
13461349
# Using checkpoint version might not always be safe (e.g., if running on different cluster).

megatron/training/training.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,19 @@ def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=Fals
180180
return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2
181181

182182
def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
183-
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu=False):
183+
shared_expert_ffn_hidden_size, num_experts_routed_to,
184+
moe_latent_size=None, swiglu=False):
184185
"""Calculate FLOPs for an MoE layer."""
185186
scale_factor = 3.0 / 2.0 if swiglu else 1.0
186-
routed_flops = (4 * batch_size * seq_len * hidden_size *
187-
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
187+
if moe_latent_size is None:
188+
routed_flops = (4 * batch_size * seq_len * hidden_size *
189+
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
190+
else:
191+
# Routed experts run on moe_latent_size.
192+
routed_flops = (4 * batch_size * seq_len * moe_latent_size *
193+
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
194+
# Up proj and down proj.
195+
routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size)
188196
shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor
189197
return routed_flops + shared_flops
190198

@@ -232,6 +240,7 @@ def hybrid_flops(batch_size, seq_len, hidden_size,
232240
num_attn_heads=32, gqa=True,
233241
gqa_groups=8, kv_channels=None,
234242
mlp_expansion=4.0, swiglu=False,
243+
moe_latent_size=None,
235244
moe_ffn_hidden_size=2048, shared_expert_ffn_hidden_size=2048, num_experts_routed_to=1,
236245
vocab_size=256000):
237246
"""Calculate total FLOPs for the hybrid model."""
@@ -244,7 +253,8 @@ def hybrid_flops(batch_size, seq_len, hidden_size,
244253
mamba_state_dim, mamba_head_dim,
245254
mamba_num_groups, mamba_num_heads) +
246255
num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
247-
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu) +
256+
shared_expert_ffn_hidden_size, num_experts_routed_to,
257+
moe_latent_size, swiglu) +
248258
(2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation
249259
)
250260
return flops_fwd * 3
@@ -449,6 +459,7 @@ def transformer_flops():
449459
kv_channels=args.kv_channels,
450460
mlp_expansion=args.ffn_hidden_size / args.hidden_size,
451461
swiglu=args.swiglu,
462+
moe_latent_size=args.moe_latent_size,
452463
moe_ffn_hidden_size=(args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None
453464
else args.ffn_hidden_size),
454465
shared_expert_ffn_hidden_size=(0 if args.moe_shared_expert_intermediate_size is None

0 commit comments

Comments
 (0)