Skip to content

Commit 2d9eb0f

Browse files
committed
Fix reference error
Address comments
1 parent 999f057 commit 2d9eb0f

File tree

7 files changed

+782
-744
lines changed

7 files changed

+782
-744
lines changed

docs/design/JobLauncher_and_JobHandle.md

Lines changed: 124 additions & 184 deletions
Large diffs are not rendered by default.

nvflare/apis/job_launcher_spec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from abc import abstractmethod
14+
from abc import ABC, abstractmethod
1515

1616
from nvflare.apis.fl_component import FLComponent
1717
from nvflare.apis.fl_constant import FLContextKey
@@ -56,7 +56,7 @@ def add_launcher(launcher, fl_ctx: FLContext):
5656
fl_ctx.set_prop(FLContextKey.JOB_LAUNCHER, job_launcher, private=True, sticky=False)
5757

5858

59-
class JobHandleSpec:
59+
class JobHandleSpec(ABC):
6060
@abstractmethod
6161
def terminate(self):
6262
"""To terminate the job run.
@@ -85,7 +85,7 @@ def wait(self):
8585
raise NotImplementedError()
8686

8787

88-
class JobLauncherSpec(FLComponent):
88+
class JobLauncherSpec(FLComponent, ABC):
8989
@abstractmethod
9090
def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
9191
"""To launch a job run.

nvflare/app_opt/job_launcher/docker_launcher.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ class DOCKER_STATE:
4545

4646

4747
class DockerJobHandle(JobHandleSpec):
48-
def __init__(self, timeout=None):
48+
def __init__(self, container, timeout=None):
4949
super().__init__()
5050

51-
self.container = None
51+
self.container = container
5252
self.timeout = timeout
5353
self.logger = logging.getLogger(self.__class__.__name__)
5454

@@ -68,9 +68,6 @@ def wait(self):
6868
if self.container:
6969
self.enter_states([DOCKER_STATE.EXITED, DOCKER_STATE.DEAD], self.timeout)
7070

71-
def _set_container(self, container):
72-
self.container = container
73-
7471
def _get_container(self):
7572
try:
7673
docker_client = docker.from_env()
@@ -123,7 +120,6 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
123120
docker_workspace = os.environ.get("NVFL_DOCKER_WORKSPACE")
124121
self.logger.info(f"launch_job {job_id} in docker_workspace: {docker_workspace}")
125122
docker_client = docker.from_env()
126-
handle = DockerJobHandle()
127123
try:
128124
container = docker_client.containers.run(
129125
job_image,
@@ -141,22 +137,24 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
141137
# ports=ports, # Map container ports to host ports (optional)
142138
)
143139
self.logger.info(f"Launch the job in DockerJobLauncher using image: {job_image}")
144-
handle._set_container(container)
140+
141+
handle = DockerJobHandle(container)
145142
try:
146-
launched = handle.enter_states([DOCKER_STATE.RUNNING], timeout=self.timeout)
147-
if not launched:
143+
if handle.enter_states([DOCKER_STATE.RUNNING], timeout=self.timeout):
144+
return handle
145+
else:
148146
handle.terminate()
149-
return handle
147+
return None
150148
except:
151149
handle.terminate()
152-
return handle
150+
return None
153151

154152
except docker.errors.ImageNotFound:
155153
self.logger.error(f"Failed to launcher job: {job_id} in DockerJobLauncher. Image '{job_image}' not found.")
156-
return handle
154+
return None
157155
except docker.errors.APIError as e:
158156
self.logger.error(f"Error starting container: {e}")
159-
return handle
157+
return None
160158

161159
def handle_event(self, event_type: str, fl_ctx: FLContext):
162160
if event_type == EventType.BEFORE_JOB_LAUNCH:

nvflare/app_opt/job_launcher/k8s_launcher.py

Lines changed: 91 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
import logging
16+
import re
1517
import time
1618
from abc import abstractmethod
1719
from enum import Enum
@@ -89,8 +91,29 @@ class PV_NAME(Enum):
8991
]
9092

9193

