Skip to content

Commit b74d869

Browse files
authored
Merge branch 'dev' into optimize_hybrid_ep
2 parents 7b3f1e5 + ed804b4 commit b74d869

File tree

150 files changed

+10126
-3849
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

150 files changed

+10126
-3849
lines changed

.github/copy-pr-bot.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
enabled: true
22
auto_sync_draft: false
33
auto_sync_ready: true
4-
trustees_override: ["AAnoosheh", "ArEsKay3", "Autumn1998", "BestJuly", "BoxiangW", "ChenhanYu", "FDecaYed", "HaochenYuan", "ISEEKYAN", "JRD971000", "QiZhangNV", "ShriyaRishab", "Victarry", "Wohox", "ZhiyuLi-Nvidia", "aklife97", "ananthsub", "asolergi-nv", "buptzyb", "chtruong814", "cspades", "cuichenx", "deepakn94", "dimapihtar", "duncanriach", "erhoo82", "ericharper", "fanshiqing", "gautham-kollu", "hxbai", "jaredcasper", "jiemingz", "jkamalu", "jon-barker", "kanz-nv", "kevalmorabia97", "ko3n1g", "kunlunl", "kvareddy", "layalir", "lhb8125", "lmcafee-nvidia", "maanug-nv", "mathemakitten", "matthieule", "mehraakash", "mkhona-nvidia", "pablo-garay", "parthmannan", "pthombre", "rogerwaleffe", "sanandaraj5597", "santhnm2", "sbak5", "shanmugamr1992", "shifangx", "shjwudp", "sidsingh-nvidia", "skyw", "tdene", "theothermike", "thomasdhc", "trintamaki", "tylerpoon", "wdykas", "xiaoyao0115", "xuwchen", "yanring", "yaox12", "yaoyu-33", "yashaswikarnati", "yobibyte", "youngeunkwon0405", "yuzhongw-nvidia", "zhongbozhu"]
4+
trustees_override: ["AAnoosheh", "ArEsKay3", "Autumn1998", "BestJuly", "BoxiangW", "ChenhanYu", "FDecaYed", "HaochenYuan", "ISEEKYAN", "JRD971000", "QiZhangNV", "ShriyaRishab", "Victarry", "Wohox", "ZhiyuLi-Nvidia", "aklife97", "ananthsub", "asolergi-nv", "buptzyb", "chtruong814", "cspades", "cuichenx", "deepakn94", "dimapihtar", "duncanriach", "erhoo82", "ericharper", "fanshiqing", "gautham-kollu", "guyueh1", "hxbai", "jaredcasper", "jiemingz", "jkamalu", "jon-barker", "kanz-nv", "kevalmorabia97", "ko3n1g", "kunlunl", "kvareddy", "layalir", "lhb8125", "lmcafee-nvidia", "maanug-nv", "mathemakitten", "matthieule", "mehraakash", "mkhona-nvidia", "pablo-garay", "parthmannan", "pthombre", "rogerwaleffe", "sanandaraj5597", "santhnm2", "sbak5", "shanmugamr1992", "shifangx", "shjwudp", "sidsingh-nvidia", "skyw", "tdene", "theothermike", "thomasdhc", "trintamaki", "tylerpoon", "wdykas", "xiaoyao0115", "xuwchen", "yanring", "yaox12", "yaoyu-33", "yashaswikarnati", "yeyu-nvidia", "yobibyte", "youngeunkwon0405", "yuzhongw-nvidia", "zhongbozhu"]

.github/workflows/auto-update-copy-pr-bot.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ jobs:
4848
mv .github/copy-pr-bot.yaml.new .github/copy-pr-bot.yaml
4949
5050
- name: Commit changes
51+
env:
52+
GH_TOKEN: ${{ secrets.PAT }}
5153
run: |
52-
git remote set-url origin https://x-access-token:${{ secrets.PAT }}@github.com/NVIDIA/Megatron-LM.git
54+
git remote set-url origin https://x-access-token:${GH_TOKEN}@github.com/NVIDIA/Megatron-LM.git
5355
git config --global user.name "GitHub Actions"
5456
git config --global user.email "github-actions[bot]@users.noreply.github.com"
5557
git add .github/copy-pr-bot.yaml
@@ -58,4 +60,4 @@ jobs:
5860
exit 0
5961
fi
6062
git commit -m "Update copy-pr-bot.yaml [skip ci]"
61-
git push
63+
git push -u origin main

