|
32 | 32 | from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( |
33 | 33 | GPTInferenceWrapper, |
34 | 34 | ) |
| 35 | +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( |
| 36 | + InferenceWrapperConfig, |
| 37 | +) |
| 38 | +from megatron.core.utils import get_model_config |
35 | 39 | from megatron.core.inference.text_generation_controllers.text_generation_controller import ( |
36 | 40 | TextGenerationController, |
37 | 41 | ) |
@@ -524,7 +528,16 @@ def create_mcore_engine( |
524 | 528 | max_batch_size=max_batch_size, |
525 | 529 | max_sequence_length=inference_max_seq_length, |
526 | 530 | ) |
527 | | - model_inference_wrapper = GPTInferenceWrapper(model, inference_context) |
| 531 | + model_config = get_model_config(model) |
| 532 | + inference_wrapper_config = InferenceWrapperConfig( |
| 533 | + hidden_size=model_config.hidden_size, |
| 534 | + params_dtype=params_dtype, |
| 535 | + inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, |
| 536 | + padded_vocab_size=tokenizer.vocab_size, |
| 537 | + fp32_residual_connection=getattr(model_config, 'fp32_residual_connection', False), |
| 538 | + inference_max_seq_length=inference_max_seq_length, |
| 539 | + ) |
| 540 | + model_inference_wrapper = GPTInferenceWrapper(model, inference_wrapper_config, inference_context) |
528 | 541 | text_generation_controller = TextGenerationController( |
529 | 542 | inference_wrapped_model=model_inference_wrapper, tokenizer=tokenizer |
530 | 543 | ) |
|
0 commit comments