diff --git a/docs/design/JobLauncher_and_JobHandle.md b/docs/design/JobLauncher_and_JobHandle.md new file mode 100644 index 0000000000..8ddeeae839 --- /dev/null +++ b/docs/design/JobLauncher_and_JobHandle.md @@ -0,0 +1,581 @@ +# JobLauncher and JobHandle Design Document + +## 1. Overview + +NVFlare runs each federated job as an isolated execution unit -- a subprocess or Kubernetes pod. Two abstractions govern this: + +- **JobLauncherSpec** -- starts a job and returns a handle. +- **JobHandleSpec** -- represents the running job and provides lifecycle control (poll, wait, terminate). + +The upper layers (server engine, client executor) program exclusively against these two interfaces. The concrete backend is selected at runtime through an event-based mechanism, so the engine never imports or names a specific launcher type. + +``` +┌──────────────────────────────────────────────────────────┐ +│ Upper Layer │ +│ ServerEngine / ClientExecutor │ +│ │ +│ 1. get_job_launcher(job_meta, fl_ctx) → launcher │ +│ 2. Build JOB_PROCESS_ARGS │ +│ 3. launcher.launch_job(job_meta, fl_ctx) → job_handle │ +│ 4. job_handle.wait() / job_handle.terminate() │ +└──────────┬──────────────────────┬────────────────────────┘ + │ BEFORE_JOB_LAUNCH │ + │ event selects one │ + ▼ ▼ +┌─────────────────┐ ┌─────────────────┐ +│ ProcessJob │ │ K8sJob │ +│ Launcher │ │ Launcher │ +│ ─────────────── │ │ ─────────────── │ +│ ProcessHandle │ │ K8sJobHandle │ +└─────────────────┘ └─────────────────┘ + subprocess pod +``` + +--- + +## 2. Specification Layer (`nvflare/apis/job_launcher_spec.py`) + +### 2.1 JobHandleSpec + +Abstract base class representing a running job (`class JobHandleSpec(ABC)`). All methods are decorated with `@abstractmethod`; Python enforces that any concrete subclass implements all three. Attempting to instantiate a subclass with any abstract method unimplemented raises `TypeError` at instantiation time. + +| Method | Signature | Semantics | +|--------|-----------|-----------| +| `terminate()` | `() -> None` | Stop the job immediately. | +| `poll()` | `() -> JobReturnCode` | Non-blocking query for the job's current return code. Returns `UNKNOWN` while still running. | +| `wait()` | `() -> None` | Block until the job finishes (or is terminated). | + +### 2.2 JobLauncherSpec + +Abstract base class for launching jobs (`class JobLauncherSpec(FLComponent, ABC)`). Extends `FLComponent` for event-system access. Adding `ABC` to the bases means Python selects `ABCMeta` as the metaclass automatically (since `ABCMeta` is a subclass of `type`), so `@abstractmethod` on `launch_job` is enforced at runtime for all subclasses. + +| Method | Signature | Semantics | +|--------|-----------|-----------| +| `launch_job(job_meta, fl_ctx)` | `(dict, FLContext) -> JobHandleSpec` | Start a job and return its handle. | + +### 2.3 Supporting Types + +**JobProcessArgs** -- String constants for the keys the upper layer places in `FLContextKey.JOB_PROCESS_ARGS`. Each value in the dict is a `(flag, value)` tuple, e.g. `JobProcessArgs.JOB_ID → ("-n", job_id)`. These are the standardized parameters every job process needs (workspace path, auth token, job ID, parent URL, etc.). + +| Constant | Value | Used by | +|----------|-------|---------| +| `EXE_MODULE` | `"exe_module"` | Server, Client | +| `WORKSPACE` | `"workspace"` | Server, Client | +| `STARTUP_DIR` | `"startup_dir"` | Client | +| `APP_ROOT` | `"app_root"` | Server | +| `AUTH_TOKEN` | `"auth_token"` | Server, Client | +| `TOKEN_SIGNATURE` | `"auth_signature"` | Server, Client | +| `SSID` | `"ssid"` | Server, Client | +| `JOB_ID` | `"job_id"` | Server, Client | +| `CLIENT_NAME` | `"client_name"` | Client | +| `ROOT_URL` | `"root_url"` | Server | +| `PARENT_URL` | `"parent_url"` | Server, Client | +| `PARENT_CONN_SEC` | `"parent_conn_sec"` | Client | +| `SERVICE_HOST` | `"service_host"` | Server | +| `SERVICE_PORT` | `"service_port"` | Server | +| `HA_MODE` | `"ha_mode"` | Server | +| `TARGET` | `"target"` | Client | +| `SCHEME` | `"scheme"` | Client | +| `STARTUP_CONFIG_FILE` | `"startup_config_file"` | Server, Client | +| `RESTORE_SNAPSHOT` | `"restore_snapshot"` | (defined but not set as a standalone entry; server embeds `restore_snapshot=` into `OPTIONS` via `--set`) | +| `OPTIONS` | `"options"` | Server, Client | + +**JobReturnCode** -- Standard exit semantics (extends `ProcessExitCode`): + +| Code | Value | Meaning | +|------|-------|---------| +| `SUCCESS` | 0 | Job completed successfully. | +| `EXECUTION_ERROR` | 1 | Job failed during execution. | +| `ABORTED` | 9 | Job was terminated/aborted. | +| `UNKNOWN` | 127 | Status cannot be determined (still running, or lost). | + +**`add_launcher(launcher, fl_ctx)`** -- Appends a launcher to the `FLContextKey.JOB_LAUNCHER` list on `fl_ctx`. Called by launchers during the `BEFORE_JOB_LAUNCH` event to volunteer for the current job. + +--- + +## 3. How the Upper Layer Uses Launchers + +### 3.1 Event-Based Launcher Selection + +The engine never directly instantiates a launcher. Instead, it calls `get_job_launcher()` from `nvflare/private/fed/utils/fed_utils.py`: + +```python +def get_job_launcher(job_meta, fl_ctx) -> JobLauncherSpec: + engine = fl_ctx.get_engine() + with engine.new_context() as job_launcher_ctx: + job_launcher_ctx.remove_prop(FLContextKey.JOB_LAUNCHER) + job_launcher_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + engine.fire_event(EventType.BEFORE_JOB_LAUNCH, job_launcher_ctx) + job_launcher = job_launcher_ctx.get_prop(FLContextKey.JOB_LAUNCHER) + if not (job_launcher and isinstance(job_launcher, list)): + raise RuntimeError(f"There's no job launcher can handle this job: {job_meta}.") + launcher = job_launcher[0] + if not isinstance(launcher, JobLauncherSpec): + raise RuntimeError(f"The job launcher must be JobLauncherSpec but got {type(launcher)}") + return job_launcher[0] +``` + +Every registered `FLComponent` receives the `BEFORE_JOB_LAUNCH` event. Each launcher inspects `job_meta` and, if it can handle the job, calls `add_launcher(self, fl_ctx)`. The first launcher to register wins. + +**Selection rule in practice:** + +| Condition | Launcher selected | +|-----------|-------------------| +| `extract_job_image(job_meta, site_name)` returns `None` | **ProcessJobLauncher** (no container image → run as subprocess) | +| `extract_job_image(job_meta, site_name)` returns an image | **K8sJobLauncher** (configured as a component) | + +### 3.2 Server Side (`ServerEngine`) + +Location: `nvflare/private/fed/server/server_engine.py` + +``` +_start_runner_process(job, job_clients, snapshot, fl_ctx) +│ +├─ 1. job_launcher = get_job_launcher(job.meta, fl_ctx) +│ (fires BEFORE_JOB_LAUNCH; JOB_PROCESS_ARGS not yet set) +│ +├─ 2. Build job_args dict with server-specific JobProcessArgs +│ (EXE_MODULE, JOB_ID, WORKSPACE, STARTUP_CONFIG_FILE, +│ APP_ROOT, HA_MODE, AUTH_TOKEN, TOKEN_SIGNATURE, +│ PARENT_URL, ROOT_URL, SERVICE_HOST, SERVICE_PORT, +│ SSID, OPTIONS) +│ +├─ 3. fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args) +│ +├─ 4. job_handle = job_launcher.launch_job(job.meta, fl_ctx) +│ +├─ 5. Store in run_processes[job_id] +│ {JOB_HANDLE: job_handle, JOB_ID: job_id, PARTICIPANTS: job_clients} +│ +└─ 6. Start background thread → wait_for_complete(workspace, job_id, job_handle) + │ + ├─ job_handle.wait() # blocks until job finishes + ├─ wait up to 2s for UPDATE_RUN_STATUS message to arrive + └─ get_return_code(job_handle, job_id, workspace, logger) +``` + +**Abort path** (`abort_app_on_server`): + +1. Attempt to send an abort command to the child via the cell messaging system. +2. On failure, retrieve `job_handle` from `run_processes` and call `job_handle.terminate()`. + +### 3.3 Client Side (`ClientExecutor`) + +Location: `nvflare/private/fed/client/client_executor.py` + +``` +start_app(job_id, job_meta, ...) +│ +├─ 1. job_launcher = get_job_launcher(job_meta, fl_ctx) +│ (fires BEFORE_JOB_LAUNCH; JOB_PROCESS_ARGS not yet set) +│ +├─ 2. Build job_args dict with client-specific JobProcessArgs +│ (EXE_MODULE, JOB_ID, CLIENT_NAME, AUTH_TOKEN, TOKEN_SIGNATURE, +│ SSID, WORKSPACE, STARTUP_DIR, PARENT_URL, SCHEME, TARGET, +│ STARTUP_CONFIG_FILE, OPTIONS, optionally PARENT_CONN_SEC) +│ +├─ 3. fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args) +│ +├─ 4. job_handle = job_launcher.launch_job(job_meta, fl_ctx) +│ +├─ 5. Fire EventType.AFTER_JOB_LAUNCH event +│ +├─ 6. Store in run_processes[job_id] +│ {JOB_HANDLE: job_handle, STATUS: ClientStatus.STARTING} +│ +└─ 7. Start background thread → _wait_child_process_finish(...) + │ + ├─ job_handle.wait() + └─ get_return_code(job_handle, job_id, workspace, logger) +``` + +**Abort path** (`_terminate_job`): + +1. Poll `self.run_processes.get(job_id)` every 50 ms for up to 10 seconds; if the entry disappears the job finished gracefully. +2. Always call `job_handle.terminate()` regardless of whether the graceful exit was detected. + +### 3.4 Return Code Resolution + +`get_return_code()` in `fed_utils.py` uses a two-tier strategy: + +1. **File-based** -- Check for `FLMetaKey.PROCESS_RC_FILE` in the job's run directory. The child process writes its own return code to this file before exiting. This is the preferred source because it carries the child's own assessment. +2. **Handle-based** -- Fall back to `job_handle.poll()`, which maps the underlying execution unit's status to a `JobReturnCode`. + +--- + +## 4. The Two Implementations + +### 4.1 Process Launcher (Subprocess) + +**Files:** + +| File | Class | +|------|-------| +| `nvflare/app_common/job_launcher/process_launcher.py` | `ProcessHandle`, `ProcessJobLauncher` | +| `nvflare/app_common/job_launcher/server_process_launcher.py` | `ServerProcessJobLauncher` | +| `nvflare/app_common/job_launcher/client_process_launcher.py` | `ClientProcessJobLauncher` | + +**Class hierarchy:** + +``` +JobHandleSpec + └── ProcessHandle + +JobLauncherSpec (FLComponent) + └── ProcessJobLauncher + ├── ServerProcessJobLauncher + └── ClientProcessJobLauncher +``` + +#### ProcessHandle + +Wraps a `ProcessAdapter` (from `nvflare/utils/process_utils.py`) that manages a `subprocess.Popen` or a PID. The constructor accepts any one of: a `ProcessAdapter` directly, a `subprocess.Popen` object, or an integer `pid`. + +| Method | Implementation | +|--------|---------------| +| `terminate()` | Delegates to `adapter.terminate()` (sends SIGTERM/SIGKILL). | +| `poll()` | Calls `adapter.poll()`. Maps exit code 0 → `SUCCESS`, 1 → `EXECUTION_ERROR`, 9 → `ABORTED`, other → `EXECUTION_ERROR`, `None` → `UNKNOWN`. | +| `wait()` | Delegates to `adapter.wait()` (blocks on `subprocess.Popen.wait()`). | + +#### ProcessJobLauncher + +| Step | Action | +|------|--------| +| 1 | Copy `os.environ`. If `app_custom_folder` is non-empty, call `add_custom_dir_to_path()`: appends the folder to `sys.path` and serializes the result into `PYTHONPATH` in the child environment. | +| 2 | Call `self.get_command(job_meta, fl_ctx)` (abstract -- implemented by server/client subclasses). | +| 3 | Parse command with `shlex.split(command, posix=True)`, spawn the process via `spawn_process(argv, new_env)`. | +| 4 | Return `ProcessHandle(process_adapter=...)`. | + +**Event registration:** + +```python +def handle_event(self, event_type, fl_ctx): + if event_type == EventType.BEFORE_JOB_LAUNCH: + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) + if not job_image: # no container image → use subprocess + add_launcher(self, fl_ctx) +``` + +**Server/Client subclasses** only override `get_command()`: + +- `ServerProcessJobLauncher.get_command()` → `generate_server_command(fl_ctx)` → `sys.executable -m -w ...` +- `ClientProcessJobLauncher.get_command()` → `generate_client_command(fl_ctx)` → `sys.executable -m -w -n ...` + +--- + +### 4.2 Kubernetes Launcher + +**File:** `nvflare/app_opt/job_launcher/k8s_launcher.py` + +**Class hierarchy:** + +``` +JobHandleSpec (ABC) + └── K8sJobHandle + +JobLauncherSpec (FLComponent, ABC) + └── K8sJobLauncher + ├── ClientK8sJobLauncher + └── ServerK8sJobLauncher +``` + +#### Pod Name Sanitization + +`launch_job` calls `uuid4_to_rfc1123(job_meta.get(JobConstants.JOB_ID))` before constructing the pod. This converts a raw UUID4 job ID into an RFC 1123-compliant Kubernetes pod name: + +1. Lowercase the string. +2. Strip characters that are not alphanumeric or hyphens (`[^a-z0-9-]`). +3. Prefix with `"j"` if the first character is a digit (Kubernetes pod names must start with a letter). +4. Strip trailing hyphens. +5. Truncate to 63 characters. + +The sanitized name is used as both the pod name (`metadata.name`) and the `job_id` stored in `K8sJobHandle` for all subsequent API calls (`terminate`, `_query_phase`). + +#### K8sJobHandle + +Wraps a Kubernetes Pod managed through the `CoreV1Api`. + +| Method | Implementation | +|--------|---------------| +| `terminate()` | Calls `delete_namespaced_pod(grace_period_seconds=0)`. `terminal_state = TERMINATED` is always set regardless of outcome: on success, on 404 `ApiException` (pod already gone, logged at `info`), on any other `ApiException` (logged at `error`), and on any other `Exception` such as network or serialization errors (also logged at `error`). This guarantees that callers holding a handle never poll indefinitely after calling `terminate()`, even when the K8s API is unreachable. | +| `poll()` | If `terminal_state` is set, maps it through `JOB_RETURN_CODE_MAPPING` and returns a `JobReturnCode`. Otherwise calls `_query_state()` and maps the result the same way. Both paths consistently return `JobReturnCode`. | +| `wait()` | Direct while loop: returns immediately if `terminal_state` is set; otherwise calls `_query_state()` and when `SUCCEEDED` or `TERMINATED` is reached, persists that state into `terminal_state` (so subsequent `poll()` calls remain accurate) and returns. Sleeps 1 second per iteration. No timeout. | +| `_query_phase()` | Calls `read_namespaced_pod` and returns the raw pod phase string (e.g. `"Pending"`, `"Running"`). On `ApiException`, logs the error and returns `POD_Phase.UNKNOWN.value`. On any other `Exception` (e.g. network errors, `urllib3.exceptions.MaxRetryError`), logs the error and also returns `POD_Phase.UNKNOWN.value` — preventing unhandled exceptions from propagating through `enter_states`/`wait`/`poll` and orphaning running pods. | +| `_query_state()` | Calls `_query_phase()` and maps the raw phase through `POD_STATE_MAPPING` to a `JobState`. Used by `poll()` and `wait()`. | +| `enter_states()` | Takes only `job_states_to_enter`; no `timeout` parameter — reads `self.timeout` from the instance. Per iteration: calls `_query_phase()` once, passes the raw phase to `_stuck_in_pending()` and to `POD_STATE_MAPPING.get()` — single K8s API call per poll cycle. Three early-exit paths, evaluated in this order: (1) **Stuck detection** — calls `self.terminate()` (sets `terminal_state = TERMINATED`) then returns `False`; (2) **Terminal phase** — if `pod_phase` is `"Failed"` or `"Succeeded"`, sets `terminal_state` from `POD_STATE_MAPPING` then returns `False` *without* calling `terminate()`. The terminal-phase check is evaluated before the plain-timeout check so that a job completing exactly when the timeout expires is recorded as `SUCCEEDED`/`TERMINATED` (per `POD_STATE_MAPPING`) rather than being misclassified as `TERMINATED` by `terminate()`; (3) **Plain timeout** (`self.timeout` elapsed) — calls `self.terminate()`, sets `terminal_state = TERMINATED`, returns `False`. Setting `terminal_state` in all `False`-return paths prevents `wait()` from looping indefinitely if the pod is GC'd before `wait()` runs. Returns `True` when target state is reached. | + +Pod phase mapping: + +| Pod Phase | JobState | JobReturnCode | +|-----------|----------|---------------| +| `Pending` | `STARTING` | `UNKNOWN` | +| `Running` | `RUNNING` | `UNKNOWN` | +| `Succeeded` | `SUCCEEDED` | `SUCCESS` | +| `Failed` | `TERMINATED` | `ABORTED` | +| `Unknown` | `UNKNOWN` | `UNKNOWN` | + +**Stuck detection:** `_stuck_count` starts at `0`. `_max_stuck_count` is set in the constructor as: `timeout if timeout is not None else pending_timeout`. So stuck detection is **always active** — if `timeout` is provided, `_max_stuck_count = timeout`; if `timeout` is `None`, `_max_stuck_count = pending_timeout` (default 30). The method `_stuck_in_pending(current_phase)` increments `_stuck_count` each time `current_phase == "Pending"` and returns `True` when `_stuck_count > _max_stuck_count`. When the phase is **not** `Pending`, `_stuck_count` is **reset to 0** — so a pod that transitions out of Pending (e.g. briefly reaches Running then goes back to Pending) starts its stuck count fresh. When it returns `True`, `enter_states()` calls `terminate()` and returns `False`. Plain startup timeout (wall-clock elapsed > `timeout`) also calls `terminate()` before returning `False`. In both cases `terminal_state` is always set to `TERMINATED` by `terminate()`, regardless of whether the API call succeeds or raises. Note: `_stuck_count` and `_max_stuck_count` are poll-iteration counts (~1 second each). + +#### K8sJobHandle Pod Manifest + +The handle constructs the pod manifest internally from a `job_config` dict: + +```yaml +apiVersion: v1 +kind: Pod +metadata: + name: +spec: + restartPolicy: Never + containers: + - name: container- + image: + command: [""] # default: /usr/local/bin/python; configurable via K8sJobLauncher.python_path + args: ["-u", "-m", "", "-w", "", ...] + volumeMounts: + - name: nvflws + mountPath: /var/tmp/nvflare/workspace + - name: nvfldata + mountPath: /var/tmp/nvflare/data + - name: nvfletc + mountPath: /var/tmp/nvflare/etc + resources: + limits: + nvidia.com/gpu: # omitted if None + imagePullPolicy: Always + volumes: + - name: nvflws + persistentVolumeClaim: + claimName: + - name: nvfldata + persistentVolumeClaim: + claimName: + - name: nvfletc + persistentVolumeClaim: + claimName: +``` + +#### K8sJobLauncher + +Constructor parameters: + +| Parameter | Default | Purpose | +|-----------|---------|---------| +| `config_file_path` | (required) | Path to kubeconfig file. Loaded via `config.load_kube_config()` at init time. | +| `workspace_pvc` | (required) | PVC claim name for the NVFlare workspace. | +| `etc_pvc` | (required) | PVC claim name for configuration/etc data. | +| `data_pvc_file_path` | (required) | Path to a YAML file mapping PVC names to mount paths for training data. Read and validated at init time; raises `ValueError` if the file is empty/contains no entries, or if the parsed YAML is not a `dict`. Only the first key (PVC name) is used. | +| `timeout` | `None` | Maximum wall-clock seconds for `enter_states([RUNNING])`. Also used as `_max_stuck_count`. If `None`, `pending_timeout` governs stuck detection instead. | +| `namespace` | `"default"` | Kubernetes namespace. | +| `pending_timeout` | `30` | Stuck-detection threshold (poll iterations) when `timeout` is `None`. Passed to `K8sJobHandle`. | +| `python_path` | `"/usr/local/bin/python"` | Absolute path to the Python executable used as the container `command`. Override for non-standard images where Python is not at `/usr/local/bin/python` (e.g. `/usr/bin/python3`). Passed through to `K8sJobHandle`. | + +Launch sequence: + +| Step | Action | +|------|--------| +| 1 | Validate and sanitize job ID: `raw_job_id = job_meta.get(JobConstants.JOB_ID)`; raises `RuntimeError` if missing or falsy. Then `job_id = uuid4_to_rfc1123(raw_job_id)`. All subsequent operations use this RFC 1123-compliant name. Extract `job_image`, `site_name`, and optional `num_of_gpus` from `job_meta`. | +| 2 | Read `JOB_PROCESS_ARGS` from `fl_ctx`; raises `RuntimeError` if the dict is absent. Raises `RuntimeError` if `EXE_MODULE` is missing or falsy. Extract the module name via `_, job_cmd = exe_module_entry`. | +| 3 | Build `job_config` dict: name (`job_id`), image, container name (`container-{job_id}`), command, volume mounts/PVCs, `module_args` from `get_module_args()`. `set_list` is conditionally set: if `fl_ctx.get_prop(FLContextKey.ARGS)` is non-None and `getattr(args, "set", None)` is non-None, `set_list = args.set` is added (see note below). Using `getattr` rather than direct attribute access guards against non-standard `ARGS` objects that lack a `set` attribute. If `num_of_gpus` is truthy (non-None **and** non-zero), adds `job_config["resources"] = {"limits": {"nvidia.com/gpu": num_of_gpus}}`; a value of `0` is intentionally excluded to avoid injecting `nvidia.com/gpu: 0` into the pod spec, which would explicitly request zero GPUs and can affect scheduling on GPU-enabled clusters differently than omitting the limit entirely. | +| 4 | Create `K8sJobHandle(job_id, core_v1, job_config, namespace=self.namespace, timeout=self.timeout, pending_timeout=self.pending_timeout)` which builds the pod manifest. | +| 5 | `core_v1.create_namespaced_pod(body=pod_manifest, namespace)` in a `try/except Exception` block. On any exception — including `ApiException` (K8s API error) and lower-level errors such as network timeouts or serialization failures — `job_handle.terminate()` is called (always sets `terminal_state = TERMINATED`) and the handle is returned. The scope of this handler is intentionally limited to this single API call; it does not swallow exceptions from the polling loop in step 6. | +| 6 | Call `job_handle.enter_states([RUNNING])` in a separate `try/except BaseException` block. On any exception (including `KeyboardInterrupt`) → `job_handle.terminate()` then re-raise. This ensures a pod already created in step 5 is not orphaned if the blocking poll loop is interrupted. The return value is captured: if `False`, logs a warning `"unable to enter running phase {job_id}"`. Inside `enter_states`: stuck detection and plain timeout both call `terminate()` (always sets `terminal_state = TERMINATED`) then return `False`; if the pod reaches a terminal phase (`Failed`/`Succeeded`), `terminal_state` is set from `POD_STATE_MAPPING` and `enter_states` returns `False` without calling `terminate()`. Setting `terminal_state` in all `False`-return paths prevents `wait()` from looping if the pod is GC'd before `wait()` runs. | +| 7 | Return `job_handle`. The K8s launcher always returns a handle; callers detect launch failure when `poll()` or `wait()` resolves. | + +> **`set_list` note:** `args.set` is the CLI `--set` items stored in `FLContextKey.ARGS` at the time `launch_job` is called. The server and client both make a deep copy of `FLContextKey.ARGS`, append `print_conf=True` (and server also appends `restore_snapshot=`) to that copy, and embed the expanded string into `JOB_PROCESS_ARGS[OPTIONS]`. They do **not** write the modified copy back to `FLContextKey.ARGS`. As a result, the K8s launcher's `set_list` contains only the original CLI `--set` items — **without** `print_conf=True` or `restore_snapshot=...`. The Process launcher receives those extra flags through `OPTIONS`, which K8s excludes from `module_args` (`get_*_job_args(include_set_options=False)`). + +**Server/Client subclasses** override `get_module_args()`: + +- `ClientK8sJobLauncher` → Calls `_job_args_dict(job_args, get_client_job_args(False, False))` — filters `JOB_PROCESS_ARGS` excluding `EXE_MODULE` and `OPTIONS`, producing a `{flag: value}` dict for the container `args` list. +- `ServerK8sJobLauncher` → Same pattern with `get_server_job_args(False, False)`. + +**Key difference from Process:** The K8s launcher does not build a shell command string. Instead, it passes the Python executable as `command` and constructs a structured `args` list (`["-u", "-m", "", "-w", "", ...]`) directly in the pod spec. + +--- + +## 5. Object-Oriented Design Summary + +### 5.1 Full Class Hierarchy + +``` +JobHandleSpec (ABC) +├── ProcessHandle (wraps ProcessAdapter / subprocess.Popen) +└── K8sJobHandle (wraps CoreV1Api + pod name) + +JobLauncherSpec (FLComponent, ABC) +├── ProcessJobLauncher (abstract: get_command) +│ ├── ServerProcessJobLauncher +│ └── ClientProcessJobLauncher +└── K8sJobLauncher (abstract: get_module_args; inherits ABCMeta from JobLauncherSpec) + ├── ServerK8sJobLauncher + └── ClientK8sJobLauncher +``` + +### 5.2 Design Patterns + +**Strategy Pattern** -- Each launcher is a strategy for running jobs. The engine programs against `JobLauncherSpec`; the concrete strategy is selected at runtime through the event system. + +**Template Method Pattern** -- Each base launcher (`ProcessJobLauncher`, `K8sJobLauncher`) implements `launch_job()` with a fixed algorithm, delegating the variable part to an abstract method: + +| Base Launcher | Template method calls | Abstract hook | +|---------------|----------------------|---------------| +| `ProcessJobLauncher` | `launch_job()` → `get_command()` | `get_command(job_meta, fl_ctx) -> str` | +| `K8sJobLauncher` | `launch_job()` → `get_module_args()` | `get_module_args(job_id, fl_ctx) -> dict` | + +Server and client subclasses provide the implementation of these hooks, producing the correct command-line arguments for each role. + +**Observer Pattern** -- Launchers register for the `BEFORE_JOB_LAUNCH` event through the `FLComponent` event system. This decouples launcher registration from the engine's control flow entirely. + +--- + +## 6. Comparison: Process vs Kubernetes + +| Aspect | Process | Kubernetes | +|--------|---------|------------| +| **When selected** | No `job_image` for site | `job_image` present | +| **Execution unit** | OS subprocess | Kubernetes Pod | +| **Isolation** | Shared host, inherited env | Pod isolation, PVC-backed volumes | +| **Command format** | Shell command string (`sys.executable -m ...`) | Structured `command: ["/usr/local/bin/python"]` + `args` list in pod spec | +| **Workspace access** | Direct filesystem (same host) | PersistentVolumeClaims | +| **Data access** | Direct filesystem | Via PVC (configured in YAML) | +| **Start verification** | None (spawn returns immediately) | `enter_states([RUNNING])` called (reads `self.timeout`); return value is checked — `False` logs a debug message but handle is still always returned; on stuck detection or plain timeout, `terminate()` is called (sets `terminal_state = TERMINATED`); callers detect failure via `poll()` | +| **Wait mechanism** | `subprocess.Popen.wait()` (OS-level block) | Direct while loop via `_query_state()`; no timeout; exits when `terminal_state` set or `SUCCEEDED`/`TERMINATED` reached | +| **Terminate** | `SIGTERM`/`SIGKILL` via `ProcessAdapter` | `delete_namespaced_pod(grace_period=0)`; `terminal_state` always set to `TERMINATED` — on success, 404, other `ApiException`, or any lower-level `Exception` (network/serialization). Errors are logged but never propagated. | +| **Return code source** | Process exit code or RC file | Pod phase mapping or RC file; `poll()` consistently returns `JobReturnCode` via `JOB_RETURN_CODE_MAPPING` | +| **GPU support** | Inherited from host environment | `nvidia.com/gpu` resource limit set in pod spec from `job_meta` resource spec (`num_of_gpus`); omitted if not specified | +| **Dependencies** | stdlib only | `kubernetes` Python client + kubeconfig | +| **Typical use case** | Simulator, single-machine POC | Production cluster with shared storage | + +--- + +## 7. Backend (BE) Resource Management + +When using the K8s launcher, the NVFlare process managing jobs may not itself run on a GPU node. The K8s scheduler handles actual resource allocation externally. Two passthrough classes support this case: + +### 7.1 BEResourceManager (`nvflare/app_common/resource_managers/BE_resource_manager.py`) + +`BEResourceManager(ResourceManagerSpec, FLComponent)` is a "Best Effort" resource manager: it always approves resource allocation requests and performs no local tracking, allowing jobs to attempt to run and fail at runtime if resources are genuinely unavailable: + +| Method | Behavior | +|--------|----------| +| `check_resources()` | Always returns `(True, )` -- never rejects a job. | +| `cancel_resources()` | No-op. | +| `allocate_resources()` | Returns empty dict `{}`. | +| `free_resources()` | No-op. | +| `report_resources()` | Returns `{}` (empty dict, conforming to the `ResourceManagerSpec` contract). | + +Use this when the container orchestration backend (K8s) is responsible for all resource accounting. + +### 7.2 BEResourceConsumer (`nvflare/app_common/resource_consumers/BE_resource_consumer.py`) + +`BEResourceConsumer(ResourceConsumerSpec)` implements a no-op `consume()`. Use this alongside `BEResourceManager` when no local resource consumption reporting is needed. + +### 7.3 GPUResourceManager `ignore_host` flag (`nvflare/app_common/resource_managers/gpu_resource_manager.py`) + +`GPUResourceManager` gained an `ignore_host=False` parameter. When `True`, the constructor skips the startup validation that checks whether the declared `num_of_gpus` and `mem_per_gpu_in_GiB` match actual host hardware. This is needed in K8s deployments where the NVFlare process runs on a node without GPUs but still needs to track a GPU resource pool for job scheduling purposes. + +--- + +## 8. Sequence Diagram + +The following shows the end-to-end flow for launching and managing a job, applicable to both server and client: + +``` + Engine fed_utils Launcher(s) Handle + │ │ │ │ + │ get_job_launcher() │ │ │ + │───────────────────────>│ │ │ + │ │ fire BEFORE_JOB_LAUNCH │ + │ │─────────────────────>│ │ + │ │ │ check job_meta │ + │ │ │ add_launcher(self) │ + │ │<─────────────────────│ │ + │ return launcher │ │ │ + │<───────────────────────│ │ │ + │ │ │ │ + │ launcher.launch_job(job_meta, fl_ctx) │ │ + │─────────────────────────────────────────────->│ │ + │ │ │ create exec unit │ + │ │ │ (process/pod) │ + │ │ │─────────────────────>│ + │ │ │ return handle │ + │<─────────────────────────────────────────────────────────────────────│ + │ │ │ │ + │ store handle in run_processes │ │ + │ │ │ │ + │ [background thread] │ │ │ + │ handle.wait() │ │ │ + │─────────────────────────────────────────────────────────────────────>│ + │ │ │ blocks/polls │ + │ │ │ │ + │ ... (on abort) ... │ │ │ + │ handle.terminate() │ │ │ + │─────────────────────────────────────────────────────────────────────>│ + │ │ │ │ + │ ... (on completion) . │ │ │ + │ get_return_code() │ │ │ + │───────────────────────>│ │ │ + │ │ check RC file │ │ + │ │ or handle.poll() │ │ + │ │─────────────────────────────────────────────>│ + │ return_code │ │ │ + │<───────────────────────│ │ │ +``` + +--- + +## 9. Configuration + +Launchers are registered as FL components in the site's `resources.json`. The configurator loads them at startup so they receive events. + +### 9.1 Process Launcher (default) + +```json +{ + "id": "job_launcher", + "path": "nvflare.app_common.job_launcher.server_process_launcher.ServerProcessJobLauncher", + "args": {} +} +``` + +### 9.2 Kubernetes Launcher + +```json +{ + "id": "job_launcher", + "path": "nvflare.app_opt.job_launcher.k8s_launcher.ClientK8sJobLauncher", + "args": { + "config_file_path": "/path/to/kubeconfig", + "workspace_pvc": "nvflare-workspace-pvc", + "etc_pvc": "nvflare-etc-pvc", + "data_pvc_file_path": "/path/to/data_pvc.yaml", + "timeout": 120, + "namespace": "nvflare" + } +} +``` + +The `data_pvc_file_path` YAML file maps PVC names to mount paths: + +```yaml +my-data-pvc: /var/tmp/nvflare/data +``` + +> Note: Currently only the PVC name (the YAML key) is used. The mount path value is ignored — the data volume is always mounted at the hardcoded path `/var/tmp/nvflare/data`. + +--- + +## 10. Future Improvements + +1. **Explicit launcher selection** -- Today "has image" → K8s, "no image" → Process. Allow an explicit `launcher_type` field in job meta or deploy map so a site can support multiple backends or provide fallback ordering. + +2. **Consistent GPU handling** -- The K8s launcher applies `num_of_gpus` from the job resource spec as a pod `nvidia.com/gpu` limit; the Process launcher ignores it entirely. Standardize resource declaration so job definitions remain portable across backends. + +3. **Unified cleanup** -- Standardize pod cleanup policy (auto-remove on exit, configurable retention for debugging) and centralize it in the handle or engine. + +4. **Consistent timeout policy and failure semantics** -- The Process launcher has no start timeout. K8s checks the `enter_states` return value and logs a debug message on failure (termination already happens inside `enter_states` for stuck/timeout paths), always returning a handle. Consider exposing a distinct "startup failed" state on the handle so callers can react without polling. + +5. **Observability** -- Add an optional `get_info()` method to `JobHandleSpec` so the engine can log launcher-specific details (pod name, namespace, PID) for debugging and operations. + +6. **Testing** -- Provide `MockJobLauncher` and `MockJobHandle` implementations for unit tests that verify server/client flow without starting real processes or pods. diff --git a/nvflare/apis/job_launcher_spec.py b/nvflare/apis/job_launcher_spec.py index cb75130e6a..f8cd4a7192 100644 --- a/nvflare/apis/job_launcher_spec.py +++ b/nvflare/apis/job_launcher_spec.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from abc import abstractmethod +from abc import ABC, abstractmethod from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FLContextKey @@ -56,12 +56,12 @@ def add_launcher(launcher, fl_ctx: FLContext): fl_ctx.set_prop(FLContextKey.JOB_LAUNCHER, job_launcher, private=True, sticky=False) -class JobHandleSpec: +class JobHandleSpec(ABC): @abstractmethod def terminate(self): """To terminate the job run. - Returns: the job run return code. + Returns: None """ raise NotImplementedError() @@ -85,7 +85,7 @@ def wait(self): raise NotImplementedError() -class JobLauncherSpec(FLComponent): +class JobLauncherSpec(FLComponent, ABC): @abstractmethod def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: """To launch a job run. @@ -94,7 +94,7 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: job_meta: job metadata fl_ctx: FLContext - Returns: boolean to indicates the job launch success or fail. + Returns: a JobHandle instance. """ raise NotImplementedError() diff --git a/nvflare/app_common/resource_consumers/BE_resource_consumer.py b/nvflare/app_common/resource_consumers/BE_resource_consumer.py new file mode 100644 index 0000000000..08059ec621 --- /dev/null +++ b/nvflare/app_common/resource_consumers/BE_resource_consumer.py @@ -0,0 +1,21 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.resource_manager_spec import ResourceConsumerSpec + + +class BEResourceConsumer(ResourceConsumerSpec): + + def consume(self, resources: dict): + pass diff --git a/nvflare/app_common/resource_managers/BE_resource_manager.py b/nvflare/app_common/resource_managers/BE_resource_manager.py new file mode 100644 index 0000000000..89b92e850c --- /dev/null +++ b/nvflare/app_common/resource_managers/BE_resource_manager.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext +from nvflare.apis.resource_manager_spec import ResourceManagerSpec + + +class BEResourceManager(ResourceManagerSpec, FLComponent): + def __init__(self): + """Best Effort Resource Manager implementation. + + It will accept all resource allocation requests + and let the job fail when the requested resources are unavailable + at runtime. + + """ + super().__init__() + + def check_resources(self, resource_requirement: dict, fl_ctx: FLContext): + if not isinstance(resource_requirement, dict): + raise TypeError(f"resource_requirement should be of type dict, but got {type(resource_requirement)}.") + + token = str(uuid.uuid4()) + return True, token + + def cancel_resources(self, resource_requirement: dict, token: str, fl_ctx: FLContext): + return None + + def allocate_resources(self, resource_requirement: dict, token: str, fl_ctx: FLContext) -> dict: + return {} + + def free_resources(self, resources: dict, token: str, fl_ctx: FLContext): + pass + + def report_resources(self, fl_ctx): + return {} diff --git a/nvflare/app_common/resource_managers/gpu_resource_manager.py b/nvflare/app_common/resource_managers/gpu_resource_manager.py index f14f2f0ccd..c4ea20b52c 100644 --- a/nvflare/app_common/resource_managers/gpu_resource_manager.py +++ b/nvflare/app_common/resource_managers/gpu_resource_manager.py @@ -35,6 +35,7 @@ def __init__( num_gpu_key: str = "num_of_gpus", gpu_mem_key: str = "mem_per_gpu_in_GiB", expiration_period: Union[int, float] = 30, + ignore_host: bool = False, ): """Resource manager for GPUs. @@ -46,6 +47,9 @@ def __init__( expiration_period: Number of seconds to hold the resources reserved. If check_resources is called but after "expiration_period" no allocate resource is called, then the reserved resources will be released. + ignore_host: Whether to skip validation against GPUs present on the local host. Set to True in + environments where the NVFlare process runs on a node without GPUs (for example, some + Kubernetes deployments) but GPU resources are managed externally. """ if not isinstance(num_of_gpus, int): raise ValueError(f"num_of_gpus should be of type int, but got {type(num_of_gpus)}.") @@ -62,17 +66,21 @@ def __init__( if expiration_period < 0: raise ValueError("expiration_period should be greater than or equal to 0.") - if num_of_gpus > 0: - num_host_gpus = len(get_host_gpu_ids()) - if num_of_gpus > num_host_gpus: - raise ValueError(f"num_of_gpus specified ({num_of_gpus}) exceeds available GPUs: {num_host_gpus}.") - - host_gpu_mem = get_host_gpu_memory_total() - for i in host_gpu_mem: - if mem_per_gpu_in_GiB * 1024 > i: - raise ValueError( - f"Memory per GPU specified ({mem_per_gpu_in_GiB * 1024}) exceeds available GPU memory: {i}." - ) + if not isinstance(ignore_host, bool): + raise ValueError(f"ignore_host should be of type bool, but got {type(ignore_host)}.") + + if not ignore_host: + if num_of_gpus > 0: + num_host_gpus = len(get_host_gpu_ids()) + if num_of_gpus > num_host_gpus: + raise ValueError(f"num_of_gpus specified ({num_of_gpus}) exceeds available GPUs: {num_host_gpus}.") + + host_gpu_mem = get_host_gpu_memory_total() + for i in host_gpu_mem: + if mem_per_gpu_in_GiB * 1024 > i: + raise ValueError( + f"Memory per GPU specified ({mem_per_gpu_in_GiB * 1024}) exceeds available GPU memory: {i}." + ) self.num_gpu_key = num_gpu_key self.gpu_mem_key = gpu_mem_key diff --git a/nvflare/app_opt/job_launcher/k8s_launcher.py b/nvflare/app_opt/job_launcher/k8s_launcher.py index 19e6716bc2..5fe5b58e46 100644 --- a/nvflare/app_opt/job_launcher/k8s_launcher.py +++ b/nvflare/app_opt/job_launcher/k8s_launcher.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import logging +import re import time from abc import abstractmethod from enum import Enum +import yaml from kubernetes import config from kubernetes.client import Configuration from kubernetes.client.api import core_v1_api @@ -24,6 +27,7 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey, JobConstants from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import JobMetaKey from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobProcessArgs, JobReturnCode, add_launcher from nvflare.utils.job_launcher_utils import extract_job_image, get_client_job_args, get_server_job_args @@ -36,12 +40,20 @@ class JobState(Enum): UNKNOWN = "unknown" +class POD_Phase(Enum): + PENDING = "Pending" + RUNNING = "Running" + SUCCEEDED = "Succeeded" + FAILED = "Failed" + UNKNOWN = "Unknown" + + POD_STATE_MAPPING = { - "Pending": JobState.STARTING, - "Running": JobState.RUNNING, - "Succeeded": JobState.SUCCEEDED, - "Failed": JobState.TERMINATED, - "Unknown": JobState.UNKNOWN, + POD_Phase.PENDING.value: JobState.STARTING, + POD_Phase.RUNNING.value: JobState.RUNNING, + POD_Phase.SUCCEEDED.value: JobState.SUCCEEDED, + POD_Phase.FAILED.value: JobState.TERMINATED, + POD_Phase.UNKNOWN.value: JobState.UNKNOWN, } JOB_RETURN_CODE_MAPPING = { @@ -52,13 +64,60 @@ class JobState(Enum): JobState.UNKNOWN: JobReturnCode.UNKNOWN, } +DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT = { + "-m": None, + "-w": None, + "-t": None, + "-d": None, + "-n": None, + "-c": None, + "-p": None, + "-g": None, + "-scheme": None, + "-s": None, +} + + +class PV_NAME(Enum): + WORKSPACE = "nvflws" + DATA = "nvfldata" + ETC = "nvfletc" + + +VOLUME_MOUNT_LIST = [ + {"name": PV_NAME.WORKSPACE.value, "mountPath": "/var/tmp/nvflare/workspace"}, + {"name": PV_NAME.DATA.value, "mountPath": "/var/tmp/nvflare/data"}, + {"name": PV_NAME.ETC.value, "mountPath": "/var/tmp/nvflare/etc"}, +] + + +def uuid4_to_rfc1123(uuid_str: str) -> str: + name = uuid_str.lower() + # Strip any chars that aren't alphanumeric or hyphen + name = re.sub(r"[^a-z0-9-]", "", name) + # Prefix with a letter if it starts with a digit + if name and name[0].isdigit(): + name = "j" + name + # Kubernetes label limit: 63 chars; strip trailing hyphens after truncation + # (truncation can expose a hyphen that was interior before slicing) + return name[:63].rstrip("-") + class K8sJobHandle(JobHandleSpec): - def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, namespace="default", timeout=None): + def __init__( + self, + job_id: str, + api_instance: core_v1_api, + job_config: dict, + namespace="default", + timeout=None, + pending_timeout=30, + python_path="/usr/local/bin/python", + ): super().__init__() self.job_id = job_id self.timeout = timeout - + self.terminal_state = None self.api_instance = api_instance self.namespace = namespace self.pod_manifest = { @@ -68,78 +127,57 @@ def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, nam "spec": { "containers": None, # link to container_list "volumes": None, # link to volume_list - "restartPolicy": "OnFailure", + "restartPolicy": "Never", }, } - self.volume_list = [{"name": None, "hostPath": {"path": None, "type": "Directory"}}] + self.volume_list = [] + self.container_list = [ { "image": None, "name": None, - "command": ["/usr/local/bin/python"], + "command": [python_path], "args": None, # args_list + args_dict + args_sets "volumeMounts": None, # volume_mount_list "imagePullPolicy": "Always", } ] - self.container_args_python_args_list = ["-u", "-m", job_config.get("command")] - self.container_args_module_args_dict = { - "-m": None, - "-w": None, - "-t": None, - "-d": None, - "-n": None, - "-c": None, - "-p": None, - "-g": None, - "-scheme": None, - "-s": None, - } - self.container_volume_mount_list = [ - { - "name": None, - "mountPath": None, - } - ] + command = job_config.get("command") + if not command: + raise ValueError("job_config must contain a non-empty 'command' key") + self.container_args_python_args_list = ["-u", "-m", command] + self.container_volume_mount_list = [] self._make_manifest(job_config) + self._stuck_count = 0 + self._max_stuck_count = self.timeout if self.timeout is not None else pending_timeout + self.logger = logging.getLogger(self.__class__.__name__) def _make_manifest(self, job_config): - self.container_volume_mount_list = job_config.get( - "volume_mount_list", [{"name": "workspace-nvflare", "mountPath": "/workspace/nvflare"}] - ) + self.container_volume_mount_list.extend(job_config.get("volume_mount_list", [])) set_list = job_config.get("set_list") - if set_list is None: + if not set_list: self.container_args_module_args_sets = list() else: self.container_args_module_args_sets = ["--set"] + set_list - self.container_args_module_args_dict = job_config.get( - "module_args", - { - "-m": None, - "-w": None, - "-t": None, - "-d": None, - "-n": None, - "-c": None, - "-p": None, - "-g": None, - "-scheme": None, - "-s": None, - }, - ) + if job_config.get("module_args") is None: + self.container_args_module_args_dict = DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT.copy() + else: + self.container_args_module_args_dict = job_config.get("module_args") self.container_args_module_args_dict_as_list = list() for k, v in self.container_args_module_args_dict.items(): + if v is None: + continue self.container_args_module_args_dict_as_list.append(k) - self.container_args_module_args_dict_as_list.append(v) - self.volume_list = job_config.get( - "volume_list", [{"name": None, "hostPath": {"path": None, "type": "Directory"}}] - ) - + self.container_args_module_args_dict_as_list.append(str(v)) + self.volume_list.extend(job_config.get("volume_list", [])) self.pod_manifest["metadata"]["name"] = job_config.get("name") self.pod_manifest["spec"]["containers"] = self.container_list self.pod_manifest["spec"]["volumes"] = self.volume_list - self.container_list[0]["image"] = job_config.get("image", "nvflare/nvflare:2.5.0") + image = job_config.get("image") + if not image: + raise ValueError("job_config must contain a non-empty 'image' key") + self.container_list[0]["image"] = image self.container_list[0]["name"] = job_config.get("container_name", "nvflare_job") self.container_list[0]["args"] = ( self.container_args_python_args_list @@ -147,62 +185,122 @@ def _make_manifest(self, job_config): + self.container_args_module_args_sets ) self.container_list[0]["volumeMounts"] = self.container_volume_mount_list + if job_config.get("resources", {}).get("limits", {}).get("nvidia.com/gpu"): + self.container_list[0]["resources"] = job_config.get("resources") def get_manifest(self): - return self.pod_manifest + return copy.deepcopy(self.pod_manifest) - def enter_states(self, job_states_to_enter: list, timeout=None): + def enter_states(self, job_states_to_enter: list): starting_time = time.time() if not isinstance(job_states_to_enter, (list, tuple)): job_states_to_enter = [job_states_to_enter] - if not all([isinstance(js, JobState)] for js in job_states_to_enter): + if not all([isinstance(js, JobState) for js in job_states_to_enter]): raise ValueError(f"expect job_states_to_enter with valid values, but get {job_states_to_enter}") while True: - job_state = self._query_state() + pod_phase = self._query_phase() + if self._stuck_in_pending(pod_phase): + self.terminate() + return False + job_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN) if job_state in job_states_to_enter: return True - elif timeout is not None and time.time() - starting_time > timeout: + elif pod_phase in [POD_Phase.FAILED.value, POD_Phase.SUCCEEDED.value]: # terminal state + self.terminal_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN) + return False + elif self.timeout is not None and time.time() - starting_time > self.timeout: + self.terminate() return False time.sleep(1) def terminate(self): - resp = self.api_instance.delete_namespaced_pod( - name=self.job_id, namespace=self.namespace, grace_period_seconds=0 - ) - return self.enter_states([JobState.TERMINATED], timeout=self.timeout) + try: + self.api_instance.delete_namespaced_pod(name=self.job_id, namespace=self.namespace, grace_period_seconds=0) + self.terminal_state = JobState.TERMINATED + except ApiException as e: + if getattr(e, "status", None) == 404: + self.logger.info(f"job {self.job_id} pod not found during termination; assuming terminated") + else: + self.logger.error(f"failed to terminate job {self.job_id}: {e}") + self.terminal_state = JobState.TERMINATED + except Exception as e: + self.logger.error(f"unexpected error terminating job {self.job_id}: {e}") + self.terminal_state = JobState.TERMINATED + return None def poll(self): + if self.terminal_state is not None: + return JOB_RETURN_CODE_MAPPING.get(self.terminal_state) job_state = self._query_state() return JOB_RETURN_CODE_MAPPING.get(job_state, JobReturnCode.UNKNOWN) - def _query_state(self): + def _query_phase(self): try: resp = self.api_instance.read_namespaced_pod(name=self.job_id, namespace=self.namespace) - except ApiException: - return JobState.UNKNOWN - return POD_STATE_MAPPING.get(resp.status.phase, JobState.UNKNOWN) + except ApiException as e: + self.logger.warning(f"failed to query pod phase {self.job_id}: {e}") + return POD_Phase.UNKNOWN.value + except Exception as e: + self.logger.warning(f"unexpected error querying pod phase {self.job_id}: {e}") + return POD_Phase.UNKNOWN.value + return resp.status.phase + + def _query_state(self): + pod_phase = self._query_phase() + return POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN) + + def _stuck_in_pending(self, current_phase): + if current_phase == POD_Phase.PENDING.value: + self._stuck_count += 1 + if self._max_stuck_count is not None and self._stuck_count >= self._max_stuck_count: + return True + else: + self._stuck_count = 0 + return False def wait(self): - self.enter_states([JobState.SUCCEEDED, JobState.TERMINATED]) + while True: + if self.terminal_state is not None: + return + job_state = self._query_state() + if job_state in (JobState.SUCCEEDED, JobState.TERMINATED): + self.terminal_state = job_state # persist so poll() stays accurate + return + time.sleep(1) class K8sJobLauncher(JobLauncherSpec): def __init__( self, - config_file_path, - root_hostpath: str, - workspace: str, - mount_path: str, + config_file_path: str, + workspace_pvc: str, + etc_pvc: str, + data_pvc_file_path: str, timeout=None, namespace="default", + pending_timeout=30, + python_path="/usr/local/bin/python", ): super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) - self.root_hostpath = root_hostpath - self.workspace = workspace - self.mount_path = mount_path + self.workspace_pvc = workspace_pvc + self.etc_pvc = etc_pvc + self.data_pvc_file_path = data_pvc_file_path self.timeout = timeout - + self.namespace = namespace + self.pending_timeout = pending_timeout + self.python_path = python_path + with open(data_pvc_file_path, "rt") as f: + data_pvc_dict = yaml.safe_load(f) + if not data_pvc_dict: + raise ValueError(f"data_pvc_file_path '{data_pvc_file_path}' is empty or contains no PVC entries.") + # data_pvc_dict will be pvc: mountPath + # currently, support one pvc and always mount to /var/tmp/nvflare/data + # ie, ignore the mountPath in data_pvc_dict + if not isinstance(data_pvc_dict, dict): + raise ValueError(f"file at data_pvc_file_path '{data_pvc_file_path}' does not contain a dictionary.") + self.data_pvc = list(data_pvc_dict.keys())[0] config.load_kube_config(config_file_path) try: c = Configuration().get_default_copy() @@ -211,47 +309,68 @@ def __init__( c.assert_hostname = False Configuration.set_default(c) self.core_v1 = core_v1_api.CoreV1Api() - self.namespace = namespace - self.job_handle = None - self.logger = logging.getLogger(self.__class__.__name__) def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: - - job_id = job_meta.get(JobConstants.JOB_ID) + site_name = fl_ctx.get_identity_name() + raw_job_id = job_meta.get(JobConstants.JOB_ID) + if not raw_job_id: + raise RuntimeError(f"missing {JobConstants.JOB_ID} in job_meta") + job_id = uuid4_to_rfc1123(raw_job_id) args = fl_ctx.get_prop(FLContextKey.ARGS) - job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) - self.logger.info(f"launch job use image: {job_image}") - + job_image = extract_job_image(job_meta, site_name) + site_resources = job_meta.get(JobMetaKey.RESOURCE_SPEC.value, {}).get(site_name, {}) + job_resource = site_resources.get("num_of_gpus", None) job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS) if not job_args: raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext") - _, job_cmd = job_args[JobProcessArgs.EXE_MODULE] + exe_module_entry = job_args.get(JobProcessArgs.EXE_MODULE) + if not exe_module_entry: + raise RuntimeError(f"missing {JobProcessArgs.EXE_MODULE} in {FLContextKey.JOB_PROCESS_ARGS}") + _, job_cmd = exe_module_entry job_config = { "name": job_id, "image": job_image, "container_name": f"container-{job_id}", "command": job_cmd, - "volume_mount_list": [{"name": self.workspace, "mountPath": self.mount_path}], - "volume_list": [{"name": self.workspace, "hostPath": {"path": self.root_hostpath, "type": "Directory"}}], + "volume_mount_list": VOLUME_MOUNT_LIST, + "volume_list": [ + {"name": PV_NAME.WORKSPACE.value, "persistentVolumeClaim": {"claimName": self.workspace_pvc}}, + {"name": PV_NAME.DATA.value, "persistentVolumeClaim": {"claimName": self.data_pvc}}, + {"name": PV_NAME.ETC.value, "persistentVolumeClaim": {"claimName": self.etc_pvc}}, + ], "module_args": self.get_module_args(job_id, fl_ctx), - "set_list": args.set, } - - self.logger.info(f"launch job with k8s_launcher. Job_id:{job_id}") - - job_handle = K8sJobHandle(job_id, self.core_v1, job_config, namespace=self.namespace, timeout=self.timeout) + if args is not None and getattr(args, "set", None) is not None: + job_config.update({"set_list": args.set}) + if job_resource: + job_config.update({"resources": {"limits": {"nvidia.com/gpu": job_resource}}}) + job_handle = K8sJobHandle( + job_id, + self.core_v1, + job_config, + namespace=self.namespace, + timeout=self.timeout, + pending_timeout=self.pending_timeout, + python_path=self.python_path, + ) + pod_manifest = job_handle.get_manifest() + self.logger.debug(f"launch job with k8s_launcher. {pod_manifest=}") try: - self.core_v1.create_namespaced_pod(body=job_handle.get_manifest(), namespace=self.namespace) - if job_handle.enter_states([JobState.RUNNING], timeout=self.timeout): - return job_handle - else: - job_handle.terminate() - return None - except ApiException: + self.core_v1.create_namespaced_pod(body=pod_manifest, namespace=self.namespace) + except Exception as e: + self.logger.error(f"failed to launch job {job_id}: {e}") + job_handle.terminal_state = JobState.TERMINATED + return job_handle + try: + entered_running = job_handle.enter_states([JobState.RUNNING]) + except BaseException: job_handle.terminate() - return None + raise + if not entered_running: + self.logger.warning(f"unable to enter running phase {job_id}") + return job_handle def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.BEFORE_JOB_LAUNCH: diff --git a/tests/unit_test/app_opt/job_launcher/__init__.py b/tests/unit_test/app_opt/job_launcher/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/unit_test/app_opt/job_launcher/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py b/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py new file mode 100644 index 0000000000..b1ff897255 --- /dev/null +++ b/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py @@ -0,0 +1,1174 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from types import ModuleType +from unittest.mock import MagicMock, Mock, patch + +import pytest + +_k8s_mock = ModuleType("kubernetes") +_k8s_client = ModuleType("kubernetes.client") +_k8s_config = ModuleType("kubernetes.config") +_k8s_rest = ModuleType("kubernetes.client.rest") +_k8s_api = ModuleType("kubernetes.client.api") +_k8s_core = ModuleType("kubernetes.client.api.core_v1_api") + + +class _FakeApiException(Exception): + def __init__(self, status=None, reason=None, http_resp=None): + self.status = status + self.reason = reason + + +_k8s_rest.ApiException = _FakeApiException +_k8s_client.Configuration = MagicMock +_k8s_client.rest = _k8s_rest +_k8s_client.api = _k8s_api +_k8s_core.CoreV1Api = MagicMock +_k8s_api.core_v1_api = _k8s_core +_k8s_mock.config = _k8s_config +_k8s_mock.client = _k8s_client + +for _mod_name, _mod_obj in [ + ("kubernetes", _k8s_mock), + ("kubernetes.config", _k8s_config), + ("kubernetes.client", _k8s_client), + ("kubernetes.client.rest", _k8s_rest), + ("kubernetes.client.api", _k8s_api), + ("kubernetes.client.api.core_v1_api", _k8s_core), +]: + sys.modules.setdefault(_mod_name, _mod_obj) + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import FLContextKey, JobConstants, ReservedKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import JobMetaKey +from nvflare.apis.job_launcher_spec import JobProcessArgs, JobReturnCode +from nvflare.app_opt.job_launcher.k8s_launcher import ( + DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT, + JOB_RETURN_CODE_MAPPING, + POD_STATE_MAPPING, + PV_NAME, + VOLUME_MOUNT_LIST, + JobState, + K8sJobHandle, + POD_Phase, + _job_args_dict, + uuid4_to_rfc1123, +) + + +def _make_job_config(**overrides): + cfg = { + "name": "test-job-123", + "image": "nvflare/nvflare:test", + "container_name": "container-test-job-123", + "command": "nvflare.private.fed.app.client.worker_process", + "volume_mount_list": VOLUME_MOUNT_LIST, + "volume_list": [ + {"name": PV_NAME.WORKSPACE.value, "persistentVolumeClaim": {"claimName": "ws-pvc"}}, + {"name": PV_NAME.DATA.value, "persistentVolumeClaim": {"claimName": "data-pvc"}}, + {"name": PV_NAME.ETC.value, "persistentVolumeClaim": {"claimName": "etc-pvc"}}, + ], + "module_args": {"-m": "val_m", "-w": "val_w"}, + "set_list": ["key1=val1", "key2=val2"], + "resources": {"limits": {"nvidia.com/gpu": 1}}, + } + cfg.update(overrides) + return cfg + + +def _make_api_instance(): + return MagicMock() + + +def _make_handle(job_id="job-1", api=None, cfg=None, **kwargs): + if api is None: + api = _make_api_instance() + if cfg is None: + cfg = _make_job_config() + handle = K8sJobHandle(job_id, api, cfg, **kwargs) + handle.job_id = job_id + return handle + + +# --------------------------------------------------------------------------- +# uuid4_to_rfc1123 +# --------------------------------------------------------------------------- +class TestUuid4ToRfc1123: + def test_lowercase(self): + assert uuid4_to_rfc1123("ABCD-1234") == "abcd-1234" + + def test_strips_invalid_chars(self): + assert uuid4_to_rfc1123("abc_def.ghi") == "abcdefghi" + + def test_prefixes_leading_digit(self): + result = uuid4_to_rfc1123("1234-abcd") + assert result[0] == "j" + assert result == "j1234-abcd" + + def test_strips_trailing_hyphens(self): + assert uuid4_to_rfc1123("abc-") == "abc" + + def test_strips_trailing_hyphen_exposed_by_truncation(self): + # 62 'a's followed by '-' followed by more chars: truncation exposes the hyphen + name = "a" * 62 + "-" + "b" * 10 + result = uuid4_to_rfc1123(name) + assert not result.endswith("-"), f"trailing hyphen in {result!r}" + assert len(result) == 62 + + def test_truncates_to_63_chars(self): + long_str = "a" * 100 + assert len(uuid4_to_rfc1123(long_str)) == 63 + + def test_typical_uuid_gets_prefixed(self): + result = uuid4_to_rfc1123("550e8400-e29b-41d4-a716-446655440000") + assert result == "j550e8400-e29b-41d4-a716-446655440000" + + def test_letter_leading_uuid_no_prefix(self): + result = uuid4_to_rfc1123("abcd1234-e29b-41d4-a716-446655440000") + assert result == "abcd1234-e29b-41d4-a716-446655440000" + + def test_empty_string(self): + assert uuid4_to_rfc1123("") == "" + + +# --------------------------------------------------------------------------- +# Mapping tables +# --------------------------------------------------------------------------- +class TestPodStateMapping: + def test_all_phases_mapped(self): + for phase in POD_Phase: + assert phase.value in POD_STATE_MAPPING + + def test_pending_maps_to_starting(self): + assert POD_STATE_MAPPING[POD_Phase.PENDING.value] == JobState.STARTING + + def test_running_maps_to_running(self): + assert POD_STATE_MAPPING[POD_Phase.RUNNING.value] == JobState.RUNNING + + def test_succeeded_maps_to_succeeded(self): + assert POD_STATE_MAPPING[POD_Phase.SUCCEEDED.value] == JobState.SUCCEEDED + + def test_failed_maps_to_terminated(self): + assert POD_STATE_MAPPING[POD_Phase.FAILED.value] == JobState.TERMINATED + + +class TestJobReturnCodeMapping: + def test_all_job_states_mapped(self): + for state in JobState: + assert state in JOB_RETURN_CODE_MAPPING + + def test_succeeded_maps_to_success(self): + assert JOB_RETURN_CODE_MAPPING[JobState.SUCCEEDED] == JobReturnCode.SUCCESS + + def test_terminated_maps_to_aborted(self): + assert JOB_RETURN_CODE_MAPPING[JobState.TERMINATED] == JobReturnCode.ABORTED + + def test_running_maps_to_unknown(self): + assert JOB_RETURN_CODE_MAPPING[JobState.RUNNING] == JobReturnCode.UNKNOWN + + +# --------------------------------------------------------------------------- +# K8sJobHandle +# --------------------------------------------------------------------------- +class TestK8sJobHandle: + # -- construction --------------------------------------------------------- + def test_init_raises_on_missing_command(self): + cfg = _make_job_config() + del cfg["command"] + with pytest.raises(ValueError, match="command"): + K8sJobHandle("job-1", _make_api_instance(), cfg) + + def test_init_raises_on_empty_command(self): + cfg = _make_job_config(command="") + with pytest.raises(ValueError, match="command"): + K8sJobHandle("job-1", _make_api_instance(), cfg) + + def test_stuck_count_starts_at_zero(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=30) + assert handle._stuck_count == 0 + + def test_max_stuck_count_equals_timeout(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=30) + assert handle._max_stuck_count == 30 + + def test_max_stuck_count_uses_pending_timeout_when_no_timeout(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=None) + assert handle._max_stuck_count == 30 + + def test_max_stuck_count_uses_custom_pending_timeout(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=None, pending_timeout=60) + assert handle._max_stuck_count == 60 + + # -- manifest ------------------------------------------------------------- + def test_manifest_metadata_name(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + assert handle.get_manifest()["metadata"]["name"] == "test-job-123" + + def test_manifest_container_image(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["image"] == "nvflare/nvflare:test" + + def test_manifest_container_name(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["name"] == "container-test-job-123" + + def test_manifest_raises_on_missing_image(self): + cfg = _make_job_config() + del cfg["image"] + with pytest.raises(ValueError, match="image"): + K8sJobHandle("job-1", _make_api_instance(), cfg) + + def test_manifest_raises_on_empty_image(self): + cfg = _make_job_config(image="") + with pytest.raises(ValueError, match="image"): + K8sJobHandle("job-1", _make_api_instance(), cfg) + + def test_manifest_default_container_name(self): + cfg = _make_job_config() + del cfg["container_name"] + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["name"] == "nvflare_job" + + def test_manifest_restart_policy(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + assert handle.get_manifest()["spec"]["restartPolicy"] == "Never" + + def test_manifest_volumes(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + volumes = handle.get_manifest()["spec"]["volumes"] + assert len(volumes) == 3 + pvc_names = [v["persistentVolumeClaim"]["claimName"] for v in volumes] + assert "ws-pvc" in pvc_names + assert "data-pvc" in pvc_names + assert "etc-pvc" in pvc_names + + def test_manifest_volume_mounts(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["volumeMounts"] == VOLUME_MOUNT_LIST + + def test_manifest_args_contain_command(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "-u" in args + assert "-m" in args + assert "nvflare.private.fed.app.client.worker_process" in args + + def test_manifest_args_contain_module_args(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "val_m" in args + assert "val_w" in args + + def test_manifest_args_contain_set_list(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "--set" in args + assert "key1=val1" in args + assert "key2=val2" in args + + def test_manifest_no_set_list(self): + cfg = _make_job_config() + cfg["set_list"] = None + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "--set" not in args + + def test_manifest_none_module_args_skipped(self): + cfg = _make_job_config(module_args={"-a": "keep", "-b": None}) + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "-a" in args + assert "keep" in args + assert "-b" not in args + + def test_manifest_non_string_module_arg_values_are_stringified(self): + cfg = _make_job_config(module_args={"-p": 8080, "-n": 42}) + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "8080" in args + assert "42" in args + assert all(isinstance(a, str) for a in args), "all container args must be str" + + def test_manifest_default_module_args_copies_dict(self): + cfg = _make_job_config() + cfg["module_args"] = None + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + assert handle.container_args_module_args_dict is not DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT + assert handle.container_args_module_args_dict == DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT + + def test_manifest_default_module_args_all_none_produces_empty_args_list(self): + cfg = _make_job_config() + cfg["module_args"] = None + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + assert handle.container_args_module_args_dict_as_list == [] + + def test_manifest_gpu_resources(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["resources"]["limits"]["nvidia.com/gpu"] == 1 + + def test_manifest_no_gpu_resources(self): + cfg = _make_job_config(resources={"limits": {"nvidia.com/gpu": None}}) + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert "resources" not in container + + def test_manifest_no_resources_key(self): + cfg = _make_job_config() + del cfg["resources"] + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert "resources" not in container + + def test_get_manifest_returns_independent_copy(self): + handle = K8sJobHandle("job-1", _make_api_instance(), _make_job_config()) + manifest = handle.get_manifest() + # Mutate the returned copy at every mutable level + manifest["metadata"]["name"] = "MUTATED" + manifest["spec"]["containers"][0]["image"] = "MUTATED" + manifest["spec"]["volumes"].clear() + # Internal state must be unchanged + internal = handle.pod_manifest + assert internal["metadata"]["name"] == "test-job-123" + assert internal["spec"]["containers"][0]["image"] == "nvflare/nvflare:test" + assert len(internal["spec"]["volumes"]) > 0 + + # -- poll ----------------------------------------------------------------- + def test_poll_returns_unknown_when_running(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.RUNNING.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + assert handle.poll() == JobReturnCode.UNKNOWN + + def test_poll_returns_success_when_succeeded(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + assert handle.poll() == JobReturnCode.SUCCESS + + def test_poll_returns_aborted_when_failed(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.FAILED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + assert handle.poll() == JobReturnCode.ABORTED + + def test_poll_uses_terminal_state_if_set(self): + api = _make_api_instance() + handle = _make_handle(api=api) + handle.terminal_state = JobState.TERMINATED + assert handle.poll() == JobReturnCode.ABORTED + api.read_namespaced_pod.assert_not_called() + + # -- terminate ------------------------------------------------------------ + def test_terminate_deletes_pod_and_sets_terminated(self): + api = _make_api_instance() + handle = _make_handle(api=api) + handle.terminate() + api.delete_namespaced_pod.assert_called_once_with(name="job-1", namespace="default", grace_period_seconds=0) + assert handle.terminal_state == JobState.TERMINATED + + def test_terminate_sets_terminated_on_404(self): + api = _make_api_instance() + api.delete_namespaced_pod.side_effect = _FakeApiException(status=404, reason="Not Found") + handle = _make_handle(api=api) + handle.terminate() + assert handle.terminal_state == JobState.TERMINATED + + def test_terminate_sets_terminated_on_non_404_api_error(self): + api = _make_api_instance() + api.delete_namespaced_pod.side_effect = _FakeApiException(status=500, reason="Internal") + handle = _make_handle(api=api) + handle.terminate() + assert handle.terminal_state == JobState.TERMINATED + + def test_terminate_sets_terminated_on_network_error(self): + api = _make_api_instance() + api.delete_namespaced_pod.side_effect = ConnectionError("network unreachable") + handle = _make_handle(api=api) + handle.terminate() + assert handle.terminal_state == JobState.TERMINATED + + def test_terminate_custom_namespace(self): + api = _make_api_instance() + handle = _make_handle(api=api, namespace="custom-ns") + handle.terminate() + api.delete_namespaced_pod.assert_called_once_with(name="job-1", namespace="custom-ns", grace_period_seconds=0) + + # -- _query_phase --------------------------------------------------------- + def test_query_phase_returns_unknown_on_api_error(self): + api = _make_api_instance() + api.read_namespaced_pod.side_effect = _FakeApiException(status=500, reason="Error") + handle = _make_handle(api=api) + assert handle._query_phase() == POD_Phase.UNKNOWN.value + + def test_query_phase_returns_unknown_on_generic_exception(self): + api = _make_api_instance() + api.read_namespaced_pod.side_effect = RuntimeError("connection lost") + handle = _make_handle(api=api) + assert handle._query_phase() == POD_Phase.UNKNOWN.value + + # -- _stuck_in_pending ---------------------------------------------------- + def test_stuck_in_pending_returns_true_at_max_count(self): + # With _stuck_count seeded to max-1, one more PENDING call increments to + # exactly max, which should fire (>=, not >). + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) + handle._stuck_count = handle._max_stuck_count - 1 + assert handle._stuck_in_pending(POD_Phase.PENDING.value) is True + + def test_stuck_in_pending_returns_false_one_before_max(self): + # One iteration before the threshold must not fire. + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) + handle._stuck_count = handle._max_stuck_count - 2 + assert handle._stuck_in_pending(POD_Phase.PENDING.value) is False + + def test_stuck_in_pending_returns_false_for_non_pending(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) + handle._stuck_count = 9999 + assert handle._stuck_in_pending(POD_Phase.RUNNING.value) is False + + def test_stuck_in_pending_resets_count_on_non_pending(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) + handle._stuck_count = 3 + handle._stuck_in_pending(POD_Phase.RUNNING.value) + assert handle._stuck_count == 0 + + def test_stuck_in_pending_increments_count(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=100) + initial = handle._stuck_count + handle._stuck_in_pending(POD_Phase.PENDING.value) + assert handle._stuck_count == initial + 1 + + def test_stuck_in_pending_returns_false_when_under_max(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=None, pending_timeout=100) + assert handle._stuck_in_pending(POD_Phase.PENDING.value) is False + + def test_stuck_in_pending_never_fires_when_pending_timeout_none(self): + # pending_timeout=None with timeout=None → _max_stuck_count=None → stuck detection disabled + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=None, pending_timeout=None) + assert handle._max_stuck_count is None + # Drive _stuck_count very high — must not raise and must return False + handle._stuck_count = 10_000 + assert handle._stuck_in_pending(POD_Phase.PENDING.value) is False + + # -- wait ----------------------------------------------------------------- + def test_wait_returns_immediately_if_terminal_state_set(self): + api = _make_api_instance() + handle = _make_handle(api=api) + handle.terminal_state = JobState.TERMINATED + handle.wait() + api.read_namespaced_pod.assert_not_called() + + def test_wait_persists_succeeded_terminal_state(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + handle.wait() + assert handle.terminal_state == JobState.SUCCEEDED + + def test_wait_persists_terminated_terminal_state(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.FAILED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + handle.wait() + assert handle.terminal_state == JobState.TERMINATED + + def test_wait_poll_consistent_after_wait(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + handle.wait() + assert handle.poll() == JobReturnCode.SUCCESS + + # -- enter_states --------------------------------------------------------- + def test_enter_states_returns_true_when_state_matches(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.RUNNING.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + assert handle.enter_states([JobState.RUNNING]) is True + + def test_enter_states_accepts_single_state(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + assert handle.enter_states(JobState.SUCCEEDED) is True + + def test_enter_states_returns_false_on_timeout(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.PENDING.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api, timeout=0) + assert handle.enter_states([JobState.RUNNING]) is False + + def test_enter_states_terminates_on_timeout(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.PENDING.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api, timeout=0) + handle.enter_states([JobState.RUNNING]) + api.delete_namespaced_pod.assert_called_once() + assert handle.terminal_state == JobState.TERMINATED + + def test_enter_states_returns_false_on_terminal_pod_phase(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.FAILED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + assert handle.enter_states([JobState.RUNNING]) is False + assert handle.terminal_state == JobState.TERMINATED + + def test_enter_states_returns_false_on_succeeded_when_waiting_for_running(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + assert handle.enter_states([JobState.RUNNING]) is False + assert handle.terminal_state == JobState.SUCCEEDED + + def test_enter_states_raises_on_invalid_state(self): + handle = _make_handle() + with pytest.raises(ValueError, match="expect job_states_to_enter"): + handle.enter_states(["not_a_state"]) + + # -- enter_states: wall-clock timeout branch ------------------------------ + # The pod is placed in UNKNOWN phase (non-pending so stuck detection does + # not fire, non-terminal so the terminal-phase branch is skipped) and + # time.time is mocked so the elapsed-time check fires on the first + # iteration. This is the branch that existing tests miss because they + # use timeout=0 with a PENDING pod, which hits stuck detection instead. + + @patch("nvflare.app_opt.job_launcher.k8s_launcher.time") + def test_enter_states_wall_clock_timeout_returns_false(self, mock_time): + mock_time.time.side_effect = [0.0, 100.0] # start=0, check=100 → 100 > timeout=10 + mock_time.sleep = Mock() + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.UNKNOWN.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api, timeout=10) + assert handle.enter_states([JobState.RUNNING]) is False + + @patch("nvflare.app_opt.job_launcher.k8s_launcher.time") + def test_enter_states_wall_clock_timeout_calls_terminate_and_sets_terminal_state(self, mock_time): + mock_time.time.side_effect = [0.0, 100.0] + mock_time.sleep = Mock() + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.UNKNOWN.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api, timeout=10) + handle.enter_states([JobState.RUNNING]) + api.delete_namespaced_pod.assert_called_once() + assert handle.terminal_state == JobState.TERMINATED + + @patch("nvflare.app_opt.job_launcher.k8s_launcher.time") + def test_enter_states_wall_clock_not_fired_when_timeout_none(self, mock_time): + # With timeout=None the wall-clock guard (self.timeout is not None) is + # unconditionally False; the loop exits through the terminal-phase path. + mock_time.time.return_value = 9999.0 + mock_time.sleep = Mock() + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api, timeout=None) + handle.enter_states([JobState.RUNNING]) + api.delete_namespaced_pod.assert_not_called() + + @patch("nvflare.app_opt.job_launcher.k8s_launcher.time") + def test_enter_states_wall_clock_not_fired_before_elapsed(self, mock_time): + # First iteration: time not yet elapsed → wall-clock skipped, loop continues. + # Second iteration: pod completes → exits via terminal-phase path, no terminate(). + mock_time.time.side_effect = [0.0, 0.5] # start=0, first check=0.5 < timeout=10 + mock_time.sleep = Mock() + api = _make_api_instance() + resp_unknown = Mock() + resp_unknown.status.phase = POD_Phase.UNKNOWN.value + resp_succeeded = Mock() + resp_succeeded.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.side_effect = [resp_unknown, resp_succeeded] + handle = _make_handle(api=api, timeout=10) + result = handle.enter_states([JobState.RUNNING]) + assert result is False + api.delete_namespaced_pod.assert_not_called() + assert handle.terminal_state == JobState.SUCCEEDED + + +# --------------------------------------------------------------------------- +# _job_args_dict helper +# --------------------------------------------------------------------------- +class TestJobArgsDict: + def test_basic(self): + job_args = { + "workspace": ("-w", "/workspace"), + "job_id": ("-j", "job-1"), + } + result = _job_args_dict(job_args, ["workspace", "job_id"]) + assert result == {"-w": "/workspace", "-j": "job-1"} + + def test_skips_missing_keys(self): + job_args = {"workspace": ("-w", "/workspace")} + result = _job_args_dict(job_args, ["workspace", "missing_key"]) + assert result == {"-w": "/workspace"} + + def test_empty_args(self): + assert _job_args_dict({}, ["a", "b"]) == {} + + def test_empty_arg_names(self): + assert _job_args_dict({"workspace": ("-w", "/workspace")}, []) == {} + + +# --------------------------------------------------------------------------- +# K8sJobLauncher handle_event +# --------------------------------------------------------------------------- +def _make_k8s_launcher_patches(): + return [ + patch("nvflare.app_opt.job_launcher.k8s_launcher.config"), + patch("nvflare.app_opt.job_launcher.k8s_launcher.Configuration"), + patch("nvflare.app_opt.job_launcher.k8s_launcher.core_v1_api"), + patch("builtins.open", create=True), + patch("nvflare.app_opt.job_launcher.k8s_launcher.yaml"), + ] + + +def _enter_patches(patches): + mocks = [p.start() for p in patches] + return mocks + + +def _exit_patches(patches): + for p in patches: + p.stop() + + +def _setup_launcher(mock_yaml, mock_conf, launcher_cls): + mock_yaml.safe_load.return_value = {"data-pvc": "/data"} + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + return launcher_cls( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + ) + + +class TestK8sJobLauncherHandleEvent: + def test_adds_launcher_when_image_present(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ClientK8sJobLauncher) + fl_ctx = FLContext() + job_meta = {JobMetaKey.DEPLOY_MAP.value: {"app": [{"sites": ["site-1"], "image": "nvflare/custom:latest"}]}} + fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "site-1", private=False, sticky=True) + + launcher.handle_event(EventType.BEFORE_JOB_LAUNCH, fl_ctx) + + launchers = fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) + assert launchers is not None + assert launcher in launchers + finally: + _exit_patches(patches) + + def test_skips_when_no_image(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ClientK8sJobLauncher) + fl_ctx = FLContext() + job_meta = {JobMetaKey.DEPLOY_MAP.value: {}} + fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "site-1", private=False, sticky=True) + + launcher.handle_event(EventType.BEFORE_JOB_LAUNCH, fl_ctx) + + assert fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) is None + finally: + _exit_patches(patches) + + +# --------------------------------------------------------------------------- +# K8sJobLauncher __init__ +# --------------------------------------------------------------------------- +class TestK8sJobLauncherInit: + def test_init_reads_data_pvc(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + mock_yaml.safe_load.return_value = {"my-data-pvc": "/mount/data"} + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + + launcher = ClientK8sJobLauncher( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + timeout=60, + namespace="test-ns", + ) + + assert launcher.workspace_pvc == "ws-pvc" + assert launcher.etc_pvc == "etc-pvc" + assert launcher.data_pvc == "my-data-pvc" + assert launcher.timeout == 60 + assert launcher.namespace == "test-ns" + finally: + _exit_patches(patches) + + def test_init_raises_on_empty_pvc_file(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + mock_yaml.safe_load.return_value = {} + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + + with pytest.raises(ValueError, match="empty"): + ClientK8sJobLauncher( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + ) + finally: + _exit_patches(patches) + + def test_init_raises_on_non_dict_pvc_file(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + mock_yaml.safe_load.return_value = "not-a-dict" + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + + with pytest.raises(ValueError, match="dictionary"): + ClientK8sJobLauncher( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + ) + finally: + _exit_patches(patches) + + +# --------------------------------------------------------------------------- +# ClientK8sJobLauncher.get_module_args +# --------------------------------------------------------------------------- +class TestClientK8sJobLauncherGetModuleArgs: + def test_returns_dict(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ClientK8sJobLauncher) + fl_ctx = FLContext() + job_args = { + JobProcessArgs.WORKSPACE: ("-w", "/workspace"), + JobProcessArgs.JOB_ID: ("-j", "job-1"), + } + fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args, private=True, sticky=False) + + result = launcher.get_module_args("job-1", fl_ctx) + assert isinstance(result, dict) + assert result.get("-w") == "/workspace" + finally: + _exit_patches(patches) + + def test_raises_when_no_args(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ClientK8sJobLauncher) + fl_ctx = FLContext() + with pytest.raises(RuntimeError, match="job_process_args"): + launcher.get_module_args("job-1", fl_ctx) + finally: + _exit_patches(patches) + + +# --------------------------------------------------------------------------- +# ServerK8sJobLauncher.get_module_args +# --------------------------------------------------------------------------- +class TestServerK8sJobLauncherGetModuleArgs: + def test_returns_dict(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ServerK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ServerK8sJobLauncher) + fl_ctx = FLContext() + job_args = { + JobProcessArgs.WORKSPACE: ("-w", "/workspace"), + JobProcessArgs.JOB_ID: ("-j", "job-1"), + JobProcessArgs.ROOT_URL: ("--root_url", "grpc://server:8003"), + } + fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args, private=True, sticky=False) + + result = launcher.get_module_args("job-1", fl_ctx) + assert isinstance(result, dict) + assert result.get("-w") == "/workspace" + finally: + _exit_patches(patches) + + def test_raises_when_no_args(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ServerK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ServerK8sJobLauncher) + fl_ctx = FLContext() + with pytest.raises(RuntimeError, match="job_process_args"): + launcher.get_module_args("job-1", fl_ctx) + finally: + _exit_patches(patches) + + +# --------------------------------------------------------------------------- +# K8sJobLauncher launch_job — integration-style happy path +# --------------------------------------------------------------------------- + +_WORKER_MODULE = "nvflare.private.fed.app.client.worker_process" +_JOB_UUID = "550e8400-e29b-41d4-a716-446655440000" +_EXPECTED_JOB_ID = uuid4_to_rfc1123(_JOB_UUID) + + +def _make_launch_job_meta(site_name="site-1", image="nvflare/nvflare:latest", gpu=None): + meta = { + JobConstants.JOB_ID: _JOB_UUID, + JobMetaKey.DEPLOY_MAP.value: {"app": [{"sites": [site_name], "image": image}]}, + } + if gpu is not None: + meta[JobMetaKey.RESOURCE_SPEC.value] = {site_name: {"num_of_gpus": gpu}} + return meta + + +def _make_launch_fl_ctx(site_name="site-1", set_items=None): + fl_ctx = FLContext() + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, site_name, private=False, sticky=True) + job_args = { + JobProcessArgs.EXE_MODULE: ("-m", _WORKER_MODULE), + JobProcessArgs.WORKSPACE: ("-w", "/var/tmp/nvflare/workspace"), + JobProcessArgs.JOB_ID: ("-n", "job-abc"), + } + fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args, private=True, sticky=False) + if set_items is not None: + args_obj = Mock() + args_obj.set = set_items + fl_ctx.set_prop(FLContextKey.ARGS, args_obj, private=False, sticky=False) + return fl_ctx + + +class TestK8sJobLauncherLaunchJob: + """Integration-style tests that exercise the full launch_job() code path. + + The kubernetes API is mocked but the rest of the code — uuid sanitization, + manifest construction, enter_states polling, and handle construction — runs + for real. read_namespaced_pod is primed to return Running immediately so + enter_states returns True on the first iteration without sleeping. + """ + + def _setup(self, patches, namespace="test-ns"): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + mock_cfg, mock_conf, mock_core_module, mock_open, mock_yaml = _enter_patches(patches) + mock_yaml.safe_load.return_value = {"data-pvc": "/data"} + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + mock_api = MagicMock() + mock_core_module.CoreV1Api.return_value = mock_api + launcher = ClientK8sJobLauncher( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + namespace=namespace, + ) + return launcher, mock_api + + def _prime_running(self, mock_api): + resp = Mock() + resp.status.phase = POD_Phase.RUNNING.value + mock_api.read_namespaced_pod.return_value = resp + + # -- return value --------------------------------------------------------- + + def test_returns_k8s_job_handle(self): + from nvflare.app_opt.job_launcher.k8s_launcher import K8sJobHandle + + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + handle = launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + assert isinstance(handle, K8sJobHandle) + finally: + _exit_patches(patches) + + def test_terminal_state_is_none_after_clean_launch(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + handle = launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + assert handle.terminal_state is None + finally: + _exit_patches(patches) + + # -- API call ------------------------------------------------------------- + + def test_create_namespaced_pod_called_once(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + mock_api.create_namespaced_pod.assert_called_once() + finally: + _exit_patches(patches) + + def test_create_namespaced_pod_uses_launcher_namespace(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches, namespace="prod-ns") + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + assert mock_api.create_namespaced_pod.call_args.kwargs["namespace"] == "prod-ns" + finally: + _exit_patches(patches) + + # -- pod manifest: identity fields ---------------------------------------- + + def test_pod_manifest_name_is_rfc1123_of_job_id(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert manifest["metadata"]["name"] == _EXPECTED_JOB_ID + finally: + _exit_patches(patches) + + def test_pod_manifest_image_from_job_meta(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(image="myrepo/myimage:v2"), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert manifest["spec"]["containers"][0]["image"] == "myrepo/myimage:v2" + finally: + _exit_patches(patches) + + def test_pod_manifest_restart_policy_never(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert manifest["spec"]["restartPolicy"] == "Never" + finally: + _exit_patches(patches) + + # -- pod manifest: container args ----------------------------------------- + + def test_pod_manifest_container_args_contain_worker_module(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + args = manifest["spec"]["containers"][0]["args"] + assert "-u" in args + assert "-m" in args + assert _WORKER_MODULE in args + finally: + _exit_patches(patches) + + def test_pod_manifest_set_list_propagated_from_args(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + fl_ctx = _make_launch_fl_ctx(set_items=["lr=0.01", "epochs=5"]) + launcher.launch_job(_make_launch_job_meta(), fl_ctx) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + args = manifest["spec"]["containers"][0]["args"] + assert "--set" in args + assert "lr=0.01" in args + assert "epochs=5" in args + finally: + _exit_patches(patches) + + def test_pod_manifest_no_set_list_when_args_not_set(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + args = manifest["spec"]["containers"][0]["args"] + assert "--set" not in args + finally: + _exit_patches(patches) + + # -- pod manifest: volumes ------------------------------------------------ + + def test_pod_manifest_pvcs_use_launcher_pvc_names(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + claims = {v["name"]: v["persistentVolumeClaim"]["claimName"] for v in manifest["spec"]["volumes"]} + assert claims[PV_NAME.WORKSPACE.value] == "ws-pvc" + assert claims[PV_NAME.DATA.value] == "data-pvc" + assert claims[PV_NAME.ETC.value] == "etc-pvc" + finally: + _exit_patches(patches) + + # -- pod manifest: GPU resources ------------------------------------------ + + def test_pod_manifest_includes_gpu_limit_when_specified(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(gpu=2), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert manifest["spec"]["containers"][0]["resources"]["limits"]["nvidia.com/gpu"] == 2 + finally: + _exit_patches(patches) + + def test_pod_manifest_omits_gpu_limit_when_not_specified(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert "resources" not in manifest["spec"]["containers"][0] + finally: + _exit_patches(patches) + + # -- create_namespaced_pod failure paths ---------------------------------- + + def test_network_error_on_create_returns_handle_with_terminal_state(self): + """Non-ApiException (e.g. network timeout) must not leave terminal_state=None.""" + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + mock_api.create_namespaced_pod.side_effect = ConnectionError("network unreachable") + try: + handle = launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + assert isinstance(handle, K8sJobHandle) + assert handle.terminal_state == JobState.TERMINATED + finally: + _exit_patches(patches) + + def test_network_error_on_create_does_not_call_terminate_api(self): + """Pod was never created; no delete API call should be made.""" + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + mock_api.create_namespaced_pod.side_effect = ConnectionError("network unreachable") + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + mock_api.delete_namespaced_pod.assert_not_called() + finally: + _exit_patches(patches) + + def test_network_error_on_create_does_not_call_enter_states(self): + """enter_states must not be reached when pod creation fails.""" + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + mock_api.create_namespaced_pod.side_effect = ConnectionError("network unreachable") + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + # read_namespaced_pod is the backing call for _query_phase / enter_states + mock_api.read_namespaced_pod.assert_not_called() + finally: + _exit_patches(patches) diff --git a/tests/unit_test/tool/job/config/configer_test.py b/tests/unit_test/tool/job/config/configer_test.py index 792d88b315..2f3ac153e4 100644 --- a/tests/unit_test/tool/job/config/configer_test.py +++ b/tests/unit_test/tool/job/config/configer_test.py @@ -32,7 +32,7 @@ ( "launch_once", [ - ["app/config_fed_client.conf", "app_script=cifar10_fl.py", "launch_once=False"], + ["app/config_fed_client.conf", "app_script=cifar10_fl.py", "launch_once=false"], ["meta.conf", "min_clients=3"], ], "launch_everytime", @@ -50,7 +50,7 @@ ( "launch_once", [ - ["app/config/config_fed_client.conf", "app_script=cifar10_fl.py", "launch_once=False"], + ["app/config/config_fed_client.conf", "app_script=cifar10_fl.py", "launch_once=false"], ["meta.conf", "min_clients=3"], ], "launch_everytime",