.github/workflows/community-bot.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ on:
2222
jobs:
2323
community-bot:
2424
uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/[email protected]
25+
with:
26+
community_project_id: ${{ vars.COMMUNITY_PROJECT_ID }}
2527
if: github.repository == 'NVIDIA/Megatron-LM'
2628
secrets:
2729
GH_TOKEN: ${{ secrets.PAT }}
28-
environment: main

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 171 additions & 104 deletions
Large diffs are not rendered by default.

examples/inference/gpt/gpt_dynamic_inference_12b.sh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,9 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
2424

2525
# Dynamic context.
2626
: ${BUFFER_SIZE_GB=50.}
27-
: ${BUFFER_OVERFLOW_FACTOR=1.}
28-
: ${BUFFER_GUARANTEED_FRACTION=0.05}
2927

3028
# Cuda graphs.
31-
: ${CUDA_GRAPH_IMPL=local}
3229
: ${NUM_CUDA_GRAPHS=16}
33-
: ${CUDA_GRAPH_SHARE_IO_BUFFERS=1}
3430

3531
# Miscellaneous.
3632
: ${USE_COORDINATOR=0}
@@ -79,8 +75,6 @@ ARGS=" \
7975
\
8076
--inference-dynamic-batching \
8177
--inference-dynamic-batching-buffer-size-gb ${BUFFER_SIZE_GB} \
82-
--inference-dynamic-batching-buffer-overflow-factor ${BUFFER_OVERFLOW_FACTOR} \
83-
--inference-dynamic-batching-buffer-guaranteed-fraction ${BUFFER_GUARANTEED_FRACTION} \
8478
\
8579
${EXTRA_ARGS} \
8680
"
@@ -91,6 +85,10 @@ if [ "${NUM_CUDA_GRAPHS}" != "0" ]; then
9185
--cuda-graph-impl local \
9286
--inference-dynamic-batching-num-cuda-graphs ${NUM_CUDA_GRAPHS} \
9387
"
88+
else
89+
ARGS+=" \
90+
--cuda-graph-impl none \
91+
"
9492
fi
9593

9694
# Prompts.

examples/inference/gpt/gpt_dynamic_inference_357m.sh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,9 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
2525

2626
# Dynamic context.
2727
: ${BUFFER_SIZE_GB=50.}
28-
: ${BUFFER_OVERFLOW_FACTOR=1.}
29-
: ${BUFFER_GUARANTEED_FRACTION=0.05}
3028

3129
# Cuda graphs.
32-
: ${CUDA_GRAPH_IMPL=local}
3330
: ${NUM_CUDA_GRAPHS=16}
34-
: ${CUDA_GRAPH_SHARE_IO_BUFFERS=1}
3531

3632
# Miscellaneous.
3733
: ${USE_COORDINATOR=0}
@@ -65,8 +61,6 @@ ARGS=" \
6561
\
6662
--inference-dynamic-batching \
6763
--inference-dynamic-batching-buffer-size-gb ${BUFFER_SIZE_GB} \
68-
--inference-dynamic-batching-buffer-overflow-factor ${BUFFER_OVERFLOW_FACTOR} \
69-
--inference-dynamic-batching-buffer-guaranteed-fraction ${BUFFER_GUARANTEED_FRACTION} \
7064
\
7165
${EXTRA_ARGS} \
7266
"
@@ -77,6 +71,10 @@ if [ "${NUM_CUDA_GRAPHS}" != "0" ]; then
7771
--cuda-graph-impl local \
7872
--inference-dynamic-batching-num-cuda-graphs ${NUM_CUDA_GRAPHS} \
7973
"
74+
else
75+
ARGS+=" \
76+
--cuda-graph-impl none \
77+
"
8078
fi
8179

