|
23 | 23 | try: |
24 | 24 | import transformer_engine as te # pylint: disable=unused-import |
25 | 25 |
|
26 | | - from megatron.core.extensions.transformer_engine import te_checkpoint |
| 26 | + from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint |
27 | 27 |
|
28 | 28 | HAVE_TE = True |
29 | 29 | except ImportError: |
@@ -120,9 +120,35 @@ def __init__( |
120 | 120 | and "shared_experts" in config.recompute_modules |
121 | 121 | ) |
122 | 122 |
|
123 | | - # Initialize router |
| 123 | + # Initialize router. |
124 | 124 | self.router = TopKRouter(config=self.config, pg_collection=pg_collection) |
125 | 125 |
|
| 126 | + # Initialize latent projections. |
| 127 | + if self.config.moe_latent_size: |
| 128 | + assert HAVE_TE, "TransformerEngine is required for MoE latent projections." |
| 129 | + self.fc1_latent_proj = TELinear( |
| 130 | + self.config.hidden_size, |
| 131 | + self.config.moe_latent_size, |
| 132 | + parallel_mode="duplicated", |
| 133 | + config=self.config, |
| 134 | + init_method=self.config.init_method, |
| 135 | + bias=self.config.add_bias_linear, |
| 136 | + skip_bias_add=False, |
| 137 | + skip_weight_param_allocation=False, |
| 138 | + is_expert=False, |
| 139 | + ) |
| 140 | + self.fc2_latent_proj = TELinear( |
| 141 | + self.config.moe_latent_size, |
| 142 | + self.config.hidden_size, |
| 143 | + parallel_mode="duplicated", |
| 144 | + config=self.config, |
| 145 | + init_method=self.config.output_layer_init_method, |
| 146 | + bias=self.config.add_bias_linear, |
| 147 | + skip_bias_add=False, |
| 148 | + skip_weight_param_allocation=False, |
| 149 | + is_expert=False, |
| 150 | + ) |
| 151 | + |
126 | 152 | # Initialize token dispatcher |
127 | 153 | if config.moe_token_dispatcher_type == "allgather": |
128 | 154 | self.token_dispatcher = MoEAllGatherTokenDispatcher( |
@@ -176,6 +202,12 @@ def router_and_preprocess(self, hidden_states: torch.Tensor): |
176 | 202 | """ |
177 | 203 | residual = hidden_states |
178 | 204 | probs, routing_map = self.router(hidden_states) |
| 205 | + # Project the hidden_states from hidden dimension down to latent dimenion. |
| 206 | + if self.config.moe_latent_size: |
| 207 | + assert ( |
| 208 | + not self.shared_expert_overlap |
| 209 | + ), "Shared expert overlap not supported when MoE latent projections are used." |
| 210 | + hidden_states, _ = self.fc1_latent_proj(hidden_states) |
179 | 211 | hidden_states, probs = self.token_dispatcher.dispatch_preprocess( |
180 | 212 | hidden_states, routing_map, probs |
181 | 213 | ) |
@@ -243,6 +275,10 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten |
243 | 275 | """ |
244 | 276 | output = self.token_dispatcher.token_combine(output) |
245 | 277 | output = self.token_dispatcher.combine_postprocess(output) |
| 278 | + # Project the output back from latent dimension to hidden dimension after combine |
| 279 | + # in latent dimension. |
| 280 | + if self.config.moe_latent_size: |
| 281 | + output, _ = self.fc2_latent_proj(output) |
246 | 282 | if shared_expert_output is not None: |
247 | 283 | output = output + shared_expert_output |
248 | 284 | return output |
@@ -274,7 +310,9 @@ def custom_forward(hidden_states): |
274 | 310 | hidden_states, probs, residual = self.router_and_preprocess(hidden_states) |
275 | 311 | dispatched_input, probs = self.dispatch(hidden_states, probs) |
276 | 312 | output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) |
| 313 | + assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}" |
277 | 314 | output = self.combine(output, shared_expert_output) |
| 315 | + |
278 | 316 | return output, mlp_bias |
279 | 317 |
|
280 | 318 | if self.moe_layer_recompute: |
|
0 commit comments