@@ -167,6 +167,7 @@ def __init__(
167167 self .enable_chunked_prefill = enable_chunked_prefill
168168 self .inference_logging_step_interval = inference_logging_step_interval
169169 self .unified_memory_level = context .unified_memory_level
170+ self .persist_cuda_graphs = context .persist_cuda_graphs
170171
171172 if enable_cuda_graph is not None :
172173 self .cuda_graph_impl = "local" if enable_cuda_graph else "none"
@@ -552,9 +553,9 @@ def suspend(self):
552553 self .context .deallocate_all_tensors ()
553554
554555 # Delete cuda graphs when not using unified memory at all (level 0) and
555- # `rl-persist-cuda-graphs` is not passed. For UVM levels 1 and 2, the context's tensors
556+ # `-- rl-persist-cuda-graphs` is not passed. For UVM levels 1 and 2, the context's tensors
556557 # maintain static memory addresses, so the cuda graphs are re-used.
557- if self .unified_memory_level == 0 and not args . rl_persist_cuda_graphs :
558+ if self .unified_memory_level == 0 and not self . persist_cuda_graphs :
558559 delete_cuda_graphs ()
559560
560561 # Maintain references to requests before reset.
@@ -596,7 +597,7 @@ def resume(self):
596597 # 0). For levels 1 and 2, the context's tensors maintain static
597598 # memory addresses, so the cuda graphs are re-used.
598599 capture_time = time .time ()
599- if self .unified_memory_level == 0 :
600+ if self .unified_memory_level == 0 and not self . persist_cuda_graphs :
600601 self .create_cuda_graphs ()
601602 capture_time = time .time () - capture_time
602603
0 commit comments