Skip to content

Commit 0d2bdd5

Browse files
committed
Fix for residual fp32 issue
Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
1 parent 649fa54 commit 0d2bdd5

1 file changed

Lines changed: 14 additions & 1 deletion

File tree

nemo_deploy/llm/inference/inference_base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
3333
GPTInferenceWrapper,
3434
)
35+
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
36+
InferenceWrapperConfig,
37+
)
38+
from megatron.core.utils import get_model_config
3539
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
3640
TextGenerationController,
3741
)
@@ -524,7 +528,16 @@ def create_mcore_engine(
524528
max_batch_size=max_batch_size,
525529
max_sequence_length=inference_max_seq_length,
526530
)
527-
model_inference_wrapper = GPTInferenceWrapper(model, inference_context)
531+
model_config = get_model_config(model)
532+
inference_wrapper_config = InferenceWrapperConfig(
533+
hidden_size=model_config.hidden_size,
534+
params_dtype=params_dtype,
535+
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
536+
padded_vocab_size=tokenizer.vocab_size,
537+
fp32_residual_connection=getattr(model_config, 'fp32_residual_connection', False),
538+
inference_max_seq_length=inference_max_seq_length,
539+
)
540+
model_inference_wrapper = GPTInferenceWrapper(model, inference_wrapper_config, inference_context)
528541
text_generation_controller = TextGenerationController(
529542
inference_wrapped_model=model_inference_wrapper, tokenizer=tokenizer
530543
)

0 commit comments

Comments
 (0)