1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import copy
1415import logging
16+ import re
1517import time
1618from abc import abstractmethod
1719from enum import Enum
@@ -89,8 +91,29 @@ class PV_NAME(Enum):
8991]
9092
9193
94+ def uuid4_to_rfc1123 (uuid_str : str ) -> str :
95+ name = uuid_str .lower ()
96+ # Strip any chars that aren't alphanumeric or hyphen
97+ name = re .sub (r"[^a-z0-9-]" , "" , name )
98+ # Prefix with a letter if it starts with a digit
99+ if name and name [0 ].isdigit ():
100+ name = "j" + name
101+ # Kubernetes label limit: 63 chars; strip trailing hyphens after truncation
102+ # (truncation can expose a hyphen that was interior before slicing)
103+ return name [:63 ].rstrip ("-" )
104+
105+
92106class K8sJobHandle (JobHandleSpec ):
93- def __init__ (self , job_id : str , api_instance : core_v1_api , job_config : dict , namespace = "default" , timeout = None ):
107+ def __init__ (
108+ self ,
109+ job_id : str ,
110+ api_instance : core_v1_api ,
111+ job_config : dict ,
112+ namespace = "default" ,
113+ timeout = None ,
114+ pending_timeout = 30 ,
115+ python_path = "/usr/local/bin/python" ,
116+ ):
94117 super ().__init__ ()
95118 self .job_id = job_id
96119 self .timeout = timeout
@@ -113,8 +136,7 @@ def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, nam
113136 {
114137 "image" : None ,
115138 "name" : None ,
116- "resources" : None ,
117- "command" : ["/usr/local/bin/python" ],
139+ "command" : [python_path ],
118140 "args" : None , # args_list + args_dict + args_sets
119141 "volumeMounts" : None , # volume_mount_list
120142 "imagePullPolicy" : "Always" ,
@@ -127,14 +149,13 @@ def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, nam
127149 self .container_volume_mount_list = []
128150 self ._make_manifest (job_config )
129151 self ._stuck_count = 0
130- self ._stuck_grace_period = 10 # seconds to wait before counting Pending as stuck
131- self ._max_stuck_count = (self .timeout + self ._stuck_grace_period ) if self .timeout is not None else None
152+ self ._max_stuck_count = self .timeout if self .timeout is not None else pending_timeout
132153 self .logger = logging .getLogger (self .__class__ .__name__ )
133154
134155 def _make_manifest (self , job_config ):
135156 self .container_volume_mount_list .extend (job_config .get ("volume_mount_list" , []))
136157 set_list = job_config .get ("set_list" )
137- if set_list is None :
158+ if not set_list :
138159 self .container_args_module_args_sets = list ()
139160 else :
140161 self .container_args_module_args_sets = ["--set" ] + set_list
@@ -147,57 +168,64 @@ def _make_manifest(self, job_config):
147168 if v is None :
148169 continue
149170 self .container_args_module_args_dict_as_list .append (k )
150- self .container_args_module_args_dict_as_list .append (v )
171+ self .container_args_module_args_dict_as_list .append (str ( v ) )
151172 self .volume_list .extend (job_config .get ("volume_list" , []))
152173 self .pod_manifest ["metadata" ]["name" ] = job_config .get ("name" )
153174 self .pod_manifest ["spec" ]["containers" ] = self .container_list
154175 self .pod_manifest ["spec" ]["volumes" ] = self .volume_list
155176
156- self .container_list [0 ]["image" ] = job_config .get ("image" , "nvflare/nvflare:2.8.0" )
177+ image = job_config .get ("image" )
178+ if not image :
179+ raise ValueError ("job_config must contain a non-empty 'image' key" )
180+ self .container_list [0 ]["image" ] = image
157181 self .container_list [0 ]["name" ] = job_config .get ("container_name" , "nvflare_job" )
158182 self .container_list [0 ]["args" ] = (
159183 self .container_args_python_args_list
160184 + self .container_args_module_args_dict_as_list
161185 + self .container_args_module_args_sets
162186 )
163187 self .container_list [0 ]["volumeMounts" ] = self .container_volume_mount_list
164- if job_config .get ("resources" , {}).get ("limits" , {}).get ("nvidia.com/gpu" ) is not None :
188+ if job_config .get ("resources" , {}).get ("limits" , {}).get ("nvidia.com/gpu" ):
165189 self .container_list [0 ]["resources" ] = job_config .get ("resources" )
166190
167191 def get_manifest (self ):
168- return self .pod_manifest
192+ return copy . deepcopy ( self .pod_manifest )
169193
170- def enter_states (self , job_states_to_enter : list , timeout = None ):
194+ def enter_states (self , job_states_to_enter : list ):
171195 starting_time = time .time ()
172196 if not isinstance (job_states_to_enter , (list , tuple )):
173197 job_states_to_enter = [job_states_to_enter ]
174198 if not all ([isinstance (js , JobState ) for js in job_states_to_enter ]):
175199 raise ValueError (f"expect job_states_to_enter with valid values, but get { job_states_to_enter } " )
176200 while True :
177201 pod_phase = self ._query_phase ()
178- if self ._stuck (pod_phase ):
202+ if self ._stuck_in_pending (pod_phase ):
179203 self .terminate ()
180204 return False
181205 job_state = POD_STATE_MAPPING .get (pod_phase , JobState .UNKNOWN )
182206 if job_state in job_states_to_enter :
183207 return True
184- elif timeout is not None and time .time () - starting_time > timeout :
208+ elif pod_phase in [POD_Phase .FAILED .value , POD_Phase .SUCCEEDED .value ]: # terminal state
209+ self .terminal_state = POD_STATE_MAPPING .get (pod_phase , JobState .UNKNOWN )
210+ return False
211+ elif self .timeout is not None and time .time () - starting_time > self .timeout :
212+ self .terminate ()
185213 return False
186214 time .sleep (1 )
187215
188216 def terminate (self ):
189217 try :
190- resp = self .api_instance .delete_namespaced_pod (
191- name = self .job_id , namespace = self .namespace , grace_period_seconds = 0
192- )
218+ self .api_instance .delete_namespaced_pod (name = self .job_id , namespace = self .namespace , grace_period_seconds = 0 )
193219 self .terminal_state = JobState .TERMINATED
194220 except ApiException as e :
195- # If the pod is already gone, treat it as terminated; otherwise, leave state unchanged.
196221 if getattr (e , "status" , None ) == 404 :
197222 self .logger .info (f"job { self .job_id } pod not found during termination; assuming terminated" )
198- self .terminal_state = JobState .TERMINATED
199223 else :
200224 self .logger .error (f"failed to terminate job { self .job_id } : { e } " )
225+ self .terminal_state = JobState .TERMINATED
226+ except Exception as e :
227+ self .logger .error (f"unexpected error terminating job { self .job_id } : { e } " )
228+ self .terminal_state = JobState .TERMINATED
201229 return None
202230
203231 def poll (self ):
@@ -210,20 +238,24 @@ def _query_phase(self):
210238 try :
211239 resp = self .api_instance .read_namespaced_pod (name = self .job_id , namespace = self .namespace )
212240 except ApiException as e :
241+ self .logger .warning (f"failed to query pod phase { self .job_id } : { e } " )
242+ return POD_Phase .UNKNOWN .value
243+ except Exception as e :
244+ self .logger .warning (f"unexpected error querying pod phase { self .job_id } : { e } " )
213245 return POD_Phase .UNKNOWN .value
214246 return resp .status .phase
215247
216248 def _query_state (self ):
217249 pod_phase = self ._query_phase ()
218250 return POD_STATE_MAPPING .get (pod_phase , JobState .UNKNOWN )
219251
220- def _stuck (self , current_phase ):
221- if self ._max_stuck_count is None :
222- return False
252+ def _stuck_in_pending (self , current_phase ):
223253 if current_phase == POD_Phase .PENDING .value :
224254 self ._stuck_count += 1
225- if self ._stuck_count > self ._max_stuck_count :
255+ if self ._max_stuck_count is not None and self . _stuck_count >= self ._max_stuck_count :
226256 return True
257+ else :
258+ self ._stuck_count = 0
227259 return False
228260
229261 def wait (self ):
@@ -246,6 +278,8 @@ def __init__(
246278 data_pvc_file_path : str ,
247279 timeout = None ,
248280 namespace = "default" ,
281+ pending_timeout = 30 ,
282+ python_path = "/usr/local/bin/python" ,
249283 ):
250284 super ().__init__ ()
251285 self .logger = logging .getLogger (self .__class__ .__name__ )
@@ -255,15 +289,18 @@ def __init__(
255289 self .data_pvc_file_path = data_pvc_file_path
256290 self .timeout = timeout
257291 self .namespace = namespace
292+ self .pending_timeout = pending_timeout
293+ self .python_path = python_path
258294 with open (data_pvc_file_path , "rt" ) as f :
259295 data_pvc_dict = yaml .safe_load (f )
260296 if not data_pvc_dict :
261297 raise ValueError (f"data_pvc_file_path '{ data_pvc_file_path } ' is empty or contains no PVC entries." )
262298 # data_pvc_dict will be pvc: mountPath
263299 # currently, support one pvc and always mount to /var/tmp/nvflare/data
264300 # ie, ignore the mountPath in data_pvc_dict
301+ if not isinstance (data_pvc_dict , dict ):
302+ raise ValueError (f"file at data_pvc_file_path '{ data_pvc_file_path } ' does not contain a dictionary." )
265303 self .data_pvc = list (data_pvc_dict .keys ())[0 ]
266-
267304 config .load_kube_config (config_file_path )
268305 try :
269306 c = Configuration ().get_default_copy ()
@@ -276,17 +313,22 @@ def __init__(
276313
277314 def launch_job (self , job_meta : dict , fl_ctx : FLContext ) -> JobHandleSpec :
278315 site_name = fl_ctx .get_identity_name ()
279- job_id = job_meta .get (JobConstants .JOB_ID )
316+ raw_job_id = job_meta .get (JobConstants .JOB_ID )
317+ if not raw_job_id :
318+ raise RuntimeError (f"missing { JobConstants .JOB_ID } in job_meta" )
319+ job_id = uuid4_to_rfc1123 (raw_job_id )
280320 args = fl_ctx .get_prop (FLContextKey .ARGS )
281321 job_image = extract_job_image (job_meta , site_name )
282322 site_resources = job_meta .get (JobMetaKey .RESOURCE_SPEC .value , {}).get (site_name , {})
283323 job_resource = site_resources .get ("num_of_gpus" , None )
284-
285324 job_args = fl_ctx .get_prop (FLContextKey .JOB_PROCESS_ARGS )
286325 if not job_args :
287326 raise RuntimeError (f"missing { FLContextKey .JOB_PROCESS_ARGS } in FLContext" )
288327
289- _ , job_cmd = job_args [JobProcessArgs .EXE_MODULE ]
328+ exe_module_entry = job_args .get (JobProcessArgs .EXE_MODULE )
329+ if not exe_module_entry :
330+ raise RuntimeError (f"missing { JobProcessArgs .EXE_MODULE } in { FLContextKey .JOB_PROCESS_ARGS } " )
331+ _ , job_cmd = exe_module_entry
290332 job_config = {
291333 "name" : job_id ,
292334 "image" : job_image ,
@@ -299,21 +341,36 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
299341 {"name" : PV_NAME .ETC .value , "persistentVolumeClaim" : {"claimName" : self .etc_pvc }},
300342 ],
301343 "module_args" : self .get_module_args (job_id , fl_ctx ),
302- "set_list" : args .set ,
303- "resources" : {"limits" : {"nvidia.com/gpu" : job_resource }},
304344 }
305-
306- job_handle = K8sJobHandle (job_id , self .core_v1 , job_config , namespace = self .namespace , timeout = self .timeout )
345+ if args is not None and getattr (args , "set" , None ) is not None :
346+ job_config .update ({"set_list" : args .set })
347+ if job_resource :
348+ job_config .update ({"resources" : {"limits" : {"nvidia.com/gpu" : job_resource }}})
349+ job_handle = K8sJobHandle (
350+ job_id ,
351+ self .core_v1 ,
352+ job_config ,
353+ namespace = self .namespace ,
354+ timeout = self .timeout ,
355+ pending_timeout = self .pending_timeout ,
356+ python_path = self .python_path ,
357+ )
307358 pod_manifest = job_handle .get_manifest ()
308359 self .logger .debug (f"launch job with k8s_launcher. { pod_manifest = } " )
309360 try :
310361 self .core_v1 .create_namespaced_pod (body = pod_manifest , namespace = self .namespace )
311- job_handle .enter_states ([JobState .RUNNING ], timeout = self .timeout )
362+ except Exception as e :
363+ self .logger .error (f"failed to launch job { job_id } : { e } " )
364+ job_handle .terminal_state = JobState .TERMINATED
312365 return job_handle
313- except ApiException as e :
314- self .logger .error (f"failed to launch job { self .job_id } : { e } " )
366+ try :
367+ entered_running = job_handle .enter_states ([JobState .RUNNING ])
368+ except BaseException :
315369 job_handle .terminate ()
316- return job_handle
370+ raise
371+ if not entered_running :
372+ self .logger .warning (f"unable to enter running phase { job_id } " )
373+ return job_handle
317374
318375 def handle_event (self , event_type : str , fl_ctx : FLContext ):
319376 if event_type == EventType .BEFORE_JOB_LAUNCH :
0 commit comments