11# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
3- from megatron .core .inference .inference_client import InferenceClient
4- from examples .inference .gpt .utils import add_common_inference_args
53import asyncio
6- import torch .distributed as dist
7- from examples .inference .gpt .gpt_dynamic_inference import get_model , get_inference_context , get_inference_controller , add_dynamic_inference_args
8- from megatron .core .inference .inference_request import DynamicInferenceRequest
9- from megatron .training import initialize_megatron
10- import torch
11- import os
12- from megatron .training import get_args , get_tokenizer
13- from megatron .core .inference .sampling_params import SamplingParams
14- from examples .inference .gpt .utils import build_requests , build_dynamic_engine_setup_prefix , Request
15- from megatron .core .inference .engines import DynamicInferenceEngine
4+ import json
5+ import os
166import time
7+ import torch
8+ import torch .distributed as dist
9+ from collections import defaultdict
1710from tqdm import tqdm
1811from typing import List
19- import json
20- from megatron .training .arguments import parse_args
12+ import warnings
13+ import logging
14+
15+ from examples .inference .gpt .gpt_dynamic_inference import (
16+ add_dynamic_inference_args ,
17+ get_inference_context ,
18+ get_inference_controller ,
19+ get_model ,
20+ )
21+ from examples .inference .gpt .utils import (
22+ Request ,
23+ build_dynamic_engine_setup_prefix ,
24+ build_requests ,
25+ add_common_inference_args
26+ )
27+
2128from megatron .core import parallel_state
29+ from megatron .core .inference .engines import DynamicInferenceEngine
30+ from megatron .core .inference .inference_client import InferenceClient
31+ from megatron .core .inference .inference_request import DynamicInferenceRequestRecord
32+ from megatron .core .inference .sampling_params import SamplingParams
33+ from megatron .core .utils import get_mamba_inference_state_config_from_model
2234
23- import logging
35+ from megatron .training import get_args , get_tokenizer , initialize_megatron
36+ from megatron .training .arguments import parse_args
37+
38+ # pylint: disable=line-too-long
2439
2540logging .basicConfig (level = logging .INFO , force = True )
2641
@@ -38,81 +53,150 @@ async def main(
3853 )
3954 # once you call engine.start_listening_to_data_parallel_coordinator,
4055 # the engine will start accepting requests from the data parallel coordinator.
41- # and processing them in an asyncio coroutine.
42- await engine .start_listening_to_data_parallel_coordinator (
43- inference_coordinator_port = port , launch_inference_coordinator = True
56+ # and processing them in an asyncio coroutine.
57+
58+ await engine .start_listening_to_data_parallel_coordinator (
59+ inference_coordinator_port = port ,
60+ launch_inference_coordinator = True ,
61+ verbose = True ,
4462 )
45- # if you want to use your own inference coordinator -
63+
64+ # if you want to use your own inference coordinator -
4665 # 1. set launch_inference_coordinator to False
4766 # 2. setup a router socket at tcp://MASTER_ADDR:PORT
4867 # 3. wait for data parallel groups to establish connection (BasicInferenceCoordinator.__init__)
4968 # 4. look at InferenceCoordinator.start() to see how we can route requests from users <-> data parallel groups
50- # based on headers.
51- # 5. look at InferenceClient to see how we create requests with headers.
52- if dist .get_rank () == 0 :
53- client = InferenceClient (port ) # submits requests to the inference coordinator
69+ # based on headers.
70+ # 5. look at InferenceClient to see how we create requests with headers.
71+
72+ args = get_args ()
73+
74+ # Test suspend/resume intervals.
75+ if args .suspend_resume_interval is not None :
76+ # Since the client doesn't directly call engine.async_step here, we test
77+ # the suspend-resume system ~4 times.
78+ suspend_resume_interval = max (1 , len (requests ) // 4 )
79+ suspend_idxs = set (range (
80+ suspend_resume_interval ,
81+ len (requests ) + 1 ,
82+ suspend_resume_interval ,
83+ ))
84+ resume_idxs = set (
85+ min (len (requests ), i + suspend_resume_interval // 2 )
86+ for i in suspend_idxs
87+ )
88+ else :
89+ suspend_idxs = set ()
90+ resume_idxs = set ()
91+
92+ # Create client and run example.
93+ if dist .get_rank () == 0 :
94+ client = InferenceClient (port ) # submits requests to the inference coordinator
5495 await client .start ()
5596 base_arrival_time = time .time_ns () / 10 ** 9
5697 for request in requests :
5798 request .time_arrival = request .time_offset + base_arrival_time
5899 futures = []
59100 num_requests_total = len (requests )
60101 num_requests_added = 0
61- #tbar = tqdm(total=num_requests_total)
102+
62103 while True :
63104 current_time = time .time_ns () / 10 ** 9
64- # Only add requests that have arrived at the current time.
65- while num_requests_added < num_requests_total and requests [num_requests_added ].time_arrival <= current_time :
66- request = requests [num_requests_added ]
67- # These add-request calls will queue up the request on a zmq socket and return
68- # instantaneously. They will return an asyncio future which can be awaited for
69- # request completion.
70- futures .append (client .add_request (request .prompt_text , request .sampling_params ))
71- num_requests_added += 1
72- #tbar.update(1)
105+ if args .incoming_requests_per_step is None :
106+ # Only add requests that have arrived at the current time.
107+ while num_requests_added < num_requests_total and requests [num_requests_added ].time_arrival <= current_time :
108+ request = requests [num_requests_added ]
109+ # These add-request calls will queue up the request on a zmq socket and return
110+ # instantaneously. They will return an asyncio future which can be awaited for
111+ # request completion.
112+ futures .append (client .add_request (request .prompt_text , request .sampling_params ))
113+ num_requests_added += 1
114+
115+ # Test suspend/resume.
116+ if num_requests_added in suspend_idxs :
117+ client .suspend_engines ()
118+ if num_requests_added in resume_idxs :
119+ client .resume_engines ()
120+
121+ else :
122+ # Add deterministic number of requests (generally used for debugging).
123+ for i in range (min (
124+ args .incoming_requests_per_step ,
125+ num_requests_total - num_requests_added
126+ )):
127+ # Change sampling parameters to force different generation lengths.
128+ request = requests [num_requests_added ]
129+ n = request .sampling_params .num_tokens_to_generate
130+ request .sampling_params .num_tokens_to_generate = n + i
131+ futures .append (client .add_request (request .prompt_text , request .sampling_params ))
132+ num_requests_added += 1
133+
134+ # Test suspend/resume.
135+ if num_requests_added in suspend_idxs :
136+ client .suspend_engines ()
137+ if num_requests_added in resume_idxs :
138+ client .resume_engines ()
139+
73140 if num_requests_added == num_requests_total :
74141 break
75- # Relinquish control since there are no more requests to add at the moment. This allows the engine to run.
142+ # Relinquish control since there are no more requests to add at the moment. This allows the engine to run.
76143 await asyncio .sleep (0 )
77- # While we wait for the requests to complete, the engine runs in the background.
78- results : List [DynamicInferenceRequest ] = await asyncio .gather (* futures )
79144
145+ # While we wait for the requests to complete, the engine runs in the background.
146+ results : List [DynamicInferenceRequestRecord ] = await asyncio .gather (* futures )
80147
81148 if dist .get_rank () == 0 :
82149 # Write results to JSON. Primarily used for functional testing.
83150 if args .output_path :
84151 json_results = {}
152+ throughputs = []
85153
86- for req in results :
154+ for record in results :
155+ req = record .merge (engine .controller .tokenizer )
87156 result_dict = {
88157 "input_prompt" : req .prompt ,
89158 "generated_text" : req .generated_text .replace ("\n " , "\\ n" ),
90159 "generated_tokens" : req .generated_tokens ,
91- "latency" : req .latency , # InferenceClient populates this field in the returned future.
160+ "latency" : req .latency , # InferenceClient populates this field in the returned future.
92161 }
93162 if req .sampling_params ["return_log_probs" ]:
94163 result_dict ["logprobs" ] = req .prompt_log_probs + req .generated_log_probs
164+ throughput = len (req .generated_tokens ) / req .latency
165+ throughputs .append (throughput )
95166 json_results [req .request_id ] = result_dict
167+ throughput_dict = {"throughput" : throughputs }
168+ if args .throughput_check_only :
169+ json_results = throughput_dict
96170 with open (args .output_path , "w" ) as fp :
97171 json .dump (json_results , fp , indent = 4 )
98172 else :
99173 print ("Results:" )
100- for req in results :
101- print (f"rid: { req .request_id } \n prompt: { req .prompt !r} \n output: { req .generated_text !r} \n \n " )
102-
174+ unique_prompt_map = defaultdict (list )
175+ for record in results :
176+ req = record .merge (engine .controller .tokenizer )
177+ unique_prompt_map [req .prompt ].append (req )
178+ for idx , (prompt_text , reqs ) in enumerate (unique_prompt_map .items ()):
179+ print (f"%d/%d. prompt '%s' ... [%d] output '%s'." % (
180+ idx ,
181+ len (unique_prompt_map ),
182+ prompt_text .replace ("\n " , "\\ n" ),
183+ len (reqs ),
184+ reqs [0 ].generated_text .replace ("\n " , "\\ n" ),
185+ ))
186+
103187 # kill the engines and suspend the client
104188 client .stop_engines ()
105189 client .stop ()
106-
190+
107191 # once the stop signal eventually makes its way to each GPU, the engines will stop.
108192 await asyncio .gather (engine .engine_loop_task )
109193
194+
110195if __name__ == "__main__" :
111- # enable inference mode in the very beginning as some fp-8 optimizations
196+ # enable inference mode in the very beginning as some fp-8 optimizations
112197 # check for it.
113198 with torch .inference_mode ():
114199 initialize_megatron (
115- #parsed_args=args
116200 extra_args_provider = add_dynamic_inference_args ,
117201 args_defaults = {'no_load_rng' : True , 'no_load_optim' : True },
118202 )
@@ -131,17 +215,25 @@ async def main(
131215 top_p = args .top_p ,
132216 return_log_probs = args .return_log_probs ,
133217 num_tokens_to_generate = args .num_tokens_to_generate ,
134- termination_id = args .termination_id if args .termination_id is not None else tokenizer .eod ,
218+ termination_id = (
219+ args .termination_id if args .termination_id is not None else tokenizer .eod
220+ ),
135221 )
136222
137223 # Requests, context, conroller.
138224 model = get_model ()
139- requests = build_requests (args , tokenizer , sampling_params ) if dist .get_rank () == 0 else None
225+ mamba_inference_state_config = get_mamba_inference_state_config_from_model (model )
226+ requests = (
227+ build_requests (args , tokenizer , sampling_params ) if dist .get_rank () == 0 else None
228+ )
229+
230+ context = get_inference_context (
231+ None ,
232+ None ,
233+ calculate_max_sequence_length_from_requests = False ,
234+ mamba_inference_state_config = mamba_inference_state_config ,
235+ )
140236
141- context = get_inference_context (None ,
142- None ,
143- calculate_max_sequence_length_from_requests = False )
144-
145237 controller = get_inference_controller (model , context )
146238
147239 # Inference engine.
@@ -150,17 +242,19 @@ async def main(
150242 context ,
151243 enable_cuda_graph = args .cuda_graph_impl == "local" ,
152244 random_seed = args .seed ,
153- enable_chunked_prefill = not args .disable_chunked_prefill
245+ enable_chunked_prefill = not args .disable_chunked_prefill ,
154246 )
155247
156-
157248 if dist .get_rank () == 0 :
158249 setup_prefix = build_dynamic_engine_setup_prefix (args , model , context , requests )
159250 print ("~~~" )
160251 print (setup_prefix )
161252 print ("~~~" )
162-
163- asyncio .run (main (engine ,
164- requests ,
165- args .inference_coordinator_port ))
166253
254+ asyncio .run (
255+ main (
256+ engine ,
257+ requests ,
258+ args .inference_coordinator_port ,
259+ )
260+ )
0 commit comments