8280
# Prompts.
Lines changed: 150 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,41 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

3-
from megatron.core.inference.inference_client import InferenceClient
4-
from examples.inference.gpt.utils import add_common_inference_args
53
import asyncio
6-
import torch.distributed as dist
7-
from examples.inference.gpt.gpt_dynamic_inference import get_model, get_inference_context, get_inference_controller, add_dynamic_inference_args
8-
from megatron.core.inference.inference_request import DynamicInferenceRequest
9-
from megatron.training import initialize_megatron
10-
import torch
11-
import os
12-
from megatron.training import get_args, get_tokenizer
13-
from megatron.core.inference.sampling_params import SamplingParams
14-
from examples.inference.gpt.utils import build_requests, build_dynamic_engine_setup_prefix, Request
15-
from megatron.core.inference.engines import DynamicInferenceEngine
4+
import json
5+
import os
166
import time
7+
import torch
8+
import torch.distributed as dist
9+
from collections import defaultdict
1710
from tqdm import tqdm
1811
from typing import List
19-
import json
20-
from megatron.training.arguments import parse_args
12+
import warnings
13+
import logging
14+
15+
from examples.inference.gpt.gpt_dynamic_inference import (
16+
add_dynamic_inference_args,
17+
get_inference_context,
18+
get_inference_controller,
19+
get_model,
20+
)
21+
from examples.inference.gpt.utils import (
22+
Request,
23+
build_dynamic_engine_setup_prefix,
24+
build_requests,
25+
add_common_inference_args
26+
)
27+
2128
from megatron.core import parallel_state
29+
from megatron.core.inference.engines import DynamicInferenceEngine
30+
from megatron.core.inference.inference_client import InferenceClient
31+
from megatron.core.inference.inference_request import DynamicInferenceRequestRecord
32+
from megatron.core.inference.sampling_params import SamplingParams
33+
from megatron.core.utils import get_mamba_inference_state_config_from_model
2234

23-
import logging
35+
from megatron.training import get_args, get_tokenizer, initialize_megatron
36+
from megatron.training.arguments import parse_args
37+
38+
# pylint: disable=line-too-long
2439

2540
logging.basicConfig(level=logging.INFO, force=True)
2641