94+
def uuid4_to_rfc1123(uuid_str: str) -> str:
95+
name = uuid_str.lower()
96+
# Strip any chars that aren't alphanumeric or hyphen
97+
name = re.sub(r"[^a-z0-9-]", "", name)
98+
# Prefix with a letter if it starts with a digit
99+
if name and name[0].isdigit():
100+
name = "j" + name
101+
# Kubernetes label limit: 63 chars; strip trailing hyphens after truncation
102+
# (truncation can expose a hyphen that was interior before slicing)
103+
return name[:63].rstrip("-")
104+
105+
92106
class K8sJobHandle(JobHandleSpec):
93-
def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, namespace="default", timeout=None):
107+
def __init__(
108+
self,
109+
job_id: str,
110+
api_instance: core_v1_api,
111+
job_config: dict,
112+
namespace="default",
113+
timeout=None,
114+
pending_timeout=30,
115+
python_path="/usr/local/bin/python",
116+
):
94117
super().__init__()
95118
self.job_id = job_id
96119
self.timeout = timeout
@@ -113,8 +136,7 @@ def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, nam
113136
{
114137
"image": None,
115138
"name": None,
116-
"resources": None,
117-
"command": ["/usr/local/bin/python"],
139+
"command": [python_path],
118140
"args": None, # args_list + args_dict + args_sets
119141
"volumeMounts": None, # volume_mount_list
120142
"imagePullPolicy": "Always",
@@ -127,14 +149,13 @@ def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, nam
127149
self.container_volume_mount_list = []
128150
self._make_manifest(job_config)
129151
self._stuck_count = 0
130-
self._stuck_grace_period = 10 # seconds to wait before counting Pending as stuck
131-
self._max_stuck_count = (self.timeout + self._stuck_grace_period) if self.timeout is not None else None
152+
self._max_stuck_count = self.timeout if self.timeout is not None else pending_timeout
132153
self.logger = logging.getLogger(self.__class__.__name__)
133154

134155
def _make_manifest(self, job_config):
135156
self.container_volume_mount_list.extend(job_config.get("volume_mount_list", []))
136157
set_list = job_config.get("set_list")
137-
if set_list is None:
158+
if not set_list:
138159
self.container_args_module_args_sets = list()
139160
else:
140161
self.container_args_module_args_sets = ["--set"] + set_list
@@ -147,57 +168,64 @@ def _make_manifest(self, job_config):
147168
if v is None:
148169
continue
149170
self.container_args_module_args_dict_as_list.append(k)
150-
self.container_args_module_args_dict_as_list.append(v)
171+
self.container_args_module_args_dict_as_list.append(str(v))
151172
self.volume_list.extend(job_config.get("volume_list", []))
152173
self.pod_manifest["metadata"]["name"] = job_config.get("name")
153174
self.pod_manifest["spec"]["containers"] = self.container_list
154175
self.pod_manifest["spec"]["volumes"] = self.volume_list
155176

156-
self.container_list[0]["image"] = job_config.get("image", "nvflare/nvflare:2.8.0")
177+
image = job_config.get("image")
178+
if not image:
179+
raise ValueError("job_config must contain a non-empty 'image' key")
180+
self.container_list[0]["image"] = image
157181
self.container_list[0]["name"] = job_config.get("container_name", "nvflare_job")
158182
self.container_list[0]["args"] = (
159183
self.container_args_python_args_list
160184
+ self.container_args_module_args_dict_as_list
161185
+ self.container_args_module_args_sets
162186
)
163187
self.container_list[0]["volumeMounts"] = self.container_volume_mount_list
164-
if job_config.get("resources", {}).get("limits", {}).get("nvidia.com/gpu") is not None:
188+
if job_config.get("resources", {}).get("limits", {}).get("nvidia.com/gpu"):
165189
self.container_list[0]["resources"] = job_config.get("resources")
166190

167191
def get_manifest(self):
168-
return self.pod_manifest
192+
return copy.deepcopy(self.pod_manifest)
169193

170-
def enter_states(self, job_states_to_enter: list, timeout=None):
194+
def enter_states(self, job_states_to_enter: list):
171195
starting_time = time.time()
172196
if not isinstance(job_states_to_enter, (list, tuple)):
173197
job_states_to_enter = [job_states_to_enter]
174198
if not all([isinstance(js, JobState) for js in job_states_to_enter]):
175199
raise ValueError(f"expect job_states_to_enter with valid values, but get {job_states_to_enter}")
176200
while True:
177201
pod_phase = self._query_phase()
178-
if self._stuck(pod_phase):
202+
if self._stuck_in_pending(pod_phase):
179203
self.terminate()
180204
return False
181205
job_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
182206
if job_state in job_states_to_enter:
183207
return True
184-
elif timeout is not None and time.time() - starting_time > timeout:
208+
elif pod_phase in [POD_Phase.FAILED.value, POD_Phase.SUCCEEDED.value]: # terminal state
209+
self.terminal_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
210+
return False
211+
elif self.timeout is not None and time.time() - starting_time > self.timeout:
212+
self.terminate()
185213
return False
186214
time.sleep(1)
187215

188216
def terminate(self):
189217
try:
190-
resp = self.api_instance.delete_namespaced_pod(
191-
name=self.job_id, namespace=self.namespace, grace_period_seconds=0
192-
)
218+
self.api_instance.delete_namespaced_pod(name=self.job_id, namespace=self.namespace, grace_period_seconds=0)
193219
self.terminal_state = JobState.TERMINATED
194220
except ApiException as e:
195-
# If the pod is already gone, treat it as terminated; otherwise, leave state unchanged.
196221
if getattr(e, "status", None) == 404:
197222
self.logger.info(f"job {self.job_id} pod not found during termination; assuming terminated")
198-
self.terminal_state = JobState.TERMINATED
199223
else:
200224
self.logger.error(f"failed to terminate job {self.job_id}: {e}")
225+
self.terminal_state = JobState.TERMINATED
226+
except Exception as e:
227+
self.logger.error(f"unexpected error terminating job {self.job_id}: {e}")
228+
self.terminal_state = JobState.TERMINATED
201229
return None
202230

203231
def poll(self):
@@ -210,20 +238,24 @@ def _query_phase(self):
210238
try:
211239
resp = self.api_instance.read_namespaced_pod(name=self.job_id, namespace=self.namespace)
212240
except ApiException as e:
241+
self.logger.warning(f"failed to query pod phase {self.job_id}: {e}")
242+
return POD_Phase.UNKNOWN.value
243+
except Exception as e:
244+
self.logger.warning(f"unexpected error querying pod phase {self.job_id}: {e}")
213245
return POD_Phase.UNKNOWN.value
214246
return resp.status.phase
215247

216248
def _query_state(self):
217249
pod_phase = self._query_phase()
218250
return POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
219251

220-
def _stuck(self, current_phase):
221-
if self._max_stuck_count is None:
222-
return False
252+
def _stuck_in_pending(self, current_phase):
223253
if current_phase == POD_Phase.PENDING.value:
224254
self._stuck_count += 1
225-
if self._stuck_count > self._max_stuck_count:
255+
if self._max_stuck_count is not None and self._stuck_count >= self._max_stuck_count:
226256
return True
257+
else:
258+
self._stuck_count = 0
227259
return False
228260

229261
def wait(self):
@@ -246,6 +278,8 @@ def __init__(
246278
data_pvc_file_path: str,
247279
timeout=None,
248280
namespace="default",
281+
pending_timeout=30,
282+
python_path="/usr/local/bin/python",
249283
):
250284
super().__init__()
251285
self.logger = logging.getLogger(self.__class__.__name__)
@@ -255,15 +289,18 @@ def __init__(
255289
self.data_pvc_file_path = data_pvc_file_path
256290
self.timeout = timeout
257291
self.namespace = namespace
292+
self.pending_timeout = pending_timeout
293+
self.python_path = python_path
258294
with open(data_pvc_file_path, "rt") as f:
259295
data_pvc_dict = yaml.safe_load(f)
260296
if not data_pvc_dict:
261297
raise ValueError(f"data_pvc_file_path '{data_pvc_file_path}' is empty or contains no PVC entries.")
262298
# data_pvc_dict will be pvc: mountPath
263299
# currently, support one pvc and always mount to /var/tmp/nvflare/data
264300
# ie, ignore the mountPath in data_pvc_dict
301+
if not isinstance(data_pvc_dict, dict):
302+
raise ValueError(f"file at data_pvc_file_path '{data_pvc_file_path}' does not contain a dictionary.")
265303
self.data_pvc = list(data_pvc_dict.keys())[0]
266-
267304
config.load_kube_config(config_file_path)
268305
try:
269306
c = Configuration().get_default_copy()
@@ -276,17 +313,22 @@ def __init__(
276313

277314
def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
278315
site_name = fl_ctx.get_identity_name()
279-
job_id = job_meta.get(JobConstants.JOB_ID)
316+
raw_job_id = job_meta.get(JobConstants.JOB_ID)
317+
if not raw_job_id:
318+
raise RuntimeError(f"missing {JobConstants.JOB_ID} in job_meta")
319+
job_id = uuid4_to_rfc1123(raw_job_id)
280320
args = fl_ctx.get_prop(FLContextKey.ARGS)
281321
job_image = extract_job_image(job_meta, site_name)
282322
site_resources = job_meta.get(JobMetaKey.RESOURCE_SPEC.value, {}).get(site_name, {})
283323
job_resource = site_resources.get("num_of_gpus", None)
284-
285324
job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
286325
if not job_args:
287326
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")
288327

289-
_, job_cmd = job_args[JobProcessArgs.EXE_MODULE]
328+
exe_module_entry = job_args.get(JobProcessArgs.EXE_MODULE)
329+
if not exe_module_entry:
330+
raise RuntimeError(f"missing {JobProcessArgs.EXE_MODULE} in {FLContextKey.JOB_PROCESS_ARGS}")
331+
_, job_cmd = exe_module_entry
290332
job_config = {
291333
"name": job_id,
292334
"image": job_image,
@@ -299,21 +341,36 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
299341
{"name": PV_NAME.ETC.value, "persistentVolumeClaim": {"claimName": self.etc_pvc}},
300342
],
301343
"module_args": self.get_module_args(job_id, fl_ctx),
302-
"set_list": args.set,
303-
"resources": {"limits": {"nvidia.com/gpu": job_resource}},
304344
}
305-
306-
job_handle = K8sJobHandle(job_id, self.core_v1, job_config, namespace=self.namespace, timeout=self.timeout)
345+
if args is not None and getattr(args, "set", None) is not None:
346+
job_config.update({"set_list": args.set})
347+
if job_resource:
348+
job_config.update({"resources": {"limits": {"nvidia.com/gpu": job_resource}}})
349+
job_handle = K8sJobHandle(
350+
job_id,
351+
self.core_v1,
352+
job_config,
353+
namespace=self.namespace,
354+
timeout=self.timeout,
355+
pending_timeout=self.pending_timeout,
356+
python_path=self.python_path,
357+
)
307358
pod_manifest = job_handle.get_manifest()
308359
self.logger.debug(f"launch job with k8s_launcher. {pod_manifest=}")
309360
try:
310361
self.core_v1.create_namespaced_pod(body=pod_manifest, namespace=self.namespace)
311-
job_handle.enter_states([JobState.RUNNING], timeout=self.timeout)
362+
except Exception as e:
363+
self.logger.error(f"failed to launch job {job_id}: {e}")
364+
job_handle.terminal_state = JobState.TERMINATED
312365
return job_handle
313-
except ApiException as e:
314-
self.logger.error(f"failed to launch job {self.job_id}: {e}")
366+
try:
367+
entered_running = job_handle.enter_states([JobState.RUNNING])
368+
except BaseException:
315369
job_handle.terminate()
316-
return job_handle
370+
raise
371+
if not entered_running:
372+
self.logger.warning(f"unable to enter running phase {job_id}")
373+
return job_handle
317374

318375
def handle_event(self, event_type: str, fl_ctx: FLContext):
319376
if event_type == EventType.BEFORE_JOB_LAUNCH:

nvflare/private/fed/client/communicator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(
5959
client_register_interval=2,
6060
timeout=5.0,
6161
maint_msg_timeout=5.0,
62-
cell_creation_timeout=15.0,
6362
):
6463
"""To init the Communicator.
6564
@@ -80,7 +79,7 @@ def __init__(
8079
self.client_register_interval = client_register_interval
8180
self.timeout = timeout
8281
self.maint_msg_timeout = maint_msg_timeout
83-
self.creation_timeout = cell_creation_timeout
82+
8483
# token and token_signature are issued by the Server after the client is authenticated
8584
# they are added to every message going to the server as proof of authentication
8685
self.token = None
@@ -274,9 +273,9 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext):
274273
start = time.time()
275274
while not self.cell:
276275
self.logger.info("Waiting for the client cell to be created.")
277-
if time.time() - start > self.creation_timeout:
276+
if time.time() - start > 15.0:
278277
raise RuntimeError("Client cell could not be created. Failed to login the client.")
279-
time.sleep(1)
278+
time.sleep(0.5)
280279

281280
shared_fl_ctx = gen_new_peer_ctx(fl_ctx)
282281
private_key_file = None

0 commit comments

Comments
 (0)