Skip to content

Commit 4185aac

Browse files
committed
persist cudagraphs
1 parent c8903d9 commit 4185aac

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def __init__(
270270
cuda_graph_mixed_prefill_count: Optional[int] = 16,
271271
metrics_writer: Optional['WandbModule'] = None,
272272
num_request_metadata: Optional[int] = None,
273+
persist_cuda_graphs: Optional[bool] = False,
273274
):
274275
super().__init__(materialize_only_last_token_logits=materialize_only_last_token_logits)
275276

@@ -360,6 +361,7 @@ def __init__(
360361

361362
# Unified memory.
362363
self.unified_memory_level = unified_memory_level
364+
self.persist_cuda_graphs = persist_cuda_graphs
363365
if unified_memory_level > 0:
364366
try:
365367
self.unified_memory_mempool = create_unified_mempool()

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def __init__(
167167
self.enable_chunked_prefill = enable_chunked_prefill
168168
self.inference_logging_step_interval = inference_logging_step_interval
169169
self.unified_memory_level = context.unified_memory_level
170+
self.persist_cuda_graphs = context.persist_cuda_graphs
170171

171172
if enable_cuda_graph is not None:
172173
self.cuda_graph_impl = "local" if enable_cuda_graph else "none"
@@ -552,9 +553,9 @@ def suspend(self):
552553
self.context.deallocate_all_tensors()
553554

554555
# Delete cuda graphs when not using unified memory at all (level 0) and
555-
# `rl-persist-cuda-graphs` is not passed. For UVM levels 1 and 2, the context's tensors
556+
# `--rl-persist-cuda-graphs` is not passed. For UVM levels 1 and 2, the context's tensors
556557
# maintain static memory addresses, so the cuda graphs are re-used.
557-
if self.unified_memory_level == 0 and not args.rl_persist_cuda_graphs:
558+
if self.unified_memory_level == 0 and not self.persist_cuda_graphs:
558559
delete_cuda_graphs()
559560

560561
# Maintain references to requests before reset.
@@ -596,7 +597,7 @@ def resume(self):
596597
# 0). For levels 1 and 2, the context's tensors maintain static
597598
# memory addresses, so the cuda graphs are re-used.
598599
capture_time = time.time()
599-
if self.unified_memory_level == 0:
600+
if self.unified_memory_level == 0 and not self.persist_cuda_graphs:
600601
self.create_cuda_graphs()
601602
capture_time = time.time() - capture_time
602603

megatron/rl/inference/megatron.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
136136
use_flashinfer_fused_rope=None,
137137
unified_memory_level=args.inference_dynamic_batching_unified_memory_level,
138138
metrics_writer=metrics_writer,
139+
persist_cuda_graphs=args.rl_persist_cuda_graphs
139140
)
140141

141142
inference_wrapped_model = GPTInferenceWrapper(model, args, inference_context)

0 commit comments

Comments
 (0)