@@ -38,81 +53,150 @@ async def main(
3853
)
3954
# once you call engine.start_listening_to_data_parallel_coordinator,
4055
# the engine will start accepting requests from the data parallel coordinator.
41-
# and processing them in an asyncio coroutine.
42-
await engine.start_listening_to_data_parallel_coordinator(
43-
inference_coordinator_port=port, launch_inference_coordinator=True
56+
# and processing them in an asyncio coroutine.
57+
58+
await engine.start_listening_to_data_parallel_coordinator(
59+
inference_coordinator_port=port,
60+
launch_inference_coordinator=True,
61+
verbose=True,
4462
)
45-
# if you want to use your own inference coordinator -
63+
64+
# if you want to use your own inference coordinator -
4665
# 1. set launch_inference_coordinator to False
4766
# 2. setup a router socket at tcp://MASTER_ADDR:PORT
4867
# 3. wait for data parallel groups to establish connection (BasicInferenceCoordinator.__init__)
4968
# 4. look at InferenceCoordinator.start() to see how we can route requests from users <-> data parallel groups
50-
# based on headers.
51-
# 5. look at InferenceClient to see how we create requests with headers.
52-
if dist.get_rank() == 0:
53-
client = InferenceClient(port) # submits requests to the inference coordinator
69+
# based on headers.
70+
# 5. look at InferenceClient to see how we create requests with headers.
71+
72+
args = get_args()
73+
74+
# Test suspend/resume intervals.
75+
if args.suspend_resume_interval is not None:
76+
# Since the client doesn't directly call engine.async_step here, we test
77+
# the suspend-resume system ~4 times.
78+
suspend_resume_interval = max(1, len(requests) // 4)
79+
suspend_idxs = set(range(
80+
suspend_resume_interval,
81+
len(requests) + 1,
82+
suspend_resume_interval,
83+
))
84+
resume_idxs = set(
85+
min(len(requests), i + suspend_resume_interval // 2)
86+
for i in suspend_idxs
87+
)
88+
else:
89+
suspend_idxs = set()
90+
resume_idxs = set()
91+
92+
# Create client and run example.
93+
if dist.get_rank() == 0:
94+
client = InferenceClient(port) # submits requests to the inference coordinator
5495
await client.start()
5596
base_arrival_time = time.time_ns() / 10**9
5697
for request in requests:
5798
request.time_arrival = request.time_offset + base_arrival_time
5899
futures = []
59100
num_requests_total = len(requests)
60101
num_requests_added = 0
61-
#tbar = tqdm(total=num_requests_total)
102+
62103
while True:
63104
current_time = time.time_ns() / 10**9
64-
# Only add requests that have arrived at the current time.
65-
while num_requests_added < num_requests_total and requests[num_requests_added].time_arrival <= current_time:
66-
request = requests[num_requests_added]
67-
# These add-request calls will queue up the request on a zmq socket and return
68-
# instantaneously. They will return an asyncio future which can be awaited for
69-
# request completion.
70-
futures.append(client.add_request(request.prompt_text, request.sampling_params))
71-
num_requests_added += 1
72-
#tbar.update(1)
105+
if args.incoming_requests_per_step is None:
106+
# Only add requests that have arrived at the current time.
107+
while num_requests_added < num_requests_total and requests[num_requests_added].time_arrival <= current_time:
108+
request = requests[num_requests_added]
109+
# These add-request calls will queue up the request on a zmq socket and return
110+
# instantaneously. They will return an asyncio future which can be awaited for
111+
# request completion.
112+
futures.append(client.add_request(request.prompt_text, request.sampling_params))
113+
num_requests_added += 1
114+
115+
# Test suspend/resume.
116+
if num_requests_added in suspend_idxs:
117+
client.suspend_engines()
118+
if num_requests_added in resume_idxs:
119+
client.resume_engines()
120+
121+
else:
122+
# Add deterministic number of requests (generally used for debugging).
123+
for i in range(min(
124+
args.incoming_requests_per_step,
125+
num_requests_total - num_requests_added
126+
)):
127+
# Change sampling parameters to force different generation lengths.
128+
request = requests[num_requests_added]
129+
n = request.sampling_params.num_tokens_to_generate
130+
request.sampling_params.num_tokens_to_generate = n + i
131+
futures.append(client.add_request(request.prompt_text, request.sampling_params))
132+
num_requests_added += 1
133+
134+
# Test suspend/resume.
135+
if num_requests_added in suspend_idxs:
136+
client.suspend_engines()
137+
if num_requests_added in resume_idxs:
138+
client.resume_engines()
139+
73140
if num_requests_added == num_requests_total:
74141
break
75-
# Relinquish control since there are no more requests to add at the moment. This allows the engine to run.
142+
# Relinquish control since there are no more requests to add at the moment. This allows the engine to run.
76143
await asyncio.sleep(0)
77-
# While we wait for the requests to complete, the engine runs in the background.
78-
results: List[DynamicInferenceRequest] = await asyncio.gather(*futures)
79144

145+
# While we wait for the requests to complete, the engine runs in the background.
146+
results: List[DynamicInferenceRequestRecord] = await asyncio.gather(*futures)
80147

81148
if dist.get_rank() == 0:
82149
# Write results to JSON. Primarily used for functional testing.
83150
if args.output_path:
84151
json_results = {}
152+
throughputs = []
85153

86-
for req in results:
154+
for record in results:
155+
req = record.merge(engine.controller.tokenizer)
87156
result_dict = {
88157
"input_prompt": req.prompt,
89158
"generated_text": req.generated_text.replace("\n", "\\n"),
90159
"generated_tokens": req.generated_tokens,
91-
"latency": req.latency, #InferenceClient populates this field in the returned future.
160+
"latency": req.latency, # InferenceClient populates this field in the returned future.
92161
}
93162
if req.sampling_params["return_log_probs"]:
94163
result_dict["logprobs"] = req.prompt_log_probs + req.generated_log_probs
164+
throughput = len(req.generated_tokens) / req.latency
165+
throughputs.append(throughput)
95166
json_results[req.request_id] = result_dict
167+
throughput_dict = {"throughput": throughputs}
168+
if args.throughput_check_only:
169+
json_results = throughput_dict
96170
with open(args.output_path, "w") as fp:
97171
json.dump(json_results, fp, indent=4)
98172
else:
99173
print("Results:")
100-
for req in results:
101-
print(f"rid: {req.request_id}\nprompt: {req.prompt!r}\noutput: {req.generated_text!r}\n\n")
102-
174+
unique_prompt_map = defaultdict(list)
175+
for record in results:
176+
req = record.merge(engine.controller.tokenizer)
177+
unique_prompt_map[req.prompt].append(req)
178+
for idx, (prompt_text, reqs) in enumerate(unique_prompt_map.items()):
179+
print(f"%d/%d. prompt '%s' ... [%d] output '%s'." % (
180+
idx,
181+
len(unique_prompt_map),
182+
prompt_text.replace("\n", "\\n"),
183+
len(reqs),
184+
reqs[0].generated_text.replace("\n", "\\n"),
185+
))
186+
103187
# kill the engines and suspend the client
104188
client.stop_engines()
105189
client.stop()
106-
190+
107191
# once the stop signal eventually makes its way to each GPU, the engines will stop.
108192
await asyncio.gather(engine.engine_loop_task)
109193

194+
110195
if __name__ == "__main__":
111-
# enable inference mode in the very beginning as some fp-8 optimizations
196+
# enable inference mode in the very beginning as some fp-8 optimizations
112197
# check for it.
113198
with torch.inference_mode():
114199
initialize_megatron(
115-
#parsed_args=args
116200
extra_args_provider=add_dynamic_inference_args,
117201
args_defaults={'no_load_rng': True, 'no_load_optim': True},
118202
)
@@ -131,17 +215,25 @@ async def main(
131215
top_p=args.top_p,
132216
return_log_probs=args.return_log_probs,
133217
num_tokens_to_generate=args.num_tokens_to_generate,
134-
termination_id=args.termination_id if args.termination_id is not None else tokenizer.eod,
218+
termination_id=(
219+
args.termination_id if args.termination_id is not None else tokenizer.eod
220+
),
135221
)
136222

137223
# Requests, context, conroller.
138224
model = get_model()
139-
requests = build_requests(args, tokenizer, sampling_params) if dist.get_rank() == 0 else None
225+
mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)
226+
requests = (
227+
build_requests(args, tokenizer, sampling_params) if dist.get_rank() == 0 else None
228+
)
229+
230+
context = get_inference_context(
231+
None,
232+
None,
233+
calculate_max_sequence_length_from_requests=False,
234+
mamba_inference_state_config=mamba_inference_state_config,
235+
)
140236

141-
context = get_inference_context(None,
142-
None,
143-
calculate_max_sequence_length_from_requests=False)
144-
145237
controller = get_inference_controller(model, context)
146238

147239
# Inference engine.
@@ -150,17 +242,19 @@ async def main(
150242
context,
151243
enable_cuda_graph=args.cuda_graph_impl == "local",
152244
random_seed=args.seed,
153-
enable_chunked_prefill=not args.disable_chunked_prefill
245+
enable_chunked_prefill=not args.disable_chunked_prefill,
154246
)
155247

156-
157248
if dist.get_rank() == 0:
158249
setup_prefix = build_dynamic_engine_setup_prefix(args, model, context, requests)
159250
print("~~~")
160251
print(setup_prefix)
161252
print("~~~")
162-
163-
asyncio.run(main(engine,
164-
requests,
165-
args.inference_coordinator_port))
166253

254+
asyncio.run(
255+
main(
256+
engine,
257+
requests,
258+
args.inference_coordinator_port,
259+
)
260+
)

0 commit comments

Comments
 (0)