From fb21e248cf4ca0579a3fc2aa9e0b8f3c42e0a375 Mon Sep 17 00:00:00 2001 From: Isaac Yang Date: Mon, 9 Mar 2026 14:58:21 -0700 Subject: [PATCH 1/3] Design document for Job Launcher and Job Handle (TA: NVFlare developers) Job Launcher for K8s environement GPU, image, pvc updated and working Add codes Add unit tests --- docs/design/JobLauncher_and_JobHandle.md | 641 +++++++++++++++++ nvflare/apis/job_launcher_spec.py | 4 +- .../BE_resource_consumer.py | 21 + .../resource_managers/BE_resource_manager.py | 50 ++ .../resource_managers/gpu_resource_manager.py | 30 +- .../app_opt/job_launcher/docker_launcher.py | 24 +- nvflare/app_opt/job_launcher/k8s_launcher.py | 234 +++--- nvflare/private/fed/client/communicator.py | 7 +- .../app_opt/job_launcher/__init__.py | 13 + .../job_launcher/docker_launcher_test.py | 464 ++++++++++++ .../app_opt/job_launcher/k8s_launcher_test.py | 666 ++++++++++++++++++ 11 files changed, 2041 insertions(+), 113 deletions(-) create mode 100644 docs/design/JobLauncher_and_JobHandle.md create mode 100644 nvflare/app_common/resource_consumers/BE_resource_consumer.py create mode 100644 nvflare/app_common/resource_managers/BE_resource_manager.py create mode 100644 tests/unit_test/app_opt/job_launcher/__init__.py create mode 100644 tests/unit_test/app_opt/job_launcher/docker_launcher_test.py create mode 100644 tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py diff --git a/docs/design/JobLauncher_and_JobHandle.md b/docs/design/JobLauncher_and_JobHandle.md new file mode 100644 index 0000000000..650b159785 --- /dev/null +++ b/docs/design/JobLauncher_and_JobHandle.md @@ -0,0 +1,641 @@ +# JobLauncher and JobHandle Design Document + +## 1. Overview + +NVFlare runs each federated job as an isolated execution unit -- a subprocess, Docker container, or Kubernetes pod. Two abstractions govern this: + +- **JobLauncherSpec** -- starts a job and returns a handle. +- **JobHandleSpec** -- represents the running job and provides lifecycle control (poll, wait, terminate). + +The upper layers (server engine, client executor) program exclusively against these two interfaces. The concrete backend is selected at runtime through an event-based mechanism, so the engine never imports or names a specific launcher type. + +``` +┌──────────────────────────────────────────────────────────┐ +│ Upper Layer │ +│ ServerEngine / ClientExecutor │ +│ │ +│ 1. Build JOB_PROCESS_ARGS │ +│ 2. get_job_launcher(job_meta, fl_ctx) → launcher │ +│ 3. launcher.launch_job(job_meta, fl_ctx) → job_handle │ +│ 4. job_handle.wait() / job_handle.terminate() │ +└──────────┬──────────────────────┬────────────────────────┘ + │ BEFORE_JOB_LAUNCH │ + │ event selects one │ + ▼ ▼ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ ProcessJob │ │ DockerJob │ │ K8sJob │ +│ Launcher │ │ Launcher │ │ Launcher │ +│ ─────────────── │ │ ─────────────── │ │ ─────────────── │ +│ ProcessHandle │ │ DockerJobHandle │ │ K8sJobHandle │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + subprocess container pod +``` + +--- + +## 2. Specification Layer (`nvflare/apis/job_launcher_spec.py`) + +### 2.1 JobHandleSpec + +Abstract base class representing a running job. All methods are `@abstractmethod`. + +| Method | Signature | Semantics | +|--------|-----------|-----------| +| `terminate()` | `() -> None` | Stop the job immediately. | +| `poll()` | `() -> JobReturnCode` | Non-blocking query for the job's current return code. Returns `UNKNOWN` while still running. | +| `wait()` | `() -> None` | Block until the job finishes (or is terminated). | + +### 2.2 JobLauncherSpec + +Abstract base class for launching jobs. Extends `FLComponent`, which gives it access to the event system. + +| Method | Signature | Semantics | +|--------|-----------|-----------| +| `launch_job(job_meta, fl_ctx)` | `(dict, FLContext) -> JobHandleSpec` | Start a job and return its handle. | + +### 2.3 Supporting Types + +**JobProcessArgs** -- String constants for the keys the upper layer places in `FLContextKey.JOB_PROCESS_ARGS`. These are the standardized parameters every job process needs (workspace path, auth token, job ID, parent URL, etc.). + +| Constant | Value | Used by | +|----------|-------|---------| +| `EXE_MODULE` | `"exe_module"` | Server, Client | +| `WORKSPACE` | `"workspace"` | Server, Client | +| `STARTUP_DIR` | `"startup_dir"` | Client | +| `APP_ROOT` | `"app_root"` | Server | +| `AUTH_TOKEN` | `"auth_token"` | Client | +| `TOKEN_SIGNATURE` | `"auth_signature"` | Server, Client | +| `SSID` | `"ssid"` | Server, Client | +| `JOB_ID` | `"job_id"` | Server, Client | +| `CLIENT_NAME` | `"client_name"` | Client | +| `ROOT_URL` | `"root_url"` | Server | +| `PARENT_URL` | `"parent_url"` | Server, Client | +| `PARENT_CONN_SEC` | `"parent_conn_sec"` | Client | +| `SERVICE_HOST` | `"service_host"` | Server | +| `SERVICE_PORT` | `"service_port"` | Server | +| `HA_MODE` | `"ha_mode"` | Server | +| `TARGET` | `"target"` | Client | +| `SCHEME` | `"scheme"` | Client | +| `STARTUP_CONFIG_FILE` | `"startup_config_file"` | Server, Client | +| `RESTORE_SNAPSHOT` | `"restore_snapshot"` | Server | +| `OPTIONS` | `"options"` | Server, Client | + +**JobReturnCode** -- Standard exit semantics (extends `ProcessExitCode`): + +| Code | Value | Meaning | +|------|-------|---------| +| `SUCCESS` | 0 | Job completed successfully. | +| `EXECUTION_ERROR` | 1 | Job failed during execution. | +| `ABORTED` | 9 | Job was terminated/aborted. | +| `UNKNOWN` | 127 | Status cannot be determined (still running, or lost). | + +**`add_launcher(launcher, fl_ctx)`** -- Appends a launcher to the `FLContextKey.JOB_LAUNCHER` list on `fl_ctx`. Called by launchers during the `BEFORE_JOB_LAUNCH` event to volunteer for the current job. + +--- + +## 3. How the Upper Layer Uses Launchers + +### 3.1 Event-Based Launcher Selection + +The engine never directly instantiates a launcher. Instead, it calls `get_job_launcher()` from `nvflare/private/fed/utils/fed_utils.py`: + +```python +def get_job_launcher(job_meta, fl_ctx) -> JobLauncherSpec: + engine = fl_ctx.get_engine() + with engine.new_context() as job_launcher_ctx: + job_launcher_ctx.remove_prop(FLContextKey.JOB_LAUNCHER) + job_launcher_ctx.set_prop(FLContextKey.JOB_META, job_meta, ...) + engine.fire_event(EventType.BEFORE_JOB_LAUNCH, job_launcher_ctx) + job_launcher = job_launcher_ctx.get_prop(FLContextKey.JOB_LAUNCHER) + ... + return job_launcher[0] +``` + +Every registered `FLComponent` receives the `BEFORE_JOB_LAUNCH` event. Each launcher inspects `job_meta` and, if it can handle the job, calls `add_launcher(self, fl_ctx)`. The first launcher to register wins. + +**Selection rule in practice:** + +| Condition | Launcher selected | +|-----------|-------------------| +| `extract_job_image(job_meta, site_name)` returns `None` | **ProcessJobLauncher** (no container image → run as subprocess) | +| `extract_job_image(job_meta, site_name)` returns an image | **DockerJobLauncher** or **K8sJobLauncher** (whichever is configured as a component) | + +### 3.2 Server Side (`ServerEngine`) + +Location: `nvflare/private/fed/server/server_engine.py` + +``` +_start_runner_process(job, job_clients, snapshot, fl_ctx) +│ +├─ 1. Build job_args dict with server-specific JobProcessArgs +│ (WORKSPACE, APP_ROOT, PARENT_URL, AUTH_TOKEN, HA_MODE, ...) +│ +├─ 2. fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args) +│ +├─ 3. job_launcher = get_job_launcher(job.meta, fl_ctx) +│ +├─ 4. job_handle = job_launcher.launch_job(job.meta, fl_ctx) +│ +├─ 5. Store in run_processes[job_id][RunProcessKey.JOB_HANDLE] +│ +└─ 6. Start background thread → wait_for_complete(workspace, job_id, job_handle) + │ + ├─ job_handle.wait() # blocks until job finishes + └─ get_return_code(job_handle, job_id, workspace, logger) +``` + +**Abort path** (`abort_app_on_server`): + +1. Attempt to send an abort command to the child via the cell messaging system. +2. On failure, retrieve `job_handle` from `run_processes` and call `job_handle.terminate()`. + +### 3.3 Client Side (`ClientExecutor`) + +Location: `nvflare/private/fed/client/client_executor.py` + +``` +start_app(job_id, job_meta, ...) +│ +├─ 1. Build job_args dict with client-specific JobProcessArgs +│ (WORKSPACE, STARTUP_DIR, CLIENT_NAME, PARENT_URL, AUTH_TOKEN, ...) +│ +├─ 2. fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args) +│ +├─ 3. job_launcher = get_job_launcher(job_meta, fl_ctx) +│ +├─ 4. job_handle = job_launcher.launch_job(job_meta, fl_ctx) +│ +├─ 5. Store in run_processes[job_id][RunProcessKey.JOB_HANDLE] +│ +└─ 6. Start background thread → _wait_child_process_finish(...) + │ + ├─ job_handle.wait() + └─ get_return_code(job_handle, job_id, workspace, logger) +``` + +**Abort path** (`_terminate_job`): + +1. Wait up to 10 seconds for the child to exit gracefully (polling `job_handle.poll()`). +2. Call `job_handle.terminate()`. + +### 3.4 Return Code Resolution + +`get_return_code()` in `fed_utils.py` uses a two-tier strategy: + +1. **File-based** -- Check for `FLMetaKey.PROCESS_RC_FILE` in the job's run directory. The child process writes its own return code to this file before exiting. This is the preferred source because it carries the child's own assessment. +2. **Handle-based** -- Fall back to `job_handle.poll()`, which maps the underlying execution unit's status to a `JobReturnCode`. + +--- + +## 4. The Three Implementations + +### 4.1 Process Launcher (Subprocess) + +**Files:** + +| File | Class | +|------|-------| +| `nvflare/app_common/job_launcher/process_launcher.py` | `ProcessHandle`, `ProcessJobLauncher` | +| `nvflare/app_common/job_launcher/server_process_launcher.py` | `ServerProcessJobLauncher` | +| `nvflare/app_common/job_launcher/client_process_launcher.py` | `ClientProcessJobLauncher` | + +**Class hierarchy:** + +``` +JobHandleSpec + └── ProcessHandle + +JobLauncherSpec (FLComponent) + └── ProcessJobLauncher + ├── ServerProcessJobLauncher + └── ClientProcessJobLauncher +``` + +#### ProcessHandle + +Wraps a `ProcessAdapter` (from `nvflare/utils/process_utils.py`) that manages a `subprocess.Popen` or a PID. + +| Method | Implementation | +|--------|---------------| +| `terminate()` | Delegates to `adapter.terminate()` (sends SIGTERM/SIGKILL). | +| `poll()` | Calls `adapter.poll()`. Maps exit code 0 → `SUCCESS`, 1 → `EXECUTION_ERROR`, 9 → `ABORTED`, other → `EXECUTION_ERROR`, `None` → `UNKNOWN`. | +| `wait()` | Delegates to `adapter.wait()` (blocks on `subprocess.Popen.wait()`). | + +#### ProcessJobLauncher + +| Step | Action | +|------|--------| +| 1 | Copy `os.environ` and add `app_custom_folder` to `PYTHONPATH`. | +| 2 | Call `self.get_command(job_meta, fl_ctx)` (abstract -- implemented by server/client subclasses). | +| 3 | Parse command with `shlex.split()`, spawn the process via `spawn_process(argv, new_env)`. | +| 4 | Return `ProcessHandle(process_adapter=...)`. | + +**Event registration:** + +```python +def handle_event(self, event_type, fl_ctx): + if event_type == EventType.BEFORE_JOB_LAUNCH: + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) + if not job_image: # no container image → use subprocess + add_launcher(self, fl_ctx) +``` + +**Server/Client subclasses** only override `get_command()`: + +- `ServerProcessJobLauncher.get_command()` → `generate_server_command(fl_ctx)` → `python -m -w ...` +- `ClientProcessJobLauncher.get_command()` → `generate_client_command(fl_ctx)` → `python -m -w -n ...` + +--- + +### 4.2 Docker Launcher + +**File:** `nvflare/app_opt/job_launcher/docker_launcher.py` + +**Class hierarchy:** + +``` +JobHandleSpec + └── DockerJobHandle + +JobLauncherSpec (FLComponent) + └── DockerJobLauncher + ├── ClientDockerJobLauncher + └── ServerDockerJobLauncher +``` + +#### DockerJobHandle + +Wraps a Docker SDK `Container` object. + +| Method | Implementation | +|--------|---------------| +| `terminate()` | `container.stop()`. | +| `poll()` | Re-fetches container via `docker.from_env().containers.get(id)`. Maps status: `EXITED` → `SUCCESS`, `DEAD` → `ABORTED`, all others → `UNKNOWN`. Removes the container on terminal states. | +| `wait()` | `enter_states([EXITED, DEAD], timeout)` -- polls container status in a 1-second loop until a terminal state is reached. | + +Docker container states and their mappings: + +| Docker Status | JobReturnCode | +|---------------|---------------| +| `created` | `UNKNOWN` | +| `restarting` | `UNKNOWN` | +| `running` | `UNKNOWN` | +| `paused` | `UNKNOWN` | +| `exited` | `SUCCESS` | +| `dead` | `ABORTED` | + +#### DockerJobLauncher + +Constructor parameters: + +| Parameter | Default | Purpose | +|-----------|---------|---------| +| `mount_path` | `"/workspace"` | Container-side mount point for the host workspace. | +| `network` | `"nvflare-network"` | Docker network the container joins. | +| `timeout` | `None` | Maximum seconds to wait for the container to reach `RUNNING`. | + +Launch sequence: + +| Step | Action | +|------|--------| +| 1 | Extract `job_image` from `job_meta` via `extract_job_image()`. | +| 2 | Build `PYTHONPATH` with `app_custom_folder`. | +| 3 | Call `self.get_command(job_meta, fl_ctx)` → `(container_name, command_string)`. | +| 4 | Read `NVFL_DOCKER_WORKSPACE` env var for the host-side workspace path. | +| 5 | `docker_client.containers.run(image, command, name, network, volumes, detach=True)`. | +| 6 | `DockerJobHandle(container).enter_states([RUNNING], timeout)`. | +| 7 | If timeout or error → `handle.terminate()`, return `None`. Otherwise return handle. | + +**Event registration:** Same pattern as Process but with the opposite condition -- registers when `extract_job_image()` returns a truthy value. + +**Server/Client subclasses** override `get_command()`: + +- `ClientDockerJobLauncher` → returns `("{client_name}-{job_id}", generate_client_command(fl_ctx))`. +- `ServerDockerJobLauncher` → returns `("server-{job_id}", generate_server_command(fl_ctx))`. + +--- + +### 4.3 Kubernetes Launcher + +**File:** `nvflare/app_opt/job_launcher/k8s_launcher.py` + +**Class hierarchy:** + +``` +JobHandleSpec + └── K8sJobHandle + +JobLauncherSpec (FLComponent) + └── K8sJobLauncher + ├── ClientK8sJobLauncher + └── ServerK8sJobLauncher +``` + +#### K8sJobHandle + +Wraps a Kubernetes Pod managed through the `CoreV1Api`. + +| Method | Implementation | +|--------|---------------| +| `terminate()` | Calls `delete_namespaced_pod(grace_period_seconds=0)` in a try/except. `terminal_state = TERMINATED` is set when the delete succeeds, or when the `ApiException` has status 404 (pod already gone). For any other `ApiException`, the error is logged and `terminal_state` is left unchanged. | +| `poll()` | If `terminal_state` is set, maps it through `JOB_RETURN_CODE_MAPPING` and returns a `JobReturnCode`. Otherwise calls `_query_state()` and maps the result the same way. Both paths consistently return `JobReturnCode`. | +| `wait()` | Direct while loop: returns immediately if `terminal_state` is set; otherwise calls `_query_state()` and when `SUCCEEDED` or `TERMINATED` is reached, persists that state into `terminal_state` (so subsequent `poll()` calls remain accurate) and returns. Sleeps 1 second per iteration. No timeout. | +| `_query_phase()` | Calls `read_namespaced_pod` and returns the raw pod phase string (e.g. `"Pending"`, `"Running"`). On `ApiException`, returns `POD_Phase.UNKNOWN.value`. | +| `_query_state()` | Calls `_query_phase()` and maps the raw phase through `POD_STATE_MAPPING` to a `JobState`. Used by `poll()` and `wait()`. | +| `enter_states()` | Per iteration: calls `_query_phase()` once, passes the raw phase to both `_stuck()` and directly to `POD_STATE_MAPPING.get()` — single K8s API call per poll cycle. Returns `True` when target state is reached, `False` on timeout or stuck detection. | + +Pod phase mapping: + +| Pod Phase | JobState | JobReturnCode | +|-----------|----------|---------------| +| `Pending` | `STARTING` | `UNKNOWN` | +| `Running` | `RUNNING` | `UNKNOWN` | +| `Succeeded` | `SUCCEEDED` | `SUCCESS` | +| `Failed` | `TERMINATED` | `ABORTED` | +| `Unknown` | `UNKNOWN` | `UNKNOWN` | + +> Note: `POD_Phase.TERMINATED` has been removed from the enum. `POD_STATE_MAPPING` now covers only the five real Kubernetes pod phases: `Pending`, `Running`, `Succeeded`, `Failed`, `Unknown`. + +**Stuck detection:** `_stuck_count` starts at `0`. A separate `_stuck_grace_period = 10` is added to `timeout` to form `_max_stuck_count = timeout + _stuck_grace_period`, giving a grace window of ~10 extra poll cycles before stuck detection activates. If `timeout` is `None`, `_max_stuck_count` is also `None` and stuck detection is disabled entirely. `enter_states()` passes the raw phase string from `_query_phase()` directly to `_stuck()`. `_stuck()` compares `current_phase == POD_Phase.PENDING.value` (i.e. `"Pending" == "Pending"`), incrementing `_stuck_count` on each match. When `_stuck_count > _max_stuck_count`, `_stuck()` returns `True`, `enter_states()` calls `terminate()` (which sets `terminal_state = TERMINATED` when the delete call succeeds or returns 404) and returns `False`. Note: `_stuck_count` and `_max_stuck_count` are poll-iteration counts (each ~1 second), not wall-clock seconds — the semantics coincide only because each poll sleeps exactly 1 second. + +#### K8sJobHandle Pod Manifest + +The handle constructs the pod manifest internally from a `job_config` dict: + +```yaml +apiVersion: v1 +kind: Pod +metadata: + name: +spec: + restartPolicy: Never + containers: + - name: container- + image: + command: ["/usr/local/bin/python"] + args: ["-u", "-m", "", "-w", "", ...] + volumeMounts: + - name: nvflws + mountPath: /var/tmp/nvflare/workspace + - name: nvfldata + mountPath: /var/tmp/nvflare/data + - name: nvfletc + mountPath: /var/tmp/nvflare/etc + resources: + limits: + nvidia.com/gpu: # omitted if None + imagePullPolicy: Always + volumes: + - name: nvflws + persistentVolumeClaim: + claimName: + - name: nvfldata + persistentVolumeClaim: + claimName: + - name: nvfletc + persistentVolumeClaim: + claimName: +``` + +#### K8sJobLauncher + +Constructor parameters: + +| Parameter | Purpose | +|-----------|---------| +| `config_file_path` | Path to kubeconfig file. Loaded via `config.load_kube_config()`. | +| `workspace_pvc` | PVC claim name for the NVFlare workspace. | +| `etc_pvc` | PVC claim name for configuration/etc data. | +| `data_pvc_file_path` | Path to a YAML file mapping PVC names to mount paths for training data. | +| `timeout` | Maximum seconds to wait for pod to reach `Running` (also used as stuck threshold). | +| `namespace` | Kubernetes namespace (default: `"default"`). | + +Launch sequence: + +| Step | Action | +|------|--------| +| 1 | Extract `job_image`, `site_name`, and optional `num_of_gpus` from `job_meta`. | +| 2 | Read `JOB_PROCESS_ARGS` from `fl_ctx`; extract `EXE_MODULE` as the container command. | +| 3 | Build `job_config` dict: name, image, container name, command, volume mounts/PVCs, `module_args` from `get_module_args()`, set list, GPU resources. | +| 4 | Create `K8sJobHandle(job_id, core_v1, job_config, namespace, timeout)` which builds the pod manifest. | +| 5 | `core_v1.create_namespaced_pod(body=pod_manifest, namespace)`. | +| 6 | Call `job_handle.enter_states([RUNNING], timeout)`. The return value is not checked. If stuck detection fires, `terminate()` is called inside `enter_states` (sets `terminal_state = TERMINATED` via `finally`) before returning the handle, so the caller can detect failure via `poll()`. On plain timeout (no stuck), the handle is returned with `terminal_state` unset and the pod may still be starting. | +| 7 | On `ApiException` from `create_namespaced_pod` → `job_handle.terminate()` then return the handle. Unlike Docker (which returns `None` on failure), the K8s launcher always returns a handle; callers detect failure when `poll()` or `wait()` resolves. | + +**Server/Client subclasses** override `get_module_args()`: + +- `ClientK8sJobLauncher` → Filters `JOB_PROCESS_ARGS` through `get_client_job_args(include_exe_module=False, include_set_options=False)` to produce the dict of `-flag value` pairs for the container args list. +- `ServerK8sJobLauncher` → Same pattern with `get_server_job_args(...)`. + +**Key difference from Process/Docker:** The K8s launcher does not build a shell command string. Instead, it passes the Python executable as `command` and constructs a structured `args` list (`["-u", "-m", "", "-w", "", ...]`) directly in the pod spec. + +--- + +## 5. Object-Oriented Design Summary + +### 5.1 Full Class Hierarchy + +``` +JobHandleSpec (abstract) +├── ProcessHandle (wraps ProcessAdapter / subprocess.Popen) +├── DockerJobHandle (wraps docker.Container) +└── K8sJobHandle (wraps CoreV1Api + pod name) + +JobLauncherSpec (abstract, extends FLComponent) +├── ProcessJobLauncher (abstract: get_command) +│ ├── ServerProcessJobLauncher +│ └── ClientProcessJobLauncher +├── DockerJobLauncher (abstract: get_command) +│ ├── ServerDockerJobLauncher +│ └── ClientDockerJobLauncher +└── K8sJobLauncher (abstract: get_module_args) + ├── ServerK8sJobLauncher + └── ClientK8sJobLauncher +``` + +### 5.2 Design Patterns + +**Strategy Pattern** -- Each launcher is a strategy for running jobs. The engine programs against `JobLauncherSpec`; the concrete strategy is selected at runtime through the event system. + +**Template Method Pattern** -- Each base launcher (`ProcessJobLauncher`, `DockerJobLauncher`, `K8sJobLauncher`) implements `launch_job()` with a fixed algorithm, delegating the variable part to an abstract method: + +| Base Launcher | Template method calls | Abstract hook | +|---------------|----------------------|---------------| +| `ProcessJobLauncher` | `launch_job()` → `get_command()` | `get_command(job_meta, fl_ctx) -> str` | +| `DockerJobLauncher` | `launch_job()` → `get_command()` | `get_command(job_meta, fl_ctx) -> (str, str)` | +| `K8sJobLauncher` | `launch_job()` → `get_module_args()` | `get_module_args(job_id, fl_ctx) -> dict` | + +Server and client subclasses provide the implementation of these hooks, producing the correct command-line arguments for each role. + +**Observer Pattern** -- Launchers register for the `BEFORE_JOB_LAUNCH` event through the `FLComponent` event system. This decouples launcher registration from the engine's control flow entirely. + +--- + +## 6. Comparison: Process vs Docker vs Kubernetes + +| Aspect | Process | Docker | Kubernetes | +|--------|---------|--------|------------| +| **When selected** | No `job_image` for site | `job_image` present | `job_image` present | +| **Execution unit** | OS subprocess | Docker container | Kubernetes Pod | +| **Isolation** | Shared host, inherited env | Container isolation, mounted workspace | Pod isolation, PVC-backed volumes | +| **Command format** | Shell command string (`python -m ...`) | Shell command inside `/bin/bash -c` | Structured `command` + `args` list in pod spec | +| **Workspace access** | Direct filesystem (same host) | Host directory bind-mounted to container | PersistentVolumeClaims | +| **Data access** | Direct filesystem | Via bind mount | Via PVC (configured in YAML) | +| **Start verification** | None (spawn returns immediately) | Poll for `RUNNING` state with timeout | `enter_states([RUNNING], timeout)` with stuck detection; return value not checked — on stuck, `terminal_state` is set so caller can detect via `poll()` | +| **Wait mechanism** | `subprocess.Popen.wait()` (OS-level block) | Poll container status for `EXITED`/`DEAD` | Direct while loop via `_query_state()`; no timeout; exits when `terminal_state` set or `SUCCEEDED`/`TERMINATED` reached | +| **Terminate** | `SIGTERM`/`SIGKILL` via `ProcessAdapter` | `container.stop()` | `delete_namespaced_pod(grace_period=0)`; `terminal_state` set to `TERMINATED` on success or 404; left unchanged (error logged) for other exceptions | +| **Return code source** | Process exit code or RC file | Container status mapping or RC file | Pod phase mapping or RC file; `poll()` now consistently returns `JobReturnCode` via `JOB_RETURN_CODE_MAPPING` | +| **GPU support** | Inherited from host environment | Not explicitly managed | `nvidia.com/gpu` resource limit in pod spec | +| **Dependencies** | stdlib only | `docker` Python SDK | `kubernetes` Python client + kubeconfig | +| **Typical use case** | Simulator, single-machine POC | Multi-container on single host | Production cluster with shared storage | + +--- + +## 7. Backend (BE) Resource Management + +When using the K8s launcher, the NVFlare process managing jobs may not itself run on a GPU node. The K8s scheduler handles actual resource allocation externally. Two passthrough classes support this case: + +### 7.1 BEResourceManager (`nvflare/app_common/resource_managers/BE_resource_manager.py`) + +`BEResourceManager(ResourceManagerSpec, FLComponent)` is a "Best Effort" resource manager: it always approves resource allocation requests and performs no local tracking, allowing jobs to attempt to run and fail at runtime if resources are genuinely unavailable: + +| Method | Behavior | +|--------|----------| +| `check_resources()` | Always returns `(True, )` -- never rejects a job. | +| `cancel_resources()` | No-op. | +| `allocate_resources()` | Returns empty dict `{}`. | +| `free_resources()` | No-op. | +| `report_resources()` | Returns `{}` (empty dict, conforming to the `ResourceManagerSpec` contract). | + +Use this when the container orchestration backend (K8s) is responsible for all resource accounting. + +### 7.2 BEResourceConsumer (`nvflare/app_common/resource_consumers/BE_resource_consumer.py`) + +`BEResourceConsumer(ResourceConsumerSpec)` implements a no-op `consume()`. Use this alongside `BEResourceManager` when no local resource consumption reporting is needed. + +### 7.3 GPUResourceManager `ignore_host` flag (`nvflare/app_common/resource_managers/gpu_resource_manager.py`) + +`GPUResourceManager` gained an `ignore_host=False` parameter. When `True`, the constructor skips the startup validation that checks whether the declared `num_of_gpus` and `mem_per_gpu_in_GiB` match actual host hardware. This is needed in K8s deployments where the NVFlare process runs on a node without GPUs but still needs to track a GPU resource pool for job scheduling purposes. + +--- + +## 8. Sequence Diagram + +The following shows the end-to-end flow for launching and managing a job, applicable to both server and client: + +``` + Engine fed_utils Launcher(s) Handle + │ │ │ │ + │ get_job_launcher() │ │ │ + │───────────────────────>│ │ │ + │ │ fire BEFORE_JOB_LAUNCH │ + │ │─────────────────────>│ │ + │ │ │ check job_meta │ + │ │ │ add_launcher(self) │ + │ │<─────────────────────│ │ + │ return launcher │ │ │ + │<───────────────────────│ │ │ + │ │ │ │ + │ launcher.launch_job(job_meta, fl_ctx) │ │ + │─────────────────────────────────────────────->│ │ + │ │ │ create exec unit │ + │ │ │ (process/container │ + │ │ │ /pod) │ + │ │ │─────────────────────>│ + │ │ │ return handle │ + │<─────────────────────────────────────────────────────────────────────│ + │ │ │ │ + │ store handle in run_processes │ │ + │ │ │ │ + │ [background thread] │ │ │ + │ handle.wait() │ │ │ + │─────────────────────────────────────────────────────────────────────>│ + │ │ │ blocks/polls │ + │ │ │ │ + │ ... (on abort) ... │ │ │ + │ handle.terminate() │ │ │ + │─────────────────────────────────────────────────────────────────────>│ + │ │ │ │ + │ ... (on completion) . │ │ │ + │ get_return_code() │ │ │ + │───────────────────────>│ │ │ + │ │ check RC file │ │ + │ │ or handle.poll() │ │ + │ │─────────────────────────────────────────────>│ + │ return_code │ │ │ + │<───────────────────────│ │ │ +``` + +--- + +## 9. Configuration + +Launchers are registered as FL components in the site's `resources.json`. The configurator loads them at startup so they receive events. + +### 9.1 Process Launcher (default) + +```json +{ + "id": "job_launcher", + "path": "nvflare.app_common.job_launcher.server_process_launcher.ServerProcessJobLauncher", + "args": {} +} +``` + +### 9.2 Docker Launcher + +```json +{ + "id": "job_launcher", + "path": "nvflare.app_opt.job_launcher.docker_launcher.ClientDockerJobLauncher", + "args": { + "mount_path": "/workspace", + "network": "nvflare-network", + "timeout": 60 + } +} +``` + +Requires the `NVFL_DOCKER_WORKSPACE` environment variable to be set on the host to identify the workspace directory to bind-mount. + +### 9.3 Kubernetes Launcher + +```json +{ + "id": "job_launcher", + "path": "nvflare.app_opt.job_launcher.k8s_launcher.ClientK8sJobLauncher", + "args": { + "config_file_path": "/path/to/kubeconfig", + "workspace_pvc": "nvflare-workspace-pvc", + "etc_pvc": "nvflare-etc-pvc", + "data_pvc_file_path": "/path/to/data_pvc.yaml", + "timeout": 120, + "namespace": "nvflare" + } +} +``` + +The `data_pvc_file_path` YAML file maps PVC names to mount paths: + +```yaml +my-data-pvc: /var/tmp/nvflare/data +``` + +> Note: Currently only the PVC name (the YAML key) is used. The mount path value is ignored — the data volume is always mounted at the hardcoded path `/var/tmp/nvflare/data`. + +--- + +## 10. Future Improvements + +1. **Explicit launcher selection** -- Today "has image" → Docker or K8s, "no image" → Process. Allow an explicit `launcher_type` field in job meta or deploy map so a site can support multiple container backends or provide fallback ordering (e.g., try K8s, fall back to Docker). + +2. **Consistent GPU handling** -- The K8s launcher reads `num_of_gpus` from the resource spec; the Docker and Process launchers do not. Standardize resource declaration so job definitions remain portable across backends. + +3. **Unified cleanup** -- Standardize container/pod cleanup policy across launchers (auto-remove on exit, configurable retention for debugging) and centralize it in the handle or engine. + +4. **Consistent timeout policy and failure semantics** -- The Process launcher has no start timeout. Docker polls for `RUNNING` and returns `None` on failure. K8s polls for `Running` with stuck detection (terminates and sets `terminal_state` on stuck) but does not act on plain startup timeout — if a pod is slow to start but not stuck in `Pending`, the handle is returned with `terminal_state` unset. Consider terminating explicitly on timeout and unifying failure return across all launchers (either always `None` or always a terminated handle). + +5. **Observability** -- Add an optional `get_info()` method to `JobHandleSpec` so the engine can log launcher-specific details (container ID, pod name, namespace, PID) for debugging and operations. + +6. **Testing** -- Provide `MockJobLauncher` and `MockJobHandle` implementations for unit tests that verify server/client flow without starting real processes or containers. diff --git a/nvflare/apis/job_launcher_spec.py b/nvflare/apis/job_launcher_spec.py index cb75130e6a..8fedfb866e 100644 --- a/nvflare/apis/job_launcher_spec.py +++ b/nvflare/apis/job_launcher_spec.py @@ -61,7 +61,7 @@ class JobHandleSpec: def terminate(self): """To terminate the job run. - Returns: the job run return code. + Returns: None """ raise NotImplementedError() @@ -94,7 +94,7 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: job_meta: job metadata fl_ctx: FLContext - Returns: boolean to indicates the job launch success or fail. + Returns: a JobHandle instance. """ raise NotImplementedError() diff --git a/nvflare/app_common/resource_consumers/BE_resource_consumer.py b/nvflare/app_common/resource_consumers/BE_resource_consumer.py new file mode 100644 index 0000000000..08059ec621 --- /dev/null +++ b/nvflare/app_common/resource_consumers/BE_resource_consumer.py @@ -0,0 +1,21 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.resource_manager_spec import ResourceConsumerSpec + + +class BEResourceConsumer(ResourceConsumerSpec): + + def consume(self, resources: dict): + pass diff --git a/nvflare/app_common/resource_managers/BE_resource_manager.py b/nvflare/app_common/resource_managers/BE_resource_manager.py new file mode 100644 index 0000000000..89b92e850c --- /dev/null +++ b/nvflare/app_common/resource_managers/BE_resource_manager.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext +from nvflare.apis.resource_manager_spec import ResourceManagerSpec + + +class BEResourceManager(ResourceManagerSpec, FLComponent): + def __init__(self): + """Best Effort Resource Manager implementation. + + It will accept all resource allocation requests + and let the job fail when the requested resources are unavailable + at runtime. + + """ + super().__init__() + + def check_resources(self, resource_requirement: dict, fl_ctx: FLContext): + if not isinstance(resource_requirement, dict): + raise TypeError(f"resource_requirement should be of type dict, but got {type(resource_requirement)}.") + + token = str(uuid.uuid4()) + return True, token + + def cancel_resources(self, resource_requirement: dict, token: str, fl_ctx: FLContext): + return None + + def allocate_resources(self, resource_requirement: dict, token: str, fl_ctx: FLContext) -> dict: + return {} + + def free_resources(self, resources: dict, token: str, fl_ctx: FLContext): + pass + + def report_resources(self, fl_ctx): + return {} diff --git a/nvflare/app_common/resource_managers/gpu_resource_manager.py b/nvflare/app_common/resource_managers/gpu_resource_manager.py index f14f2f0ccd..c4ea20b52c 100644 --- a/nvflare/app_common/resource_managers/gpu_resource_manager.py +++ b/nvflare/app_common/resource_managers/gpu_resource_manager.py @@ -35,6 +35,7 @@ def __init__( num_gpu_key: str = "num_of_gpus", gpu_mem_key: str = "mem_per_gpu_in_GiB", expiration_period: Union[int, float] = 30, + ignore_host: bool = False, ): """Resource manager for GPUs. @@ -46,6 +47,9 @@ def __init__( expiration_period: Number of seconds to hold the resources reserved. If check_resources is called but after "expiration_period" no allocate resource is called, then the reserved resources will be released. + ignore_host: Whether to skip validation against GPUs present on the local host. Set to True in + environments where the NVFlare process runs on a node without GPUs (for example, some + Kubernetes deployments) but GPU resources are managed externally. """ if not isinstance(num_of_gpus, int): raise ValueError(f"num_of_gpus should be of type int, but got {type(num_of_gpus)}.") @@ -62,17 +66,21 @@ def __init__( if expiration_period < 0: raise ValueError("expiration_period should be greater than or equal to 0.") - if num_of_gpus > 0: - num_host_gpus = len(get_host_gpu_ids()) - if num_of_gpus > num_host_gpus: - raise ValueError(f"num_of_gpus specified ({num_of_gpus}) exceeds available GPUs: {num_host_gpus}.") - - host_gpu_mem = get_host_gpu_memory_total() - for i in host_gpu_mem: - if mem_per_gpu_in_GiB * 1024 > i: - raise ValueError( - f"Memory per GPU specified ({mem_per_gpu_in_GiB * 1024}) exceeds available GPU memory: {i}." - ) + if not isinstance(ignore_host, bool): + raise ValueError(f"ignore_host should be of type bool, but got {type(ignore_host)}.") + + if not ignore_host: + if num_of_gpus > 0: + num_host_gpus = len(get_host_gpu_ids()) + if num_of_gpus > num_host_gpus: + raise ValueError(f"num_of_gpus specified ({num_of_gpus}) exceeds available GPUs: {num_host_gpus}.") + + host_gpu_mem = get_host_gpu_memory_total() + for i in host_gpu_mem: + if mem_per_gpu_in_GiB * 1024 > i: + raise ValueError( + f"Memory per GPU specified ({mem_per_gpu_in_GiB * 1024}) exceeds available GPU memory: {i}." + ) self.num_gpu_key = num_gpu_key self.gpu_mem_key = gpu_mem_key diff --git a/nvflare/app_opt/job_launcher/docker_launcher.py b/nvflare/app_opt/job_launcher/docker_launcher.py index 8779e4711d..50e38a504a 100644 --- a/nvflare/app_opt/job_launcher/docker_launcher.py +++ b/nvflare/app_opt/job_launcher/docker_launcher.py @@ -45,10 +45,10 @@ class DOCKER_STATE: class DockerJobHandle(JobHandleSpec): - def __init__(self, container, timeout=None): + def __init__(self, timeout=None): super().__init__() - self.container = container + self.container = None self.timeout = timeout self.logger = logging.getLogger(self.__class__.__name__) @@ -68,6 +68,9 @@ def wait(self): if self.container: self.enter_states([DOCKER_STATE.EXITED, DOCKER_STATE.DEAD], self.timeout) + def _set_container(self, container): + self.container = container + def _get_container(self): try: docker_client = docker.from_env() @@ -120,6 +123,7 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: docker_workspace = os.environ.get("NVFL_DOCKER_WORKSPACE") self.logger.info(f"launch_job {job_id} in docker_workspace: {docker_workspace}") docker_client = docker.from_env() + handle = DockerJobHandle() try: container = docker_client.containers.run( job_image, @@ -137,24 +141,22 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: # ports=ports, # Map container ports to host ports (optional) ) self.logger.info(f"Launch the job in DockerJobLauncher using image: {job_image}") - - handle = DockerJobHandle(container) + handle._set_container(container) try: - if handle.enter_states([DOCKER_STATE.RUNNING], timeout=self.timeout): - return handle - else: + launched = handle.enter_states([DOCKER_STATE.RUNNING], timeout=self.timeout) + if not launched: handle.terminate() - return None + return handle except: handle.terminate() - return None + return handle except docker.errors.ImageNotFound: self.logger.error(f"Failed to launcher job: {job_id} in DockerJobLauncher. Image '{job_image}' not found.") - return None + return handle except docker.errors.APIError as e: self.logger.error(f"Error starting container: {e}") - return None + return handle def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.BEFORE_JOB_LAUNCH: diff --git a/nvflare/app_opt/job_launcher/k8s_launcher.py b/nvflare/app_opt/job_launcher/k8s_launcher.py index 19e6716bc2..3c0ab04884 100644 --- a/nvflare/app_opt/job_launcher/k8s_launcher.py +++ b/nvflare/app_opt/job_launcher/k8s_launcher.py @@ -16,6 +16,7 @@ from abc import abstractmethod from enum import Enum +import yaml from kubernetes import config from kubernetes.client import Configuration from kubernetes.client.api import core_v1_api @@ -24,6 +25,7 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey, JobConstants from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import JobMetaKey from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobProcessArgs, JobReturnCode, add_launcher from nvflare.utils.job_launcher_utils import extract_job_image, get_client_job_args, get_server_job_args @@ -36,12 +38,20 @@ class JobState(Enum): UNKNOWN = "unknown" +class POD_Phase(Enum): + PENDING = "Pending" + RUNNING = "Running" + SUCCEEDED = "Succeeded" + FAILED = "Failed" + UNKNOWN = "Unknown" + + POD_STATE_MAPPING = { - "Pending": JobState.STARTING, - "Running": JobState.RUNNING, - "Succeeded": JobState.SUCCEEDED, - "Failed": JobState.TERMINATED, - "Unknown": JobState.UNKNOWN, + POD_Phase.PENDING.value: JobState.STARTING, + POD_Phase.RUNNING.value: JobState.RUNNING, + POD_Phase.SUCCEEDED.value: JobState.SUCCEEDED, + POD_Phase.FAILED.value: JobState.TERMINATED, + POD_Phase.UNKNOWN.value: JobState.UNKNOWN, } JOB_RETURN_CODE_MAPPING = { @@ -52,13 +62,39 @@ class JobState(Enum): JobState.UNKNOWN: JobReturnCode.UNKNOWN, } +DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT = { + "-m": None, + "-w": None, + "-t": None, + "-d": None, + "-n": None, + "-c": None, + "-p": None, + "-g": None, + "-scheme": None, + "-s": None, +} + + +class PV_NAME(Enum): + WORKSPACE = "nvflws" + DATA = "nvfldata" + ETC = "nvfletc" + + +VOLUME_MOUNT_LIST = [ + {"name": PV_NAME.WORKSPACE.value, "mountPath": "/var/tmp/nvflare/workspace"}, + {"name": PV_NAME.DATA.value, "mountPath": "/var/tmp/nvflare/data"}, + {"name": PV_NAME.ETC.value, "mountPath": "/var/tmp/nvflare/etc"}, +] + class K8sJobHandle(JobHandleSpec): def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, namespace="default", timeout=None): super().__init__() self.job_id = job_id self.timeout = timeout - + self.terminal_state = None self.api_instance = api_instance self.namespace = namespace self.pod_manifest = { @@ -68,78 +104,56 @@ def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, nam "spec": { "containers": None, # link to container_list "volumes": None, # link to volume_list - "restartPolicy": "OnFailure", + "restartPolicy": "Never", }, } - self.volume_list = [{"name": None, "hostPath": {"path": None, "type": "Directory"}}] + self.volume_list = [] + self.container_list = [ { "image": None, "name": None, + "resources": None, "command": ["/usr/local/bin/python"], "args": None, # args_list + args_dict + args_sets "volumeMounts": None, # volume_mount_list "imagePullPolicy": "Always", } ] - self.container_args_python_args_list = ["-u", "-m", job_config.get("command")] - self.container_args_module_args_dict = { - "-m": None, - "-w": None, - "-t": None, - "-d": None, - "-n": None, - "-c": None, - "-p": None, - "-g": None, - "-scheme": None, - "-s": None, - } - self.container_volume_mount_list = [ - { - "name": None, - "mountPath": None, - } - ] + command = job_config.get("command") + if not command: + raise ValueError("job_config must contain a non-empty 'command' key") + self.container_args_python_args_list = ["-u", "-m", command] + self.container_volume_mount_list = [] self._make_manifest(job_config) + self._stuck_count = 0 + self._stuck_grace_period = 10 # seconds to wait before counting Pending as stuck + self._max_stuck_count = (self.timeout + self._stuck_grace_period) if self.timeout is not None else None + self.logger = logging.getLogger(self.__class__.__name__) def _make_manifest(self, job_config): - self.container_volume_mount_list = job_config.get( - "volume_mount_list", [{"name": "workspace-nvflare", "mountPath": "/workspace/nvflare"}] - ) + self.container_volume_mount_list.extend(job_config.get("volume_mount_list", [])) set_list = job_config.get("set_list") if set_list is None: self.container_args_module_args_sets = list() else: self.container_args_module_args_sets = ["--set"] + set_list - self.container_args_module_args_dict = job_config.get( - "module_args", - { - "-m": None, - "-w": None, - "-t": None, - "-d": None, - "-n": None, - "-c": None, - "-p": None, - "-g": None, - "-scheme": None, - "-s": None, - }, - ) + if job_config.get("module_args") is None: + self.container_args_module_args_dict = DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT.copy() + else: + self.container_args_module_args_dict = job_config.get("module_args") self.container_args_module_args_dict_as_list = list() for k, v in self.container_args_module_args_dict.items(): + if v is None: + continue self.container_args_module_args_dict_as_list.append(k) self.container_args_module_args_dict_as_list.append(v) - self.volume_list = job_config.get( - "volume_list", [{"name": None, "hostPath": {"path": None, "type": "Directory"}}] - ) - + self.volume_list.extend(job_config.get("volume_list", [])) self.pod_manifest["metadata"]["name"] = job_config.get("name") self.pod_manifest["spec"]["containers"] = self.container_list self.pod_manifest["spec"]["volumes"] = self.volume_list - self.container_list[0]["image"] = job_config.get("image", "nvflare/nvflare:2.5.0") + self.container_list[0]["image"] = job_config.get("image", "nvflare/nvflare:2.8.0") self.container_list[0]["name"] = job_config.get("container_name", "nvflare_job") self.container_list[0]["args"] = ( self.container_args_python_args_list @@ -147,6 +161,8 @@ def _make_manifest(self, job_config): + self.container_args_module_args_sets ) self.container_list[0]["volumeMounts"] = self.container_volume_mount_list + if job_config.get("resources", {}).get("limits", {}).get("nvidia.com/gpu") is not None: + self.container_list[0]["resources"] = job_config.get("resources") def get_manifest(self): return self.pod_manifest @@ -155,10 +171,14 @@ def enter_states(self, job_states_to_enter: list, timeout=None): starting_time = time.time() if not isinstance(job_states_to_enter, (list, tuple)): job_states_to_enter = [job_states_to_enter] - if not all([isinstance(js, JobState)] for js in job_states_to_enter): + if not all([isinstance(js, JobState) for js in job_states_to_enter]): raise ValueError(f"expect job_states_to_enter with valid values, but get {job_states_to_enter}") while True: - job_state = self._query_state() + pod_phase = self._query_phase() + if self._stuck(pod_phase): + self.terminate() + return False + job_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN) if job_state in job_states_to_enter: return True elif timeout is not None and time.time() - starting_time > timeout: @@ -166,42 +186,83 @@ def enter_states(self, job_states_to_enter: list, timeout=None): time.sleep(1) def terminate(self): - resp = self.api_instance.delete_namespaced_pod( - name=self.job_id, namespace=self.namespace, grace_period_seconds=0 - ) - return self.enter_states([JobState.TERMINATED], timeout=self.timeout) + try: + resp = self.api_instance.delete_namespaced_pod( + name=self.job_id, namespace=self.namespace, grace_period_seconds=0 + ) + self.terminal_state = JobState.TERMINATED + except ApiException as e: + # If the pod is already gone, treat it as terminated; otherwise, leave state unchanged. + if getattr(e, "status", None) == 404: + self.logger.info(f"job {self.job_id} pod not found during termination; assuming terminated") + self.terminal_state = JobState.TERMINATED + else: + self.logger.error(f"failed to terminate job {self.job_id}: {e}") + return None def poll(self): + if self.terminal_state is not None: + return JOB_RETURN_CODE_MAPPING.get(self.terminal_state) job_state = self._query_state() return JOB_RETURN_CODE_MAPPING.get(job_state, JobReturnCode.UNKNOWN) - def _query_state(self): + def _query_phase(self): try: resp = self.api_instance.read_namespaced_pod(name=self.job_id, namespace=self.namespace) - except ApiException: - return JobState.UNKNOWN - return POD_STATE_MAPPING.get(resp.status.phase, JobState.UNKNOWN) + except ApiException as e: + return POD_Phase.UNKNOWN.value + return resp.status.phase + + def _query_state(self): + pod_phase = self._query_phase() + return POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN) + + def _stuck(self, current_phase): + if self._max_stuck_count is None: + return False + if current_phase == POD_Phase.PENDING.value: + self._stuck_count += 1 + if self._stuck_count > self._max_stuck_count: + return True + return False def wait(self): - self.enter_states([JobState.SUCCEEDED, JobState.TERMINATED]) + while True: + if self.terminal_state is not None: + return + job_state = self._query_state() + if job_state in (JobState.SUCCEEDED, JobState.TERMINATED): + self.terminal_state = job_state # persist so poll() stays accurate + return + time.sleep(1) class K8sJobLauncher(JobLauncherSpec): def __init__( self, - config_file_path, - root_hostpath: str, - workspace: str, - mount_path: str, + config_file_path: str, + workspace_pvc: str, + etc_pvc: str, + data_pvc_file_path: str, timeout=None, namespace="default", ): super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) - self.root_hostpath = root_hostpath - self.workspace = workspace - self.mount_path = mount_path + self.workspace_pvc = workspace_pvc + self.etc_pvc = etc_pvc + self.data_pvc_file_path = data_pvc_file_path self.timeout = timeout + self.namespace = namespace + with open(data_pvc_file_path, "rt") as f: + data_pvc_dict = yaml.safe_load(f) + if not data_pvc_dict: + raise ValueError(f"data_pvc_file_path '{data_pvc_file_path}' is empty or contains no PVC entries.") + # data_pvc_dict will be pvc: mountPath + # currently, support one pvc and always mount to /var/tmp/nvflare/data + # ie, ignore the mountPath in data_pvc_dict + self.data_pvc = list(data_pvc_dict.keys())[0] config.load_kube_config(config_file_path) try: @@ -211,17 +272,15 @@ def __init__( c.assert_hostname = False Configuration.set_default(c) self.core_v1 = core_v1_api.CoreV1Api() - self.namespace = namespace - self.job_handle = None - self.logger = logging.getLogger(self.__class__.__name__) def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: - + site_name = fl_ctx.get_identity_name() job_id = job_meta.get(JobConstants.JOB_ID) args = fl_ctx.get_prop(FLContextKey.ARGS) - job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) - self.logger.info(f"launch job use image: {job_image}") + job_image = extract_job_image(job_meta, site_name) + site_resources = job_meta.get(JobMetaKey.RESOURCE_SPEC.value, {}).get(site_name, {}) + job_resource = site_resources.get("num_of_gpus", None) job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS) if not job_args: @@ -233,25 +292,28 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: "image": job_image, "container_name": f"container-{job_id}", "command": job_cmd, - "volume_mount_list": [{"name": self.workspace, "mountPath": self.mount_path}], - "volume_list": [{"name": self.workspace, "hostPath": {"path": self.root_hostpath, "type": "Directory"}}], + "volume_mount_list": VOLUME_MOUNT_LIST, + "volume_list": [ + {"name": PV_NAME.WORKSPACE.value, "persistentVolumeClaim": {"claimName": self.workspace_pvc}}, + {"name": PV_NAME.DATA.value, "persistentVolumeClaim": {"claimName": self.data_pvc}}, + {"name": PV_NAME.ETC.value, "persistentVolumeClaim": {"claimName": self.etc_pvc}}, + ], "module_args": self.get_module_args(job_id, fl_ctx), "set_list": args.set, + "resources": {"limits": {"nvidia.com/gpu": job_resource}}, } - self.logger.info(f"launch job with k8s_launcher. Job_id:{job_id}") - job_handle = K8sJobHandle(job_id, self.core_v1, job_config, namespace=self.namespace, timeout=self.timeout) + pod_manifest = job_handle.get_manifest() + self.logger.debug(f"launch job with k8s_launcher. {pod_manifest=}") try: - self.core_v1.create_namespaced_pod(body=job_handle.get_manifest(), namespace=self.namespace) - if job_handle.enter_states([JobState.RUNNING], timeout=self.timeout): - return job_handle - else: - job_handle.terminate() - return None - except ApiException: + self.core_v1.create_namespaced_pod(body=pod_manifest, namespace=self.namespace) + job_handle.enter_states([JobState.RUNNING], timeout=self.timeout) + return job_handle + except ApiException as e: + self.logger.error(f"failed to launch job {self.job_id}: {e}") job_handle.terminate() - return None + return job_handle def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.BEFORE_JOB_LAUNCH: diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 1387f9ca9c..d44d6e1ad7 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -59,6 +59,7 @@ def __init__( client_register_interval=2, timeout=5.0, maint_msg_timeout=5.0, + cell_creation_timeout=15.0, ): """To init the Communicator. @@ -79,7 +80,7 @@ def __init__( self.client_register_interval = client_register_interval self.timeout = timeout self.maint_msg_timeout = maint_msg_timeout - + self.creation_timeout = cell_creation_timeout # token and token_signature are issued by the Server after the client is authenticated # they are added to every message going to the server as proof of authentication self.token = None @@ -273,9 +274,9 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): start = time.time() while not self.cell: self.logger.info("Waiting for the client cell to be created.") - if time.time() - start > 15.0: + if time.time() - start > self.creation_timeout: raise RuntimeError("Client cell could not be created. Failed to login the client.") - time.sleep(0.5) + time.sleep(1) shared_fl_ctx = gen_new_peer_ctx(fl_ctx) private_key_file = None diff --git a/tests/unit_test/app_opt/job_launcher/__init__.py b/tests/unit_test/app_opt/job_launcher/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/unit_test/app_opt/job_launcher/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/app_opt/job_launcher/docker_launcher_test.py b/tests/unit_test/app_opt/job_launcher/docker_launcher_test.py new file mode 100644 index 0000000000..3e71e25b85 --- /dev/null +++ b/tests/unit_test/app_opt/job_launcher/docker_launcher_test.py @@ -0,0 +1,464 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock, patch + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import FLContextKey, JobConstants, ReservedKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import JobMetaKey +from nvflare.apis.job_launcher_spec import JobReturnCode +from nvflare.app_opt.job_launcher.docker_launcher import ( + DOCKER_STATE, + JOB_RETURN_CODE_MAPPING, + DockerJobHandle, + DockerJobLauncher, +) + + +# --------------------------------------------------------------------------- +# Constants and mappings +# --------------------------------------------------------------------------- +class TestDockerState: + def test_state_values(self): + assert DOCKER_STATE.CREATED == "created" + assert DOCKER_STATE.RESTARTING == "restarting" + assert DOCKER_STATE.RUNNING == "running" + assert DOCKER_STATE.PAUSED == "paused" + assert DOCKER_STATE.EXITED == "exited" + assert DOCKER_STATE.DEAD == "dead" + + +class TestDockerJobReturnCodeMapping: + def test_running_maps_to_unknown(self): + assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.RUNNING] == JobReturnCode.UNKNOWN + + def test_exited_maps_to_success(self): + assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.EXITED] == JobReturnCode.SUCCESS + + def test_dead_maps_to_aborted(self): + assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.DEAD] == JobReturnCode.ABORTED + + def test_created_maps_to_unknown(self): + assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.CREATED] == JobReturnCode.UNKNOWN + + def test_paused_maps_to_unknown(self): + assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.PAUSED] == JobReturnCode.UNKNOWN + + def test_restarting_maps_to_unknown(self): + assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.RESTARTING] == JobReturnCode.UNKNOWN + + +# --------------------------------------------------------------------------- +# DockerJobHandle +# --------------------------------------------------------------------------- +class TestDockerJobHandle: + def test_init_defaults(self): + handle = DockerJobHandle() + assert handle.container is None + assert handle.timeout is None + + def test_init_with_timeout(self): + handle = DockerJobHandle(timeout=30) + assert handle.container is None + assert handle.timeout == 30 + + def test_set_container(self): + handle = DockerJobHandle() + container = Mock() + handle._set_container(container) + assert handle.container is container + + def test_terminate_stops_container(self): + handle = DockerJobHandle() + container = Mock() + handle._set_container(container) + handle.terminate() + container.stop.assert_called_once() + + def test_terminate_noop_when_no_container(self): + handle = DockerJobHandle() + handle.terminate() + + # -- poll ----------------------------------------------------------------- + @patch.object(DockerJobHandle, "_get_container") + def test_poll_running_returns_unknown(self, mock_get): + container = Mock() + container.status = DOCKER_STATE.RUNNING + mock_get.return_value = container + handle = DockerJobHandle() + assert handle.poll() == JobReturnCode.UNKNOWN + + @patch.object(DockerJobHandle, "_get_container") + def test_poll_exited_removes_and_returns_success(self, mock_get): + container = Mock() + container.status = DOCKER_STATE.EXITED + mock_get.return_value = container + handle = DockerJobHandle() + result = handle.poll() + container.remove.assert_called_once_with(force=True) + assert result == JobReturnCode.SUCCESS + + @patch.object(DockerJobHandle, "_get_container") + def test_poll_dead_removes_and_returns_aborted(self, mock_get): + container = Mock() + container.status = DOCKER_STATE.DEAD + mock_get.return_value = container + handle = DockerJobHandle() + result = handle.poll() + container.remove.assert_called_once_with(force=True) + assert result == JobReturnCode.ABORTED + + @patch.object(DockerJobHandle, "_get_container") + def test_poll_returns_none_when_container_gone(self, mock_get): + mock_get.return_value = None + handle = DockerJobHandle() + assert handle.poll() is None + + @patch.object(DockerJobHandle, "_get_container") + def test_poll_unknown_status_returns_unknown(self, mock_get): + container = Mock() + container.status = "something_unexpected" + mock_get.return_value = container + handle = DockerJobHandle() + assert handle.poll() == JobReturnCode.UNKNOWN + + # -- wait ----------------------------------------------------------------- + @patch.object(DockerJobHandle, "enter_states") + def test_wait_calls_enter_states(self, mock_enter): + handle = DockerJobHandle(timeout=10) + handle._set_container(Mock()) + handle.wait() + mock_enter.assert_called_once_with([DOCKER_STATE.EXITED, DOCKER_STATE.DEAD], 10) + + def test_wait_noop_when_no_container(self): + handle = DockerJobHandle() + handle.wait() + + # -- enter_states --------------------------------------------------------- + @patch.object(DockerJobHandle, "_get_container") + def test_enter_states_returns_true_when_state_matches(self, mock_get): + container = Mock() + container.status = DOCKER_STATE.RUNNING + mock_get.return_value = container + handle = DockerJobHandle() + assert handle.enter_states([DOCKER_STATE.RUNNING]) is True + + @patch.object(DockerJobHandle, "_get_container") + def test_enter_states_returns_false_when_container_gone(self, mock_get): + mock_get.return_value = None + handle = DockerJobHandle() + assert handle.enter_states([DOCKER_STATE.RUNNING]) is False + + @patch.object(DockerJobHandle, "_get_container") + def test_enter_states_returns_false_on_timeout(self, mock_get): + container = Mock() + container.status = DOCKER_STATE.CREATED + mock_get.return_value = container + handle = DockerJobHandle() + assert handle.enter_states([DOCKER_STATE.RUNNING], timeout=0) is False + + @patch.object(DockerJobHandle, "_get_container") + def test_enter_states_wraps_single_state(self, mock_get): + container = Mock() + container.status = DOCKER_STATE.EXITED + mock_get.return_value = container + handle = DockerJobHandle() + assert handle.enter_states(DOCKER_STATE.EXITED) is True + + # -- _get_container ------------------------------------------------------- + @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") + def test_get_container_returns_container(self, mock_docker): + orig_container = Mock() + orig_container.id = "abc123" + refreshed = Mock() + mock_docker.from_env.return_value.containers.get.return_value = refreshed + + handle = DockerJobHandle() + handle._set_container(orig_container) + result = handle._get_container() + assert result is refreshed + mock_docker.from_env.return_value.containers.get.assert_called_once_with("abc123") + + @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") + def test_get_container_returns_none_on_exception(self, mock_docker): + orig_container = Mock() + orig_container.id = "abc123" + mock_docker.from_env.side_effect = Exception("connection error") + + handle = DockerJobHandle() + handle._set_container(orig_container) + assert handle._get_container() is None + + +# --------------------------------------------------------------------------- +# DockerJobLauncher +# --------------------------------------------------------------------------- +def _make_fl_ctx_for_docker_launch(): + fl_ctx = FLContext() + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) + workspace_obj = Mock() + workspace_obj.get_app_custom_dir.return_value = "/custom/dir" + fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace_obj, private=True, sticky=False) + fl_ctx.set_prop( + FLContextKey.JOB_PROCESS_ARGS, + { + "exe_module": ("-m", "nvflare.private.fed.app.client.worker_process"), + "workspace": ("-w", "/workspace"), + }, + private=True, + sticky=False, + ) + return fl_ctx + + +def _make_docker_job_meta(image="nvflare/nvflare:test", job_id="job-123"): + return { + JobConstants.JOB_ID: job_id, + JobMetaKey.DEPLOY_MAP.value: {"app": [{"sites": ["client-1"], "image": image}]}, + } + + +class TestDockerJobLauncher: + def test_init_defaults(self): + launcher = DockerJobLauncher() + assert launcher.mount_path == "/workspace" + assert launcher.network == "nvflare-network" + assert launcher.timeout is None + + def test_init_custom(self): + launcher = DockerJobLauncher(mount_path="/custom", network="my-net", timeout=120) + assert launcher.mount_path == "/custom" + assert launcher.network == "my-net" + assert launcher.timeout == 120 + + # -- handle_event --------------------------------------------------------- + def test_handle_event_adds_launcher_when_image_present(self): + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + launcher = ClientDockerJobLauncher() + fl_ctx = FLContext() + job_meta = {JobMetaKey.DEPLOY_MAP.value: {"app": [{"sites": ["client-1"], "image": "nvflare/custom:latest"}]}} + fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) + + launcher.handle_event(EventType.BEFORE_JOB_LAUNCH, fl_ctx) + + launchers = fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) + assert launchers is not None + assert launcher in launchers + + def test_handle_event_skips_when_no_image(self): + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + launcher = ClientDockerJobLauncher() + fl_ctx = FLContext() + job_meta = {JobMetaKey.DEPLOY_MAP.value: {}} + fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) + + launcher.handle_event(EventType.BEFORE_JOB_LAUNCH, fl_ctx) + assert fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) is None + + def test_handle_event_ignores_other_events(self): + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + launcher = ClientDockerJobLauncher() + fl_ctx = FLContext() + launcher.handle_event(EventType.SYSTEM_START, fl_ctx) + assert fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) is None + + # -- launch_job ----------------------------------------------------------- + @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") + @patch("nvflare.app_opt.job_launcher.docker_launcher.os") + def test_launch_job_success(self, mock_os, mock_docker): + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + mock_os.environ.get.return_value = "/docker/workspace" + container = Mock() + container.status = DOCKER_STATE.RUNNING + mock_docker.from_env.return_value.containers.run.return_value = container + + launcher = ClientDockerJobLauncher(timeout=5) + fl_ctx = _make_fl_ctx_for_docker_launch() + job_meta = _make_docker_job_meta() + + with patch.object(DockerJobHandle, "enter_states", return_value=True): + handle = launcher.launch_job(job_meta, fl_ctx) + + assert handle is not None + mock_docker.from_env.return_value.containers.run.assert_called_once() + + @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") + @patch("nvflare.app_opt.job_launcher.docker_launcher.os") + def test_launch_job_returns_handle_on_enter_states_failure(self, mock_os, mock_docker): + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + mock_os.environ.get.return_value = "/docker/workspace" + container = Mock() + mock_docker.from_env.return_value.containers.run.return_value = container + + launcher = ClientDockerJobLauncher(timeout=1) + fl_ctx = _make_fl_ctx_for_docker_launch() + job_meta = _make_docker_job_meta() + + with patch.object(DockerJobHandle, "enter_states", return_value=False): + handle = launcher.launch_job(job_meta, fl_ctx) + + assert handle is not None + + @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") + @patch("nvflare.app_opt.job_launcher.docker_launcher.os") + def test_launch_job_terminates_on_enter_states_failure(self, mock_os, mock_docker): + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + mock_os.environ.get.return_value = "/docker/workspace" + container = Mock() + mock_docker.from_env.return_value.containers.run.return_value = container + + launcher = ClientDockerJobLauncher(timeout=1) + fl_ctx = _make_fl_ctx_for_docker_launch() + job_meta = _make_docker_job_meta() + + with patch.object(DockerJobHandle, "enter_states", return_value=False) as mock_enter: + handle = launcher.launch_job(job_meta, fl_ctx) + + container.stop.assert_called_once() + + @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") + @patch("nvflare.app_opt.job_launcher.docker_launcher.os") + def test_launch_job_returns_handle_on_enter_states_exception(self, mock_os, mock_docker): + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + mock_os.environ.get.return_value = "/docker/workspace" + container = Mock() + mock_docker.from_env.return_value.containers.run.return_value = container + + launcher = ClientDockerJobLauncher(timeout=1) + fl_ctx = _make_fl_ctx_for_docker_launch() + job_meta = _make_docker_job_meta() + + with patch.object(DockerJobHandle, "enter_states", side_effect=RuntimeError("boom")): + handle = launcher.launch_job(job_meta, fl_ctx) + + assert handle is not None + container.stop.assert_called_once() + + @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") + @patch("nvflare.app_opt.job_launcher.docker_launcher.os") + def test_launch_job_returns_handle_on_image_not_found(self, mock_os, mock_docker): + import docker as docker_pkg + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + mock_os.environ.get.return_value = "/docker/workspace" + mock_docker.from_env.return_value.containers.run.side_effect = docker_pkg.errors.ImageNotFound("not found") + mock_docker.errors = docker_pkg.errors + + launcher = ClientDockerJobLauncher() + fl_ctx = _make_fl_ctx_for_docker_launch() + job_meta = _make_docker_job_meta(image="bad/image:latest") + + handle = launcher.launch_job(job_meta, fl_ctx) + assert handle is not None + assert handle.container is None + + @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") + @patch("nvflare.app_opt.job_launcher.docker_launcher.os") + def test_launch_job_returns_handle_on_api_error(self, mock_os, mock_docker): + import docker as docker_pkg + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + mock_os.environ.get.return_value = "/docker/workspace" + mock_docker.from_env.return_value.containers.run.side_effect = docker_pkg.errors.APIError("api error") + mock_docker.errors = docker_pkg.errors + + launcher = ClientDockerJobLauncher() + fl_ctx = _make_fl_ctx_for_docker_launch() + job_meta = _make_docker_job_meta() + + handle = launcher.launch_job(job_meta, fl_ctx) + assert handle is not None + assert handle.container is None + + @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") + @patch("nvflare.app_opt.job_launcher.docker_launcher.os") + def test_launch_job_empty_custom_folder_uses_pythonpath_only(self, mock_os, mock_docker): + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + mock_os.environ.get.return_value = "/docker/workspace" + container = Mock() + mock_docker.from_env.return_value.containers.run.return_value = container + + launcher = ClientDockerJobLauncher() + fl_ctx = FLContext() + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) + workspace_obj = Mock() + workspace_obj.get_app_custom_dir.return_value = "" + fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace_obj, private=True, sticky=False) + fl_ctx.set_prop( + FLContextKey.JOB_PROCESS_ARGS, + { + "exe_module": ("-m", "worker"), + "workspace": ("-w", "/workspace"), + }, + private=True, + sticky=False, + ) + + job_meta = _make_docker_job_meta() + + with patch.object(DockerJobHandle, "enter_states", return_value=True): + handle = launcher.launch_job(job_meta, fl_ctx) + + call_kwargs = mock_docker.from_env.return_value.containers.run.call_args + command_str = call_kwargs[1]["command"] if "command" in call_kwargs[1] else call_kwargs[0][1] + assert "$PYTHONPATH" in command_str + assert "/custom" not in command_str + + +# --------------------------------------------------------------------------- +# ClientDockerJobLauncher.get_command +# --------------------------------------------------------------------------- +class TestClientDockerJobLauncher: + @patch("nvflare.app_opt.job_launcher.docker_launcher.generate_client_command") + def test_get_command(self, mock_gen_cmd): + from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher + + mock_gen_cmd.return_value = "python -u -m worker_process -w /workspace" + launcher = ClientDockerJobLauncher() + fl_ctx = FLContext() + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) + job_meta = {JobConstants.JOB_ID: "job-abc"} + + name, cmd = launcher.get_command(job_meta, fl_ctx) + assert name == "client-1-job-abc" + assert cmd == "python -u -m worker_process -w /workspace" + + +# --------------------------------------------------------------------------- +# ServerDockerJobLauncher.get_command +# --------------------------------------------------------------------------- +class TestServerDockerJobLauncher: + @patch("nvflare.app_opt.job_launcher.docker_launcher.generate_server_command") + def test_get_command(self, mock_gen_cmd): + from nvflare.app_opt.job_launcher.docker_launcher import ServerDockerJobLauncher + + mock_gen_cmd.return_value = "python -u -m server_process -w /workspace" + launcher = ServerDockerJobLauncher() + fl_ctx = FLContext() + job_meta = {JobConstants.JOB_ID: "job-xyz"} + + name, cmd = launcher.get_command(job_meta, fl_ctx) + assert name == "server-job-xyz" + assert cmd == "python -u -m server_process -w /workspace" diff --git a/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py b/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py new file mode 100644 index 0000000000..73b7303e8c --- /dev/null +++ b/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py @@ -0,0 +1,666 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from types import ModuleType +from unittest.mock import MagicMock, Mock, patch + +import pytest + +_k8s_mock = ModuleType("kubernetes") +_k8s_client = ModuleType("kubernetes.client") +_k8s_config = ModuleType("kubernetes.config") +_k8s_rest = ModuleType("kubernetes.client.rest") +_k8s_api = ModuleType("kubernetes.client.api") +_k8s_core = ModuleType("kubernetes.client.api.core_v1_api") + + +class _FakeApiException(Exception): + def __init__(self, status=None, reason=None, http_resp=None): + self.status = status + self.reason = reason + + +_k8s_rest.ApiException = _FakeApiException +_k8s_client.Configuration = MagicMock +_k8s_client.rest = _k8s_rest +_k8s_client.api = _k8s_api +_k8s_core.CoreV1Api = MagicMock +_k8s_api.core_v1_api = _k8s_core +_k8s_mock.config = _k8s_config +_k8s_mock.client = _k8s_client + +for _mod_name, _mod_obj in [ + ("kubernetes", _k8s_mock), + ("kubernetes.config", _k8s_config), + ("kubernetes.client", _k8s_client), + ("kubernetes.client.rest", _k8s_rest), + ("kubernetes.client.api", _k8s_api), + ("kubernetes.client.api.core_v1_api", _k8s_core), +]: + sys.modules.setdefault(_mod_name, _mod_obj) + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import FLContextKey, ReservedKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import JobMetaKey +from nvflare.apis.job_launcher_spec import JobProcessArgs, JobReturnCode +from nvflare.app_opt.job_launcher.k8s_launcher import ( + DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT, + JOB_RETURN_CODE_MAPPING, + POD_STATE_MAPPING, + PV_NAME, + VOLUME_MOUNT_LIST, + JobState, + K8sJobHandle, + POD_Phase, + _job_args_dict, +) + + +def _make_job_config(**overrides): + cfg = { + "name": "test-job-123", + "image": "nvflare/nvflare:test", + "container_name": "container-test-job-123", + "command": "nvflare.private.fed.app.client.worker_process", + "volume_mount_list": VOLUME_MOUNT_LIST, + "volume_list": [ + {"name": PV_NAME.WORKSPACE.value, "persistentVolumeClaim": {"claimName": "ws-pvc"}}, + {"name": PV_NAME.DATA.value, "persistentVolumeClaim": {"claimName": "data-pvc"}}, + {"name": PV_NAME.ETC.value, "persistentVolumeClaim": {"claimName": "etc-pvc"}}, + ], + "module_args": {"-m": "val_m", "-w": "val_w"}, + "set_list": ["key1=val1", "key2=val2"], + "resources": {"limits": {"nvidia.com/gpu": 1}}, + } + cfg.update(overrides) + return cfg + + +def _make_api_instance(): + return MagicMock() + + +# --------------------------------------------------------------------------- +# Mapping tables +# --------------------------------------------------------------------------- +class TestPodStateMapping: + def test_all_phases_mapped(self): + for phase in POD_Phase: + assert phase.value in POD_STATE_MAPPING + + def test_pending_maps_to_starting(self): + assert POD_STATE_MAPPING[POD_Phase.PENDING.value] == JobState.STARTING + + def test_running_maps_to_running(self): + assert POD_STATE_MAPPING[POD_Phase.RUNNING.value] == JobState.RUNNING + + def test_succeeded_maps_to_succeeded(self): + assert POD_STATE_MAPPING[POD_Phase.SUCCEEDED.value] == JobState.SUCCEEDED + + def test_failed_maps_to_terminated(self): + assert POD_STATE_MAPPING[POD_Phase.FAILED.value] == JobState.TERMINATED + + +class TestJobReturnCodeMapping: + def test_all_job_states_mapped(self): + for state in JobState: + assert state in JOB_RETURN_CODE_MAPPING + + def test_succeeded_maps_to_success(self): + assert JOB_RETURN_CODE_MAPPING[JobState.SUCCEEDED] == JobReturnCode.SUCCESS + + def test_terminated_maps_to_aborted(self): + assert JOB_RETURN_CODE_MAPPING[JobState.TERMINATED] == JobReturnCode.ABORTED + + def test_running_maps_to_unknown(self): + assert JOB_RETURN_CODE_MAPPING[JobState.RUNNING] == JobReturnCode.UNKNOWN + + +# --------------------------------------------------------------------------- +# K8sJobHandle +# --------------------------------------------------------------------------- +class TestK8sJobHandle: + # -- construction --------------------------------------------------------- + def test_init_raises_on_missing_command(self): + cfg = _make_job_config() + del cfg["command"] + with pytest.raises(ValueError, match="command"): + K8sJobHandle("job-1", _make_api_instance(), cfg) + + def test_init_raises_on_empty_command(self): + cfg = _make_job_config(command="") + with pytest.raises(ValueError, match="command"): + K8sJobHandle("job-1", _make_api_instance(), cfg) + + def test_stuck_count_starts_at_zero(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=30) + assert handle._stuck_count == 0 + + def test_max_stuck_count_includes_grace_period(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=30) + assert handle._max_stuck_count == 30 + handle._stuck_grace_period + + def test_max_stuck_count_is_none_with_no_timeout(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=None) + assert handle._max_stuck_count is None + + # -- manifest ------------------------------------------------------------- + def test_manifest_metadata_name(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + assert handle.get_manifest()["metadata"]["name"] == "test-job-123" + + def test_manifest_container_image(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["image"] == "nvflare/nvflare:test" + + def test_manifest_container_name(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["name"] == "container-test-job-123" + + def test_manifest_default_image(self): + cfg = _make_job_config() + del cfg["image"] + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["image"] == "nvflare/nvflare:2.8.0" + + def test_manifest_default_container_name(self): + cfg = _make_job_config() + del cfg["container_name"] + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["name"] == "nvflare_job" + + def test_manifest_restart_policy(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + assert handle.get_manifest()["spec"]["restartPolicy"] == "Never" + + def test_manifest_volumes(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + volumes = handle.get_manifest()["spec"]["volumes"] + assert len(volumes) == 3 + pvc_names = [v["persistentVolumeClaim"]["claimName"] for v in volumes] + assert "ws-pvc" in pvc_names + assert "data-pvc" in pvc_names + assert "etc-pvc" in pvc_names + + def test_manifest_volume_mounts(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["volumeMounts"] == VOLUME_MOUNT_LIST + + def test_manifest_args_contain_command(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "-u" in args + assert "-m" in args + assert "nvflare.private.fed.app.client.worker_process" in args + + def test_manifest_args_contain_module_args(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "val_m" in args + assert "val_w" in args + + def test_manifest_args_contain_set_list(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "--set" in args + assert "key1=val1" in args + assert "key2=val2" in args + + def test_manifest_no_set_list(self): + cfg = _make_job_config() + cfg["set_list"] = None + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "--set" not in args + + def test_manifest_none_module_args_skipped(self): + cfg = _make_job_config(module_args={"-a": "keep", "-b": None}) + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "-a" in args + assert "keep" in args + assert "-b" not in args + + def test_manifest_default_module_args_copies_dict(self): + cfg = _make_job_config() + cfg["module_args"] = None + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + assert handle.container_args_module_args_dict is not DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT + assert handle.container_args_module_args_dict == DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT + + def test_manifest_default_module_args_all_none_produces_empty_args_list(self): + cfg = _make_job_config() + cfg["module_args"] = None + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + assert handle.container_args_module_args_dict_as_list == [] + + def test_manifest_gpu_resources(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["resources"]["limits"]["nvidia.com/gpu"] == 1 + + def test_manifest_no_gpu_resources(self): + cfg = _make_job_config(resources={"limits": {"nvidia.com/gpu": None}}) + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert container["resources"] is None + + # -- poll ----------------------------------------------------------------- + def test_poll_returns_unknown_when_running(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.RUNNING.value + api.read_namespaced_pod.return_value = resp + handle = K8sJobHandle("job-1", api, _make_job_config()) + assert handle.poll() == JobReturnCode.UNKNOWN + + def test_poll_returns_success_when_succeeded(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = K8sJobHandle("job-1", api, _make_job_config()) + assert handle.poll() == JobReturnCode.SUCCESS + + def test_poll_returns_aborted_when_failed(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.FAILED.value + api.read_namespaced_pod.return_value = resp + handle = K8sJobHandle("job-1", api, _make_job_config()) + assert handle.poll() == JobReturnCode.ABORTED + + def test_poll_uses_terminal_state_if_set(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config()) + handle.terminal_state = JobState.TERMINATED + assert handle.poll() == JobReturnCode.ABORTED + api.read_namespaced_pod.assert_not_called() + + # -- terminate ------------------------------------------------------------ + def test_terminate_deletes_pod_and_sets_terminated(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config()) + handle.terminate() + api.delete_namespaced_pod.assert_called_once_with(name="job-1", namespace="default", grace_period_seconds=0) + assert handle.terminal_state == JobState.TERMINATED + + def test_terminate_sets_terminated_on_404(self): + api = _make_api_instance() + api.delete_namespaced_pod.side_effect = _FakeApiException(status=404, reason="Not Found") + handle = K8sJobHandle("job-1", api, _make_job_config()) + handle.terminate() + assert handle.terminal_state == JobState.TERMINATED + + def test_terminate_does_not_set_state_on_non_404_error(self): + api = _make_api_instance() + api.delete_namespaced_pod.side_effect = _FakeApiException(status=500, reason="Internal") + handle = K8sJobHandle("job-1", api, _make_job_config()) + handle.terminate() + assert handle.terminal_state is None + + def test_terminate_custom_namespace(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), namespace="custom-ns") + handle.terminate() + api.delete_namespaced_pod.assert_called_once_with(name="job-1", namespace="custom-ns", grace_period_seconds=0) + + # -- _query_phase --------------------------------------------------------- + def test_query_phase_returns_unknown_on_api_error(self): + api = _make_api_instance() + api.read_namespaced_pod.side_effect = _FakeApiException(status=500, reason="Error") + handle = K8sJobHandle("job-1", api, _make_job_config()) + assert handle._query_phase() == POD_Phase.UNKNOWN.value + + # -- _stuck --------------------------------------------------------------- + def test_stuck_returns_false_when_no_timeout_and_grace_only(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=None) + assert handle._stuck(POD_Phase.PENDING.value) is False + + def test_stuck_returns_true_after_max_count_with_grace(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) + handle._stuck_count = handle._max_stuck_count + assert handle._stuck(POD_Phase.PENDING.value) is True + + def test_stuck_returns_false_for_non_pending(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) + handle._stuck_count = 9999 + assert handle._stuck(POD_Phase.RUNNING.value) is False + + def test_stuck_increments_count_on_pending(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=100) + initial = handle._stuck_count + handle._stuck(POD_Phase.PENDING.value) + assert handle._stuck_count == initial + 1 + + # -- wait ----------------------------------------------------------------- + def test_wait_returns_immediately_if_terminal_state_set(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config()) + handle.terminal_state = JobState.TERMINATED + handle.wait() + api.read_namespaced_pod.assert_not_called() + + def test_wait_persists_succeeded_terminal_state(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = K8sJobHandle("job-1", api, _make_job_config()) + handle.wait() + assert handle.terminal_state == JobState.SUCCEEDED + + def test_wait_persists_terminated_terminal_state(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.FAILED.value + api.read_namespaced_pod.return_value = resp + handle = K8sJobHandle("job-1", api, _make_job_config()) + handle.wait() + assert handle.terminal_state == JobState.TERMINATED + + def test_wait_poll_consistent_after_wait(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = K8sJobHandle("job-1", api, _make_job_config()) + handle.wait() + assert handle.poll() == JobReturnCode.SUCCESS + + # -- enter_states --------------------------------------------------------- + def test_enter_states_returns_true_when_state_matches(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.RUNNING.value + api.read_namespaced_pod.return_value = resp + handle = K8sJobHandle("job-1", api, _make_job_config()) + assert handle.enter_states([JobState.RUNNING]) is True + + def test_enter_states_accepts_single_state(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = K8sJobHandle("job-1", api, _make_job_config()) + assert handle.enter_states(JobState.SUCCEEDED) is True + + def test_enter_states_returns_false_on_timeout(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.PENDING.value + api.read_namespaced_pod.return_value = resp + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=None) + assert handle.enter_states([JobState.RUNNING], timeout=0) is False + + def test_enter_states_raises_on_invalid_state(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config()) + with pytest.raises(ValueError, match="expect job_states_to_enter"): + handle.enter_states(["not_a_state"]) + + +# --------------------------------------------------------------------------- +# _job_args_dict helper +# --------------------------------------------------------------------------- +class TestJobArgsDict: + def test_basic(self): + job_args = { + "workspace": ("-w", "/workspace"), + "job_id": ("-j", "job-1"), + } + result = _job_args_dict(job_args, ["workspace", "job_id"]) + assert result == {"-w": "/workspace", "-j": "job-1"} + + def test_skips_missing_keys(self): + job_args = {"workspace": ("-w", "/workspace")} + result = _job_args_dict(job_args, ["workspace", "missing_key"]) + assert result == {"-w": "/workspace"} + + def test_empty_args(self): + assert _job_args_dict({}, ["a", "b"]) == {} + + def test_empty_arg_names(self): + assert _job_args_dict({"workspace": ("-w", "/workspace")}, []) == {} + + +# --------------------------------------------------------------------------- +# K8sJobLauncher handle_event +# --------------------------------------------------------------------------- +def _make_k8s_launcher_patches(): + return [ + patch("nvflare.app_opt.job_launcher.k8s_launcher.config"), + patch("nvflare.app_opt.job_launcher.k8s_launcher.Configuration"), + patch("nvflare.app_opt.job_launcher.k8s_launcher.core_v1_api"), + patch("builtins.open", create=True), + patch("nvflare.app_opt.job_launcher.k8s_launcher.yaml"), + ] + + +def _enter_patches(patches): + mocks = [p.start() for p in patches] + return mocks + + +def _exit_patches(patches): + for p in patches: + p.stop() + + +def _setup_launcher(mock_yaml, mock_conf, launcher_cls): + mock_yaml.safe_load.return_value = {"data-pvc": "/data"} + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + return launcher_cls( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + ) + + +class TestK8sJobLauncherHandleEvent: + def test_adds_launcher_when_image_present(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ClientK8sJobLauncher) + fl_ctx = FLContext() + job_meta = {JobMetaKey.DEPLOY_MAP.value: {"app": [{"sites": ["site-1"], "image": "nvflare/custom:latest"}]}} + fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "site-1", private=False, sticky=True) + + launcher.handle_event(EventType.BEFORE_JOB_LAUNCH, fl_ctx) + + launchers = fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) + assert launchers is not None + assert launcher in launchers + finally: + _exit_patches(patches) + + def test_skips_when_no_image(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ClientK8sJobLauncher) + fl_ctx = FLContext() + job_meta = {JobMetaKey.DEPLOY_MAP.value: {}} + fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "site-1", private=False, sticky=True) + + launcher.handle_event(EventType.BEFORE_JOB_LAUNCH, fl_ctx) + + assert fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) is None + finally: + _exit_patches(patches) + + +# --------------------------------------------------------------------------- +# K8sJobLauncher __init__ +# --------------------------------------------------------------------------- +class TestK8sJobLauncherInit: + def test_init_reads_data_pvc(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + mock_yaml.safe_load.return_value = {"my-data-pvc": "/mount/data"} + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + + launcher = ClientK8sJobLauncher( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + timeout=60, + namespace="test-ns", + ) + + assert launcher.workspace_pvc == "ws-pvc" + assert launcher.etc_pvc == "etc-pvc" + assert launcher.data_pvc == "my-data-pvc" + assert launcher.timeout == 60 + assert launcher.namespace == "test-ns" + finally: + _exit_patches(patches) + + def test_init_raises_on_empty_pvc_file(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + mock_yaml.safe_load.return_value = {} + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + + with pytest.raises(ValueError, match="empty"): + ClientK8sJobLauncher( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + ) + finally: + _exit_patches(patches) + + +# --------------------------------------------------------------------------- +# ClientK8sJobLauncher.get_module_args +# --------------------------------------------------------------------------- +class TestClientK8sJobLauncherGetModuleArgs: + def test_returns_dict(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ClientK8sJobLauncher) + fl_ctx = FLContext() + job_args = { + JobProcessArgs.WORKSPACE: ("-w", "/workspace"), + JobProcessArgs.JOB_ID: ("-j", "job-1"), + } + fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args, private=True, sticky=False) + + result = launcher.get_module_args("job-1", fl_ctx) + assert isinstance(result, dict) + assert result.get("-w") == "/workspace" + finally: + _exit_patches(patches) + + def test_raises_when_no_args(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ClientK8sJobLauncher) + fl_ctx = FLContext() + with pytest.raises(RuntimeError, match="job_process_args"): + launcher.get_module_args("job-1", fl_ctx) + finally: + _exit_patches(patches) + + +# --------------------------------------------------------------------------- +# ServerK8sJobLauncher.get_module_args +# --------------------------------------------------------------------------- +class TestServerK8sJobLauncherGetModuleArgs: + def test_returns_dict(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ServerK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ServerK8sJobLauncher) + fl_ctx = FLContext() + job_args = { + JobProcessArgs.WORKSPACE: ("-w", "/workspace"), + JobProcessArgs.JOB_ID: ("-j", "job-1"), + JobProcessArgs.ROOT_URL: ("--root_url", "grpc://server:8003"), + } + fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args, private=True, sticky=False) + + result = launcher.get_module_args("job-1", fl_ctx) + assert isinstance(result, dict) + assert result.get("-w") == "/workspace" + finally: + _exit_patches(patches) + + def test_raises_when_no_args(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ServerK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + launcher = _setup_launcher(mock_yaml, mock_conf, ServerK8sJobLauncher) + fl_ctx = FLContext() + with pytest.raises(RuntimeError, match="job_process_args"): + launcher.get_module_args("job-1", fl_ctx) + finally: + _exit_patches(patches) From 999f057c516d9b2048f3ab1b2aa01a1f6375e9a5 Mon Sep 17 00:00:00 2001 From: Isaac Yang Date: Thu, 19 Mar 2026 08:55:19 -0700 Subject: [PATCH 2/3] Fix job configer unit test issue --- tests/unit_test/tool/job/config/configer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_test/tool/job/config/configer_test.py b/tests/unit_test/tool/job/config/configer_test.py index 792d88b315..2f3ac153e4 100644 --- a/tests/unit_test/tool/job/config/configer_test.py +++ b/tests/unit_test/tool/job/config/configer_test.py @@ -32,7 +32,7 @@ ( "launch_once", [ - ["app/config_fed_client.conf", "app_script=cifar10_fl.py", "launch_once=False"], + ["app/config_fed_client.conf", "app_script=cifar10_fl.py", "launch_once=false"], ["meta.conf", "min_clients=3"], ], "launch_everytime", @@ -50,7 +50,7 @@ ( "launch_once", [ - ["app/config/config_fed_client.conf", "app_script=cifar10_fl.py", "launch_once=False"], + ["app/config/config_fed_client.conf", "app_script=cifar10_fl.py", "launch_once=false"], ["meta.conf", "min_clients=3"], ], "launch_everytime", From 2d9eb0fbe0368382845f71fe4683effd9c512f5c Mon Sep 17 00:00:00 2001 From: Isaac Yang Date: Thu, 19 Mar 2026 08:59:06 -0700 Subject: [PATCH 3/3] Fix reference error Address comments --- docs/design/JobLauncher_and_JobHandle.md | 308 ++++----- nvflare/apis/job_launcher_spec.py | 6 +- .../app_opt/job_launcher/docker_launcher.py | 24 +- nvflare/app_opt/job_launcher/k8s_launcher.py | 125 +++- nvflare/private/fed/client/communicator.py | 7 +- .../job_launcher/docker_launcher_test.py | 464 -------------- .../app_opt/job_launcher/k8s_launcher_test.py | 592 ++++++++++++++++-- 7 files changed, 782 insertions(+), 744 deletions(-) delete mode 100644 tests/unit_test/app_opt/job_launcher/docker_launcher_test.py diff --git a/docs/design/JobLauncher_and_JobHandle.md b/docs/design/JobLauncher_and_JobHandle.md index 650b159785..8ddeeae839 100644 --- a/docs/design/JobLauncher_and_JobHandle.md +++ b/docs/design/JobLauncher_and_JobHandle.md @@ -2,7 +2,7 @@ ## 1. Overview -NVFlare runs each federated job as an isolated execution unit -- a subprocess, Docker container, or Kubernetes pod. Two abstractions govern this: +NVFlare runs each federated job as an isolated execution unit -- a subprocess or Kubernetes pod. Two abstractions govern this: - **JobLauncherSpec** -- starts a job and returns a handle. - **JobHandleSpec** -- represents the running job and provides lifecycle control (poll, wait, terminate). @@ -14,21 +14,21 @@ The upper layers (server engine, client executor) program exclusively against th │ Upper Layer │ │ ServerEngine / ClientExecutor │ │ │ -│ 1. Build JOB_PROCESS_ARGS │ -│ 2. get_job_launcher(job_meta, fl_ctx) → launcher │ +│ 1. get_job_launcher(job_meta, fl_ctx) → launcher │ +│ 2. Build JOB_PROCESS_ARGS │ │ 3. launcher.launch_job(job_meta, fl_ctx) → job_handle │ │ 4. job_handle.wait() / job_handle.terminate() │ └──────────┬──────────────────────┬────────────────────────┘ │ BEFORE_JOB_LAUNCH │ │ event selects one │ ▼ ▼ -┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ -│ ProcessJob │ │ DockerJob │ │ K8sJob │ -│ Launcher │ │ Launcher │ │ Launcher │ -│ ─────────────── │ │ ─────────────── │ │ ─────────────── │ -│ ProcessHandle │ │ DockerJobHandle │ │ K8sJobHandle │ -└─────────────────┘ └─────────────────┘ └─────────────────┘ - subprocess container pod +┌─────────────────┐ ┌─────────────────┐ +│ ProcessJob │ │ K8sJob │ +│ Launcher │ │ Launcher │ +│ ─────────────── │ │ ─────────────── │ +│ ProcessHandle │ │ K8sJobHandle │ +└─────────────────┘ └─────────────────┘ + subprocess pod ``` --- @@ -37,7 +37,7 @@ The upper layers (server engine, client executor) program exclusively against th ### 2.1 JobHandleSpec -Abstract base class representing a running job. All methods are `@abstractmethod`. +Abstract base class representing a running job (`class JobHandleSpec(ABC)`). All methods are decorated with `@abstractmethod`; Python enforces that any concrete subclass implements all three. Attempting to instantiate a subclass with any abstract method unimplemented raises `TypeError` at instantiation time. | Method | Signature | Semantics | |--------|-----------|-----------| @@ -47,7 +47,7 @@ Abstract base class representing a running job. All methods are `@abstractmethod ### 2.2 JobLauncherSpec -Abstract base class for launching jobs. Extends `FLComponent`, which gives it access to the event system. +Abstract base class for launching jobs (`class JobLauncherSpec(FLComponent, ABC)`). Extends `FLComponent` for event-system access. Adding `ABC` to the bases means Python selects `ABCMeta` as the metaclass automatically (since `ABCMeta` is a subclass of `type`), so `@abstractmethod` on `launch_job` is enforced at runtime for all subclasses. | Method | Signature | Semantics | |--------|-----------|-----------| @@ -55,7 +55,7 @@ Abstract base class for launching jobs. Extends `FLComponent`, which gives it ac ### 2.3 Supporting Types -**JobProcessArgs** -- String constants for the keys the upper layer places in `FLContextKey.JOB_PROCESS_ARGS`. These are the standardized parameters every job process needs (workspace path, auth token, job ID, parent URL, etc.). +**JobProcessArgs** -- String constants for the keys the upper layer places in `FLContextKey.JOB_PROCESS_ARGS`. Each value in the dict is a `(flag, value)` tuple, e.g. `JobProcessArgs.JOB_ID → ("-n", job_id)`. These are the standardized parameters every job process needs (workspace path, auth token, job ID, parent URL, etc.). | Constant | Value | Used by | |----------|-------|---------| @@ -63,7 +63,7 @@ Abstract base class for launching jobs. Extends `FLComponent`, which gives it ac | `WORKSPACE` | `"workspace"` | Server, Client | | `STARTUP_DIR` | `"startup_dir"` | Client | | `APP_ROOT` | `"app_root"` | Server | -| `AUTH_TOKEN` | `"auth_token"` | Client | +| `AUTH_TOKEN` | `"auth_token"` | Server, Client | | `TOKEN_SIGNATURE` | `"auth_signature"` | Server, Client | | `SSID` | `"ssid"` | Server, Client | | `JOB_ID` | `"job_id"` | Server, Client | @@ -77,7 +77,7 @@ Abstract base class for launching jobs. Extends `FLComponent`, which gives it ac | `TARGET` | `"target"` | Client | | `SCHEME` | `"scheme"` | Client | | `STARTUP_CONFIG_FILE` | `"startup_config_file"` | Server, Client | -| `RESTORE_SNAPSHOT` | `"restore_snapshot"` | Server | +| `RESTORE_SNAPSHOT` | `"restore_snapshot"` | (defined but not set as a standalone entry; server embeds `restore_snapshot=` into `OPTIONS` via `--set`) | | `OPTIONS` | `"options"` | Server, Client | **JobReturnCode** -- Standard exit semantics (extends `ProcessExitCode`): @@ -104,10 +104,14 @@ def get_job_launcher(job_meta, fl_ctx) -> JobLauncherSpec: engine = fl_ctx.get_engine() with engine.new_context() as job_launcher_ctx: job_launcher_ctx.remove_prop(FLContextKey.JOB_LAUNCHER) - job_launcher_ctx.set_prop(FLContextKey.JOB_META, job_meta, ...) + job_launcher_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) engine.fire_event(EventType.BEFORE_JOB_LAUNCH, job_launcher_ctx) job_launcher = job_launcher_ctx.get_prop(FLContextKey.JOB_LAUNCHER) - ... + if not (job_launcher and isinstance(job_launcher, list)): + raise RuntimeError(f"There's no job launcher can handle this job: {job_meta}.") + launcher = job_launcher[0] + if not isinstance(launcher, JobLauncherSpec): + raise RuntimeError(f"The job launcher must be JobLauncherSpec but got {type(launcher)}") return job_launcher[0] ``` @@ -118,7 +122,7 @@ Every registered `FLComponent` receives the `BEFORE_JOB_LAUNCH` event. Each laun | Condition | Launcher selected | |-----------|-------------------| | `extract_job_image(job_meta, site_name)` returns `None` | **ProcessJobLauncher** (no container image → run as subprocess) | -| `extract_job_image(job_meta, site_name)` returns an image | **DockerJobLauncher** or **K8sJobLauncher** (whichever is configured as a component) | +| `extract_job_image(job_meta, site_name)` returns an image | **K8sJobLauncher** (configured as a component) | ### 3.2 Server Side (`ServerEngine`) @@ -127,20 +131,26 @@ Location: `nvflare/private/fed/server/server_engine.py` ``` _start_runner_process(job, job_clients, snapshot, fl_ctx) │ -├─ 1. Build job_args dict with server-specific JobProcessArgs -│ (WORKSPACE, APP_ROOT, PARENT_URL, AUTH_TOKEN, HA_MODE, ...) +├─ 1. job_launcher = get_job_launcher(job.meta, fl_ctx) +│ (fires BEFORE_JOB_LAUNCH; JOB_PROCESS_ARGS not yet set) │ -├─ 2. fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args) +├─ 2. Build job_args dict with server-specific JobProcessArgs +│ (EXE_MODULE, JOB_ID, WORKSPACE, STARTUP_CONFIG_FILE, +│ APP_ROOT, HA_MODE, AUTH_TOKEN, TOKEN_SIGNATURE, +│ PARENT_URL, ROOT_URL, SERVICE_HOST, SERVICE_PORT, +│ SSID, OPTIONS) │ -├─ 3. job_launcher = get_job_launcher(job.meta, fl_ctx) +├─ 3. fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args) │ ├─ 4. job_handle = job_launcher.launch_job(job.meta, fl_ctx) │ -├─ 5. Store in run_processes[job_id][RunProcessKey.JOB_HANDLE] +├─ 5. Store in run_processes[job_id] +│ {JOB_HANDLE: job_handle, JOB_ID: job_id, PARTICIPANTS: job_clients} │ └─ 6. Start background thread → wait_for_complete(workspace, job_id, job_handle) │ ├─ job_handle.wait() # blocks until job finishes + ├─ wait up to 2s for UPDATE_RUN_STATUS message to arrive └─ get_return_code(job_handle, job_id, workspace, logger) ``` @@ -156,18 +166,24 @@ Location: `nvflare/private/fed/client/client_executor.py` ``` start_app(job_id, job_meta, ...) │ -├─ 1. Build job_args dict with client-specific JobProcessArgs -│ (WORKSPACE, STARTUP_DIR, CLIENT_NAME, PARENT_URL, AUTH_TOKEN, ...) +├─ 1. job_launcher = get_job_launcher(job_meta, fl_ctx) +│ (fires BEFORE_JOB_LAUNCH; JOB_PROCESS_ARGS not yet set) │ -├─ 2. fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args) +├─ 2. Build job_args dict with client-specific JobProcessArgs +│ (EXE_MODULE, JOB_ID, CLIENT_NAME, AUTH_TOKEN, TOKEN_SIGNATURE, +│ SSID, WORKSPACE, STARTUP_DIR, PARENT_URL, SCHEME, TARGET, +│ STARTUP_CONFIG_FILE, OPTIONS, optionally PARENT_CONN_SEC) │ -├─ 3. job_launcher = get_job_launcher(job_meta, fl_ctx) +├─ 3. fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args) │ ├─ 4. job_handle = job_launcher.launch_job(job_meta, fl_ctx) │ -├─ 5. Store in run_processes[job_id][RunProcessKey.JOB_HANDLE] +├─ 5. Fire EventType.AFTER_JOB_LAUNCH event │ -└─ 6. Start background thread → _wait_child_process_finish(...) +├─ 6. Store in run_processes[job_id] +│ {JOB_HANDLE: job_handle, STATUS: ClientStatus.STARTING} +│ +└─ 7. Start background thread → _wait_child_process_finish(...) │ ├─ job_handle.wait() └─ get_return_code(job_handle, job_id, workspace, logger) @@ -175,8 +191,8 @@ start_app(job_id, job_meta, ...) **Abort path** (`_terminate_job`): -1. Wait up to 10 seconds for the child to exit gracefully (polling `job_handle.poll()`). -2. Call `job_handle.terminate()`. +1. Poll `self.run_processes.get(job_id)` every 50 ms for up to 10 seconds; if the entry disappears the job finished gracefully. +2. Always call `job_handle.terminate()` regardless of whether the graceful exit was detected. ### 3.4 Return Code Resolution @@ -187,7 +203,7 @@ start_app(job_id, job_meta, ...) --- -## 4. The Three Implementations +## 4. The Two Implementations ### 4.1 Process Launcher (Subprocess) @@ -213,7 +229,7 @@ JobLauncherSpec (FLComponent) #### ProcessHandle -Wraps a `ProcessAdapter` (from `nvflare/utils/process_utils.py`) that manages a `subprocess.Popen` or a PID. +Wraps a `ProcessAdapter` (from `nvflare/utils/process_utils.py`) that manages a `subprocess.Popen` or a PID. The constructor accepts any one of: a `ProcessAdapter` directly, a `subprocess.Popen` object, or an integer `pid`. | Method | Implementation | |--------|---------------| @@ -225,9 +241,9 @@ Wraps a `ProcessAdapter` (from `nvflare/utils/process_utils.py`) that manages a | Step | Action | |------|--------| -| 1 | Copy `os.environ` and add `app_custom_folder` to `PYTHONPATH`. | +| 1 | Copy `os.environ`. If `app_custom_folder` is non-empty, call `add_custom_dir_to_path()`: appends the folder to `sys.path` and serializes the result into `PYTHONPATH` in the child environment. | | 2 | Call `self.get_command(job_meta, fl_ctx)` (abstract -- implemented by server/client subclasses). | -| 3 | Parse command with `shlex.split()`, spawn the process via `spawn_process(argv, new_env)`. | +| 3 | Parse command with `shlex.split(command, posix=True)`, spawn the process via `spawn_process(argv, new_env)`. | | 4 | Return `ProcessHandle(process_adapter=...)`. | **Event registration:** @@ -243,107 +259,51 @@ def handle_event(self, event_type, fl_ctx): **Server/Client subclasses** only override `get_command()`: -- `ServerProcessJobLauncher.get_command()` → `generate_server_command(fl_ctx)` → `python -m -w ...` -- `ClientProcessJobLauncher.get_command()` → `generate_client_command(fl_ctx)` → `python -m -w -n ...` - ---- - -### 4.2 Docker Launcher - -**File:** `nvflare/app_opt/job_launcher/docker_launcher.py` - -**Class hierarchy:** - -``` -JobHandleSpec - └── DockerJobHandle - -JobLauncherSpec (FLComponent) - └── DockerJobLauncher - ├── ClientDockerJobLauncher - └── ServerDockerJobLauncher -``` - -#### DockerJobHandle - -Wraps a Docker SDK `Container` object. - -| Method | Implementation | -|--------|---------------| -| `terminate()` | `container.stop()`. | -| `poll()` | Re-fetches container via `docker.from_env().containers.get(id)`. Maps status: `EXITED` → `SUCCESS`, `DEAD` → `ABORTED`, all others → `UNKNOWN`. Removes the container on terminal states. | -| `wait()` | `enter_states([EXITED, DEAD], timeout)` -- polls container status in a 1-second loop until a terminal state is reached. | - -Docker container states and their mappings: - -| Docker Status | JobReturnCode | -|---------------|---------------| -| `created` | `UNKNOWN` | -| `restarting` | `UNKNOWN` | -| `running` | `UNKNOWN` | -| `paused` | `UNKNOWN` | -| `exited` | `SUCCESS` | -| `dead` | `ABORTED` | - -#### DockerJobLauncher - -Constructor parameters: - -| Parameter | Default | Purpose | -|-----------|---------|---------| -| `mount_path` | `"/workspace"` | Container-side mount point for the host workspace. | -| `network` | `"nvflare-network"` | Docker network the container joins. | -| `timeout` | `None` | Maximum seconds to wait for the container to reach `RUNNING`. | - -Launch sequence: - -| Step | Action | -|------|--------| -| 1 | Extract `job_image` from `job_meta` via `extract_job_image()`. | -| 2 | Build `PYTHONPATH` with `app_custom_folder`. | -| 3 | Call `self.get_command(job_meta, fl_ctx)` → `(container_name, command_string)`. | -| 4 | Read `NVFL_DOCKER_WORKSPACE` env var for the host-side workspace path. | -| 5 | `docker_client.containers.run(image, command, name, network, volumes, detach=True)`. | -| 6 | `DockerJobHandle(container).enter_states([RUNNING], timeout)`. | -| 7 | If timeout or error → `handle.terminate()`, return `None`. Otherwise return handle. | - -**Event registration:** Same pattern as Process but with the opposite condition -- registers when `extract_job_image()` returns a truthy value. - -**Server/Client subclasses** override `get_command()`: - -- `ClientDockerJobLauncher` → returns `("{client_name}-{job_id}", generate_client_command(fl_ctx))`. -- `ServerDockerJobLauncher` → returns `("server-{job_id}", generate_server_command(fl_ctx))`. +- `ServerProcessJobLauncher.get_command()` → `generate_server_command(fl_ctx)` → `sys.executable -m -w ...` +- `ClientProcessJobLauncher.get_command()` → `generate_client_command(fl_ctx)` → `sys.executable -m -w -n ...` --- -### 4.3 Kubernetes Launcher +### 4.2 Kubernetes Launcher **File:** `nvflare/app_opt/job_launcher/k8s_launcher.py` **Class hierarchy:** ``` -JobHandleSpec +JobHandleSpec (ABC) └── K8sJobHandle -JobLauncherSpec (FLComponent) +JobLauncherSpec (FLComponent, ABC) └── K8sJobLauncher ├── ClientK8sJobLauncher └── ServerK8sJobLauncher ``` +#### Pod Name Sanitization + +`launch_job` calls `uuid4_to_rfc1123(job_meta.get(JobConstants.JOB_ID))` before constructing the pod. This converts a raw UUID4 job ID into an RFC 1123-compliant Kubernetes pod name: + +1. Lowercase the string. +2. Strip characters that are not alphanumeric or hyphens (`[^a-z0-9-]`). +3. Prefix with `"j"` if the first character is a digit (Kubernetes pod names must start with a letter). +4. Strip trailing hyphens. +5. Truncate to 63 characters. + +The sanitized name is used as both the pod name (`metadata.name`) and the `job_id` stored in `K8sJobHandle` for all subsequent API calls (`terminate`, `_query_phase`). + #### K8sJobHandle Wraps a Kubernetes Pod managed through the `CoreV1Api`. | Method | Implementation | |--------|---------------| -| `terminate()` | Calls `delete_namespaced_pod(grace_period_seconds=0)` in a try/except. `terminal_state = TERMINATED` is set when the delete succeeds, or when the `ApiException` has status 404 (pod already gone). For any other `ApiException`, the error is logged and `terminal_state` is left unchanged. | +| `terminate()` | Calls `delete_namespaced_pod(grace_period_seconds=0)`. `terminal_state = TERMINATED` is always set regardless of outcome: on success, on 404 `ApiException` (pod already gone, logged at `info`), on any other `ApiException` (logged at `error`), and on any other `Exception` such as network or serialization errors (also logged at `error`). This guarantees that callers holding a handle never poll indefinitely after calling `terminate()`, even when the K8s API is unreachable. | | `poll()` | If `terminal_state` is set, maps it through `JOB_RETURN_CODE_MAPPING` and returns a `JobReturnCode`. Otherwise calls `_query_state()` and maps the result the same way. Both paths consistently return `JobReturnCode`. | | `wait()` | Direct while loop: returns immediately if `terminal_state` is set; otherwise calls `_query_state()` and when `SUCCEEDED` or `TERMINATED` is reached, persists that state into `terminal_state` (so subsequent `poll()` calls remain accurate) and returns. Sleeps 1 second per iteration. No timeout. | -| `_query_phase()` | Calls `read_namespaced_pod` and returns the raw pod phase string (e.g. `"Pending"`, `"Running"`). On `ApiException`, returns `POD_Phase.UNKNOWN.value`. | +| `_query_phase()` | Calls `read_namespaced_pod` and returns the raw pod phase string (e.g. `"Pending"`, `"Running"`). On `ApiException`, logs the error and returns `POD_Phase.UNKNOWN.value`. On any other `Exception` (e.g. network errors, `urllib3.exceptions.MaxRetryError`), logs the error and also returns `POD_Phase.UNKNOWN.value` — preventing unhandled exceptions from propagating through `enter_states`/`wait`/`poll` and orphaning running pods. | | `_query_state()` | Calls `_query_phase()` and maps the raw phase through `POD_STATE_MAPPING` to a `JobState`. Used by `poll()` and `wait()`. | -| `enter_states()` | Per iteration: calls `_query_phase()` once, passes the raw phase to both `_stuck()` and directly to `POD_STATE_MAPPING.get()` — single K8s API call per poll cycle. Returns `True` when target state is reached, `False` on timeout or stuck detection. | +| `enter_states()` | Takes only `job_states_to_enter`; no `timeout` parameter — reads `self.timeout` from the instance. Per iteration: calls `_query_phase()` once, passes the raw phase to `_stuck_in_pending()` and to `POD_STATE_MAPPING.get()` — single K8s API call per poll cycle. Three early-exit paths, evaluated in this order: (1) **Stuck detection** — calls `self.terminate()` (sets `terminal_state = TERMINATED`) then returns `False`; (2) **Terminal phase** — if `pod_phase` is `"Failed"` or `"Succeeded"`, sets `terminal_state` from `POD_STATE_MAPPING` then returns `False` *without* calling `terminate()`. The terminal-phase check is evaluated before the plain-timeout check so that a job completing exactly when the timeout expires is recorded as `SUCCEEDED`/`TERMINATED` (per `POD_STATE_MAPPING`) rather than being misclassified as `TERMINATED` by `terminate()`; (3) **Plain timeout** (`self.timeout` elapsed) — calls `self.terminate()`, sets `terminal_state = TERMINATED`, returns `False`. Setting `terminal_state` in all `False`-return paths prevents `wait()` from looping indefinitely if the pod is GC'd before `wait()` runs. Returns `True` when target state is reached. | Pod phase mapping: @@ -355,9 +315,7 @@ Pod phase mapping: | `Failed` | `TERMINATED` | `ABORTED` | | `Unknown` | `UNKNOWN` | `UNKNOWN` | -> Note: `POD_Phase.TERMINATED` has been removed from the enum. `POD_STATE_MAPPING` now covers only the five real Kubernetes pod phases: `Pending`, `Running`, `Succeeded`, `Failed`, `Unknown`. - -**Stuck detection:** `_stuck_count` starts at `0`. A separate `_stuck_grace_period = 10` is added to `timeout` to form `_max_stuck_count = timeout + _stuck_grace_period`, giving a grace window of ~10 extra poll cycles before stuck detection activates. If `timeout` is `None`, `_max_stuck_count` is also `None` and stuck detection is disabled entirely. `enter_states()` passes the raw phase string from `_query_phase()` directly to `_stuck()`. `_stuck()` compares `current_phase == POD_Phase.PENDING.value` (i.e. `"Pending" == "Pending"`), incrementing `_stuck_count` on each match. When `_stuck_count > _max_stuck_count`, `_stuck()` returns `True`, `enter_states()` calls `terminate()` (which sets `terminal_state = TERMINATED` when the delete call succeeds or returns 404) and returns `False`. Note: `_stuck_count` and `_max_stuck_count` are poll-iteration counts (each ~1 second), not wall-clock seconds — the semantics coincide only because each poll sleeps exactly 1 second. +**Stuck detection:** `_stuck_count` starts at `0`. `_max_stuck_count` is set in the constructor as: `timeout if timeout is not None else pending_timeout`. So stuck detection is **always active** — if `timeout` is provided, `_max_stuck_count = timeout`; if `timeout` is `None`, `_max_stuck_count = pending_timeout` (default 30). The method `_stuck_in_pending(current_phase)` increments `_stuck_count` each time `current_phase == "Pending"` and returns `True` when `_stuck_count > _max_stuck_count`. When the phase is **not** `Pending`, `_stuck_count` is **reset to 0** — so a pod that transitions out of Pending (e.g. briefly reaches Running then goes back to Pending) starts its stuck count fresh. When it returns `True`, `enter_states()` calls `terminate()` and returns `False`. Plain startup timeout (wall-clock elapsed > `timeout`) also calls `terminate()` before returning `False`. In both cases `terminal_state` is always set to `TERMINATED` by `terminate()`, regardless of whether the API call succeeds or raises. Note: `_stuck_count` and `_max_stuck_count` are poll-iteration counts (~1 second each). #### K8sJobHandle Pod Manifest @@ -373,7 +331,7 @@ spec: containers: - name: container- image: - command: ["/usr/local/bin/python"] + command: [""] # default: /usr/local/bin/python; configurable via K8sJobLauncher.python_path args: ["-u", "-m", "", "-w", "", ...] volumeMounts: - name: nvflws @@ -402,33 +360,37 @@ spec: Constructor parameters: -| Parameter | Purpose | -|-----------|---------| -| `config_file_path` | Path to kubeconfig file. Loaded via `config.load_kube_config()`. | -| `workspace_pvc` | PVC claim name for the NVFlare workspace. | -| `etc_pvc` | PVC claim name for configuration/etc data. | -| `data_pvc_file_path` | Path to a YAML file mapping PVC names to mount paths for training data. | -| `timeout` | Maximum seconds to wait for pod to reach `Running` (also used as stuck threshold). | -| `namespace` | Kubernetes namespace (default: `"default"`). | +| Parameter | Default | Purpose | +|-----------|---------|---------| +| `config_file_path` | (required) | Path to kubeconfig file. Loaded via `config.load_kube_config()` at init time. | +| `workspace_pvc` | (required) | PVC claim name for the NVFlare workspace. | +| `etc_pvc` | (required) | PVC claim name for configuration/etc data. | +| `data_pvc_file_path` | (required) | Path to a YAML file mapping PVC names to mount paths for training data. Read and validated at init time; raises `ValueError` if the file is empty/contains no entries, or if the parsed YAML is not a `dict`. Only the first key (PVC name) is used. | +| `timeout` | `None` | Maximum wall-clock seconds for `enter_states([RUNNING])`. Also used as `_max_stuck_count`. If `None`, `pending_timeout` governs stuck detection instead. | +| `namespace` | `"default"` | Kubernetes namespace. | +| `pending_timeout` | `30` | Stuck-detection threshold (poll iterations) when `timeout` is `None`. Passed to `K8sJobHandle`. | +| `python_path` | `"/usr/local/bin/python"` | Absolute path to the Python executable used as the container `command`. Override for non-standard images where Python is not at `/usr/local/bin/python` (e.g. `/usr/bin/python3`). Passed through to `K8sJobHandle`. | Launch sequence: | Step | Action | |------|--------| -| 1 | Extract `job_image`, `site_name`, and optional `num_of_gpus` from `job_meta`. | -| 2 | Read `JOB_PROCESS_ARGS` from `fl_ctx`; extract `EXE_MODULE` as the container command. | -| 3 | Build `job_config` dict: name, image, container name, command, volume mounts/PVCs, `module_args` from `get_module_args()`, set list, GPU resources. | -| 4 | Create `K8sJobHandle(job_id, core_v1, job_config, namespace, timeout)` which builds the pod manifest. | -| 5 | `core_v1.create_namespaced_pod(body=pod_manifest, namespace)`. | -| 6 | Call `job_handle.enter_states([RUNNING], timeout)`. The return value is not checked. If stuck detection fires, `terminate()` is called inside `enter_states` (sets `terminal_state = TERMINATED` via `finally`) before returning the handle, so the caller can detect failure via `poll()`. On plain timeout (no stuck), the handle is returned with `terminal_state` unset and the pod may still be starting. | -| 7 | On `ApiException` from `create_namespaced_pod` → `job_handle.terminate()` then return the handle. Unlike Docker (which returns `None` on failure), the K8s launcher always returns a handle; callers detect failure when `poll()` or `wait()` resolves. | +| 1 | Validate and sanitize job ID: `raw_job_id = job_meta.get(JobConstants.JOB_ID)`; raises `RuntimeError` if missing or falsy. Then `job_id = uuid4_to_rfc1123(raw_job_id)`. All subsequent operations use this RFC 1123-compliant name. Extract `job_image`, `site_name`, and optional `num_of_gpus` from `job_meta`. | +| 2 | Read `JOB_PROCESS_ARGS` from `fl_ctx`; raises `RuntimeError` if the dict is absent. Raises `RuntimeError` if `EXE_MODULE` is missing or falsy. Extract the module name via `_, job_cmd = exe_module_entry`. | +| 3 | Build `job_config` dict: name (`job_id`), image, container name (`container-{job_id}`), command, volume mounts/PVCs, `module_args` from `get_module_args()`. `set_list` is conditionally set: if `fl_ctx.get_prop(FLContextKey.ARGS)` is non-None and `getattr(args, "set", None)` is non-None, `set_list = args.set` is added (see note below). Using `getattr` rather than direct attribute access guards against non-standard `ARGS` objects that lack a `set` attribute. If `num_of_gpus` is truthy (non-None **and** non-zero), adds `job_config["resources"] = {"limits": {"nvidia.com/gpu": num_of_gpus}}`; a value of `0` is intentionally excluded to avoid injecting `nvidia.com/gpu: 0` into the pod spec, which would explicitly request zero GPUs and can affect scheduling on GPU-enabled clusters differently than omitting the limit entirely. | +| 4 | Create `K8sJobHandle(job_id, core_v1, job_config, namespace=self.namespace, timeout=self.timeout, pending_timeout=self.pending_timeout)` which builds the pod manifest. | +| 5 | `core_v1.create_namespaced_pod(body=pod_manifest, namespace)` in a `try/except Exception` block. On any exception — including `ApiException` (K8s API error) and lower-level errors such as network timeouts or serialization failures — `job_handle.terminate()` is called (always sets `terminal_state = TERMINATED`) and the handle is returned. The scope of this handler is intentionally limited to this single API call; it does not swallow exceptions from the polling loop in step 6. | +| 6 | Call `job_handle.enter_states([RUNNING])` in a separate `try/except BaseException` block. On any exception (including `KeyboardInterrupt`) → `job_handle.terminate()` then re-raise. This ensures a pod already created in step 5 is not orphaned if the blocking poll loop is interrupted. The return value is captured: if `False`, logs a warning `"unable to enter running phase {job_id}"`. Inside `enter_states`: stuck detection and plain timeout both call `terminate()` (always sets `terminal_state = TERMINATED`) then return `False`; if the pod reaches a terminal phase (`Failed`/`Succeeded`), `terminal_state` is set from `POD_STATE_MAPPING` and `enter_states` returns `False` without calling `terminate()`. Setting `terminal_state` in all `False`-return paths prevents `wait()` from looping if the pod is GC'd before `wait()` runs. | +| 7 | Return `job_handle`. The K8s launcher always returns a handle; callers detect launch failure when `poll()` or `wait()` resolves. | + +> **`set_list` note:** `args.set` is the CLI `--set` items stored in `FLContextKey.ARGS` at the time `launch_job` is called. The server and client both make a deep copy of `FLContextKey.ARGS`, append `print_conf=True` (and server also appends `restore_snapshot=`) to that copy, and embed the expanded string into `JOB_PROCESS_ARGS[OPTIONS]`. They do **not** write the modified copy back to `FLContextKey.ARGS`. As a result, the K8s launcher's `set_list` contains only the original CLI `--set` items — **without** `print_conf=True` or `restore_snapshot=...`. The Process launcher receives those extra flags through `OPTIONS`, which K8s excludes from `module_args` (`get_*_job_args(include_set_options=False)`). **Server/Client subclasses** override `get_module_args()`: -- `ClientK8sJobLauncher` → Filters `JOB_PROCESS_ARGS` through `get_client_job_args(include_exe_module=False, include_set_options=False)` to produce the dict of `-flag value` pairs for the container args list. -- `ServerK8sJobLauncher` → Same pattern with `get_server_job_args(...)`. +- `ClientK8sJobLauncher` → Calls `_job_args_dict(job_args, get_client_job_args(False, False))` — filters `JOB_PROCESS_ARGS` excluding `EXE_MODULE` and `OPTIONS`, producing a `{flag: value}` dict for the container `args` list. +- `ServerK8sJobLauncher` → Same pattern with `get_server_job_args(False, False)`. -**Key difference from Process/Docker:** The K8s launcher does not build a shell command string. Instead, it passes the Python executable as `command` and constructs a structured `args` list (`["-u", "-m", "", "-w", "", ...]`) directly in the pod spec. +**Key difference from Process:** The K8s launcher does not build a shell command string. Instead, it passes the Python executable as `command` and constructs a structured `args` list (`["-u", "-m", "", "-w", "", ...]`) directly in the pod spec. --- @@ -437,19 +399,15 @@ Launch sequence: ### 5.1 Full Class Hierarchy ``` -JobHandleSpec (abstract) +JobHandleSpec (ABC) ├── ProcessHandle (wraps ProcessAdapter / subprocess.Popen) -├── DockerJobHandle (wraps docker.Container) └── K8sJobHandle (wraps CoreV1Api + pod name) -JobLauncherSpec (abstract, extends FLComponent) +JobLauncherSpec (FLComponent, ABC) ├── ProcessJobLauncher (abstract: get_command) │ ├── ServerProcessJobLauncher │ └── ClientProcessJobLauncher -├── DockerJobLauncher (abstract: get_command) -│ ├── ServerDockerJobLauncher -│ └── ClientDockerJobLauncher -└── K8sJobLauncher (abstract: get_module_args) +└── K8sJobLauncher (abstract: get_module_args; inherits ABCMeta from JobLauncherSpec) ├── ServerK8sJobLauncher └── ClientK8sJobLauncher ``` @@ -458,12 +416,11 @@ JobLauncherSpec (abstract, extends FLComponent) **Strategy Pattern** -- Each launcher is a strategy for running jobs. The engine programs against `JobLauncherSpec`; the concrete strategy is selected at runtime through the event system. -**Template Method Pattern** -- Each base launcher (`ProcessJobLauncher`, `DockerJobLauncher`, `K8sJobLauncher`) implements `launch_job()` with a fixed algorithm, delegating the variable part to an abstract method: +**Template Method Pattern** -- Each base launcher (`ProcessJobLauncher`, `K8sJobLauncher`) implements `launch_job()` with a fixed algorithm, delegating the variable part to an abstract method: | Base Launcher | Template method calls | Abstract hook | |---------------|----------------------|---------------| | `ProcessJobLauncher` | `launch_job()` → `get_command()` | `get_command(job_meta, fl_ctx) -> str` | -| `DockerJobLauncher` | `launch_job()` → `get_command()` | `get_command(job_meta, fl_ctx) -> (str, str)` | | `K8sJobLauncher` | `launch_job()` → `get_module_args()` | `get_module_args(job_id, fl_ctx) -> dict` | Server and client subclasses provide the implementation of these hooks, producing the correct command-line arguments for each role. @@ -472,23 +429,23 @@ Server and client subclasses provide the implementation of these hooks, producin --- -## 6. Comparison: Process vs Docker vs Kubernetes - -| Aspect | Process | Docker | Kubernetes | -|--------|---------|--------|------------| -| **When selected** | No `job_image` for site | `job_image` present | `job_image` present | -| **Execution unit** | OS subprocess | Docker container | Kubernetes Pod | -| **Isolation** | Shared host, inherited env | Container isolation, mounted workspace | Pod isolation, PVC-backed volumes | -| **Command format** | Shell command string (`python -m ...`) | Shell command inside `/bin/bash -c` | Structured `command` + `args` list in pod spec | -| **Workspace access** | Direct filesystem (same host) | Host directory bind-mounted to container | PersistentVolumeClaims | -| **Data access** | Direct filesystem | Via bind mount | Via PVC (configured in YAML) | -| **Start verification** | None (spawn returns immediately) | Poll for `RUNNING` state with timeout | `enter_states([RUNNING], timeout)` with stuck detection; return value not checked — on stuck, `terminal_state` is set so caller can detect via `poll()` | -| **Wait mechanism** | `subprocess.Popen.wait()` (OS-level block) | Poll container status for `EXITED`/`DEAD` | Direct while loop via `_query_state()`; no timeout; exits when `terminal_state` set or `SUCCEEDED`/`TERMINATED` reached | -| **Terminate** | `SIGTERM`/`SIGKILL` via `ProcessAdapter` | `container.stop()` | `delete_namespaced_pod(grace_period=0)`; `terminal_state` set to `TERMINATED` on success or 404; left unchanged (error logged) for other exceptions | -| **Return code source** | Process exit code or RC file | Container status mapping or RC file | Pod phase mapping or RC file; `poll()` now consistently returns `JobReturnCode` via `JOB_RETURN_CODE_MAPPING` | -| **GPU support** | Inherited from host environment | Not explicitly managed | `nvidia.com/gpu` resource limit in pod spec | -| **Dependencies** | stdlib only | `docker` Python SDK | `kubernetes` Python client + kubeconfig | -| **Typical use case** | Simulator, single-machine POC | Multi-container on single host | Production cluster with shared storage | +## 6. Comparison: Process vs Kubernetes + +| Aspect | Process | Kubernetes | +|--------|---------|------------| +| **When selected** | No `job_image` for site | `job_image` present | +| **Execution unit** | OS subprocess | Kubernetes Pod | +| **Isolation** | Shared host, inherited env | Pod isolation, PVC-backed volumes | +| **Command format** | Shell command string (`sys.executable -m ...`) | Structured `command: ["/usr/local/bin/python"]` + `args` list in pod spec | +| **Workspace access** | Direct filesystem (same host) | PersistentVolumeClaims | +| **Data access** | Direct filesystem | Via PVC (configured in YAML) | +| **Start verification** | None (spawn returns immediately) | `enter_states([RUNNING])` called (reads `self.timeout`); return value is checked — `False` logs a debug message but handle is still always returned; on stuck detection or plain timeout, `terminate()` is called (sets `terminal_state = TERMINATED`); callers detect failure via `poll()` | +| **Wait mechanism** | `subprocess.Popen.wait()` (OS-level block) | Direct while loop via `_query_state()`; no timeout; exits when `terminal_state` set or `SUCCEEDED`/`TERMINATED` reached | +| **Terminate** | `SIGTERM`/`SIGKILL` via `ProcessAdapter` | `delete_namespaced_pod(grace_period=0)`; `terminal_state` always set to `TERMINATED` — on success, 404, other `ApiException`, or any lower-level `Exception` (network/serialization). Errors are logged but never propagated. | +| **Return code source** | Process exit code or RC file | Pod phase mapping or RC file; `poll()` consistently returns `JobReturnCode` via `JOB_RETURN_CODE_MAPPING` | +| **GPU support** | Inherited from host environment | `nvidia.com/gpu` resource limit set in pod spec from `job_meta` resource spec (`num_of_gpus`); omitted if not specified | +| **Dependencies** | stdlib only | `kubernetes` Python client + kubeconfig | +| **Typical use case** | Simulator, single-machine POC | Production cluster with shared storage | --- @@ -540,8 +497,7 @@ The following shows the end-to-end flow for launching and managing a job, applic │ launcher.launch_job(job_meta, fl_ctx) │ │ │─────────────────────────────────────────────->│ │ │ │ │ create exec unit │ - │ │ │ (process/container │ - │ │ │ /pod) │ + │ │ │ (process/pod) │ │ │ │─────────────────────>│ │ │ │ return handle │ │<─────────────────────────────────────────────────────────────────────│ @@ -583,23 +539,7 @@ Launchers are registered as FL components in the site's `resources.json`. The co } ``` -### 9.2 Docker Launcher - -```json -{ - "id": "job_launcher", - "path": "nvflare.app_opt.job_launcher.docker_launcher.ClientDockerJobLauncher", - "args": { - "mount_path": "/workspace", - "network": "nvflare-network", - "timeout": 60 - } -} -``` - -Requires the `NVFL_DOCKER_WORKSPACE` environment variable to be set on the host to identify the workspace directory to bind-mount. - -### 9.3 Kubernetes Launcher +### 9.2 Kubernetes Launcher ```json { @@ -628,14 +568,14 @@ my-data-pvc: /var/tmp/nvflare/data ## 10. Future Improvements -1. **Explicit launcher selection** -- Today "has image" → Docker or K8s, "no image" → Process. Allow an explicit `launcher_type` field in job meta or deploy map so a site can support multiple container backends or provide fallback ordering (e.g., try K8s, fall back to Docker). +1. **Explicit launcher selection** -- Today "has image" → K8s, "no image" → Process. Allow an explicit `launcher_type` field in job meta or deploy map so a site can support multiple backends or provide fallback ordering. -2. **Consistent GPU handling** -- The K8s launcher reads `num_of_gpus` from the resource spec; the Docker and Process launchers do not. Standardize resource declaration so job definitions remain portable across backends. +2. **Consistent GPU handling** -- The K8s launcher applies `num_of_gpus` from the job resource spec as a pod `nvidia.com/gpu` limit; the Process launcher ignores it entirely. Standardize resource declaration so job definitions remain portable across backends. -3. **Unified cleanup** -- Standardize container/pod cleanup policy across launchers (auto-remove on exit, configurable retention for debugging) and centralize it in the handle or engine. +3. **Unified cleanup** -- Standardize pod cleanup policy (auto-remove on exit, configurable retention for debugging) and centralize it in the handle or engine. -4. **Consistent timeout policy and failure semantics** -- The Process launcher has no start timeout. Docker polls for `RUNNING` and returns `None` on failure. K8s polls for `Running` with stuck detection (terminates and sets `terminal_state` on stuck) but does not act on plain startup timeout — if a pod is slow to start but not stuck in `Pending`, the handle is returned with `terminal_state` unset. Consider terminating explicitly on timeout and unifying failure return across all launchers (either always `None` or always a terminated handle). +4. **Consistent timeout policy and failure semantics** -- The Process launcher has no start timeout. K8s checks the `enter_states` return value and logs a debug message on failure (termination already happens inside `enter_states` for stuck/timeout paths), always returning a handle. Consider exposing a distinct "startup failed" state on the handle so callers can react without polling. -5. **Observability** -- Add an optional `get_info()` method to `JobHandleSpec` so the engine can log launcher-specific details (container ID, pod name, namespace, PID) for debugging and operations. +5. **Observability** -- Add an optional `get_info()` method to `JobHandleSpec` so the engine can log launcher-specific details (pod name, namespace, PID) for debugging and operations. -6. **Testing** -- Provide `MockJobLauncher` and `MockJobHandle` implementations for unit tests that verify server/client flow without starting real processes or containers. +6. **Testing** -- Provide `MockJobLauncher` and `MockJobHandle` implementations for unit tests that verify server/client flow without starting real processes or pods. diff --git a/nvflare/apis/job_launcher_spec.py b/nvflare/apis/job_launcher_spec.py index 8fedfb866e..f8cd4a7192 100644 --- a/nvflare/apis/job_launcher_spec.py +++ b/nvflare/apis/job_launcher_spec.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from abc import abstractmethod +from abc import ABC, abstractmethod from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FLContextKey @@ -56,7 +56,7 @@ def add_launcher(launcher, fl_ctx: FLContext): fl_ctx.set_prop(FLContextKey.JOB_LAUNCHER, job_launcher, private=True, sticky=False) -class JobHandleSpec: +class JobHandleSpec(ABC): @abstractmethod def terminate(self): """To terminate the job run. @@ -85,7 +85,7 @@ def wait(self): raise NotImplementedError() -class JobLauncherSpec(FLComponent): +class JobLauncherSpec(FLComponent, ABC): @abstractmethod def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: """To launch a job run. diff --git a/nvflare/app_opt/job_launcher/docker_launcher.py b/nvflare/app_opt/job_launcher/docker_launcher.py index 50e38a504a..8779e4711d 100644 --- a/nvflare/app_opt/job_launcher/docker_launcher.py +++ b/nvflare/app_opt/job_launcher/docker_launcher.py @@ -45,10 +45,10 @@ class DOCKER_STATE: class DockerJobHandle(JobHandleSpec): - def __init__(self, timeout=None): + def __init__(self, container, timeout=None): super().__init__() - self.container = None + self.container = container self.timeout = timeout self.logger = logging.getLogger(self.__class__.__name__) @@ -68,9 +68,6 @@ def wait(self): if self.container: self.enter_states([DOCKER_STATE.EXITED, DOCKER_STATE.DEAD], self.timeout) - def _set_container(self, container): - self.container = container - def _get_container(self): try: docker_client = docker.from_env() @@ -123,7 +120,6 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: docker_workspace = os.environ.get("NVFL_DOCKER_WORKSPACE") self.logger.info(f"launch_job {job_id} in docker_workspace: {docker_workspace}") docker_client = docker.from_env() - handle = DockerJobHandle() try: container = docker_client.containers.run( job_image, @@ -141,22 +137,24 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: # ports=ports, # Map container ports to host ports (optional) ) self.logger.info(f"Launch the job in DockerJobLauncher using image: {job_image}") - handle._set_container(container) + + handle = DockerJobHandle(container) try: - launched = handle.enter_states([DOCKER_STATE.RUNNING], timeout=self.timeout) - if not launched: + if handle.enter_states([DOCKER_STATE.RUNNING], timeout=self.timeout): + return handle + else: handle.terminate() - return handle + return None except: handle.terminate() - return handle + return None except docker.errors.ImageNotFound: self.logger.error(f"Failed to launcher job: {job_id} in DockerJobLauncher. Image '{job_image}' not found.") - return handle + return None except docker.errors.APIError as e: self.logger.error(f"Error starting container: {e}") - return handle + return None def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.BEFORE_JOB_LAUNCH: diff --git a/nvflare/app_opt/job_launcher/k8s_launcher.py b/nvflare/app_opt/job_launcher/k8s_launcher.py index 3c0ab04884..5fe5b58e46 100644 --- a/nvflare/app_opt/job_launcher/k8s_launcher.py +++ b/nvflare/app_opt/job_launcher/k8s_launcher.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import logging +import re import time from abc import abstractmethod from enum import Enum @@ -89,8 +91,29 @@ class PV_NAME(Enum): ] +def uuid4_to_rfc1123(uuid_str: str) -> str: + name = uuid_str.lower() + # Strip any chars that aren't alphanumeric or hyphen + name = re.sub(r"[^a-z0-9-]", "", name) + # Prefix with a letter if it starts with a digit + if name and name[0].isdigit(): + name = "j" + name + # Kubernetes label limit: 63 chars; strip trailing hyphens after truncation + # (truncation can expose a hyphen that was interior before slicing) + return name[:63].rstrip("-") + + class K8sJobHandle(JobHandleSpec): - def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, namespace="default", timeout=None): + def __init__( + self, + job_id: str, + api_instance: core_v1_api, + job_config: dict, + namespace="default", + timeout=None, + pending_timeout=30, + python_path="/usr/local/bin/python", + ): super().__init__() self.job_id = job_id self.timeout = timeout @@ -113,8 +136,7 @@ def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, nam { "image": None, "name": None, - "resources": None, - "command": ["/usr/local/bin/python"], + "command": [python_path], "args": None, # args_list + args_dict + args_sets "volumeMounts": None, # volume_mount_list "imagePullPolicy": "Always", @@ -127,14 +149,13 @@ def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, nam self.container_volume_mount_list = [] self._make_manifest(job_config) self._stuck_count = 0 - self._stuck_grace_period = 10 # seconds to wait before counting Pending as stuck - self._max_stuck_count = (self.timeout + self._stuck_grace_period) if self.timeout is not None else None + self._max_stuck_count = self.timeout if self.timeout is not None else pending_timeout self.logger = logging.getLogger(self.__class__.__name__) def _make_manifest(self, job_config): self.container_volume_mount_list.extend(job_config.get("volume_mount_list", [])) set_list = job_config.get("set_list") - if set_list is None: + if not set_list: self.container_args_module_args_sets = list() else: self.container_args_module_args_sets = ["--set"] + set_list @@ -147,13 +168,16 @@ def _make_manifest(self, job_config): if v is None: continue self.container_args_module_args_dict_as_list.append(k) - self.container_args_module_args_dict_as_list.append(v) + self.container_args_module_args_dict_as_list.append(str(v)) self.volume_list.extend(job_config.get("volume_list", [])) self.pod_manifest["metadata"]["name"] = job_config.get("name") self.pod_manifest["spec"]["containers"] = self.container_list self.pod_manifest["spec"]["volumes"] = self.volume_list - self.container_list[0]["image"] = job_config.get("image", "nvflare/nvflare:2.8.0") + image = job_config.get("image") + if not image: + raise ValueError("job_config must contain a non-empty 'image' key") + self.container_list[0]["image"] = image self.container_list[0]["name"] = job_config.get("container_name", "nvflare_job") self.container_list[0]["args"] = ( self.container_args_python_args_list @@ -161,13 +185,13 @@ def _make_manifest(self, job_config): + self.container_args_module_args_sets ) self.container_list[0]["volumeMounts"] = self.container_volume_mount_list - if job_config.get("resources", {}).get("limits", {}).get("nvidia.com/gpu") is not None: + if job_config.get("resources", {}).get("limits", {}).get("nvidia.com/gpu"): self.container_list[0]["resources"] = job_config.get("resources") def get_manifest(self): - return self.pod_manifest + return copy.deepcopy(self.pod_manifest) - def enter_states(self, job_states_to_enter: list, timeout=None): + def enter_states(self, job_states_to_enter: list): starting_time = time.time() if not isinstance(job_states_to_enter, (list, tuple)): job_states_to_enter = [job_states_to_enter] @@ -175,29 +199,33 @@ def enter_states(self, job_states_to_enter: list, timeout=None): raise ValueError(f"expect job_states_to_enter with valid values, but get {job_states_to_enter}") while True: pod_phase = self._query_phase() - if self._stuck(pod_phase): + if self._stuck_in_pending(pod_phase): self.terminate() return False job_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN) if job_state in job_states_to_enter: return True - elif timeout is not None and time.time() - starting_time > timeout: + elif pod_phase in [POD_Phase.FAILED.value, POD_Phase.SUCCEEDED.value]: # terminal state + self.terminal_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN) + return False + elif self.timeout is not None and time.time() - starting_time > self.timeout: + self.terminate() return False time.sleep(1) def terminate(self): try: - resp = self.api_instance.delete_namespaced_pod( - name=self.job_id, namespace=self.namespace, grace_period_seconds=0 - ) + self.api_instance.delete_namespaced_pod(name=self.job_id, namespace=self.namespace, grace_period_seconds=0) self.terminal_state = JobState.TERMINATED except ApiException as e: - # If the pod is already gone, treat it as terminated; otherwise, leave state unchanged. if getattr(e, "status", None) == 404: self.logger.info(f"job {self.job_id} pod not found during termination; assuming terminated") - self.terminal_state = JobState.TERMINATED else: self.logger.error(f"failed to terminate job {self.job_id}: {e}") + self.terminal_state = JobState.TERMINATED + except Exception as e: + self.logger.error(f"unexpected error terminating job {self.job_id}: {e}") + self.terminal_state = JobState.TERMINATED return None def poll(self): @@ -210,6 +238,10 @@ def _query_phase(self): try: resp = self.api_instance.read_namespaced_pod(name=self.job_id, namespace=self.namespace) except ApiException as e: + self.logger.warning(f"failed to query pod phase {self.job_id}: {e}") + return POD_Phase.UNKNOWN.value + except Exception as e: + self.logger.warning(f"unexpected error querying pod phase {self.job_id}: {e}") return POD_Phase.UNKNOWN.value return resp.status.phase @@ -217,13 +249,13 @@ def _query_state(self): pod_phase = self._query_phase() return POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN) - def _stuck(self, current_phase): - if self._max_stuck_count is None: - return False + def _stuck_in_pending(self, current_phase): if current_phase == POD_Phase.PENDING.value: self._stuck_count += 1 - if self._stuck_count > self._max_stuck_count: + if self._max_stuck_count is not None and self._stuck_count >= self._max_stuck_count: return True + else: + self._stuck_count = 0 return False def wait(self): @@ -246,6 +278,8 @@ def __init__( data_pvc_file_path: str, timeout=None, namespace="default", + pending_timeout=30, + python_path="/usr/local/bin/python", ): super().__init__() self.logger = logging.getLogger(self.__class__.__name__) @@ -255,6 +289,8 @@ def __init__( self.data_pvc_file_path = data_pvc_file_path self.timeout = timeout self.namespace = namespace + self.pending_timeout = pending_timeout + self.python_path = python_path with open(data_pvc_file_path, "rt") as f: data_pvc_dict = yaml.safe_load(f) if not data_pvc_dict: @@ -262,8 +298,9 @@ def __init__( # data_pvc_dict will be pvc: mountPath # currently, support one pvc and always mount to /var/tmp/nvflare/data # ie, ignore the mountPath in data_pvc_dict + if not isinstance(data_pvc_dict, dict): + raise ValueError(f"file at data_pvc_file_path '{data_pvc_file_path}' does not contain a dictionary.") self.data_pvc = list(data_pvc_dict.keys())[0] - config.load_kube_config(config_file_path) try: c = Configuration().get_default_copy() @@ -276,17 +313,22 @@ def __init__( def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: site_name = fl_ctx.get_identity_name() - job_id = job_meta.get(JobConstants.JOB_ID) + raw_job_id = job_meta.get(JobConstants.JOB_ID) + if not raw_job_id: + raise RuntimeError(f"missing {JobConstants.JOB_ID} in job_meta") + job_id = uuid4_to_rfc1123(raw_job_id) args = fl_ctx.get_prop(FLContextKey.ARGS) job_image = extract_job_image(job_meta, site_name) site_resources = job_meta.get(JobMetaKey.RESOURCE_SPEC.value, {}).get(site_name, {}) job_resource = site_resources.get("num_of_gpus", None) - job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS) if not job_args: raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext") - _, job_cmd = job_args[JobProcessArgs.EXE_MODULE] + exe_module_entry = job_args.get(JobProcessArgs.EXE_MODULE) + if not exe_module_entry: + raise RuntimeError(f"missing {JobProcessArgs.EXE_MODULE} in {FLContextKey.JOB_PROCESS_ARGS}") + _, job_cmd = exe_module_entry job_config = { "name": job_id, "image": job_image, @@ -299,21 +341,36 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: {"name": PV_NAME.ETC.value, "persistentVolumeClaim": {"claimName": self.etc_pvc}}, ], "module_args": self.get_module_args(job_id, fl_ctx), - "set_list": args.set, - "resources": {"limits": {"nvidia.com/gpu": job_resource}}, } - - job_handle = K8sJobHandle(job_id, self.core_v1, job_config, namespace=self.namespace, timeout=self.timeout) + if args is not None and getattr(args, "set", None) is not None: + job_config.update({"set_list": args.set}) + if job_resource: + job_config.update({"resources": {"limits": {"nvidia.com/gpu": job_resource}}}) + job_handle = K8sJobHandle( + job_id, + self.core_v1, + job_config, + namespace=self.namespace, + timeout=self.timeout, + pending_timeout=self.pending_timeout, + python_path=self.python_path, + ) pod_manifest = job_handle.get_manifest() self.logger.debug(f"launch job with k8s_launcher. {pod_manifest=}") try: self.core_v1.create_namespaced_pod(body=pod_manifest, namespace=self.namespace) - job_handle.enter_states([JobState.RUNNING], timeout=self.timeout) + except Exception as e: + self.logger.error(f"failed to launch job {job_id}: {e}") + job_handle.terminal_state = JobState.TERMINATED return job_handle - except ApiException as e: - self.logger.error(f"failed to launch job {self.job_id}: {e}") + try: + entered_running = job_handle.enter_states([JobState.RUNNING]) + except BaseException: job_handle.terminate() - return job_handle + raise + if not entered_running: + self.logger.warning(f"unable to enter running phase {job_id}") + return job_handle def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.BEFORE_JOB_LAUNCH: diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index d44d6e1ad7..1387f9ca9c 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -59,7 +59,6 @@ def __init__( client_register_interval=2, timeout=5.0, maint_msg_timeout=5.0, - cell_creation_timeout=15.0, ): """To init the Communicator. @@ -80,7 +79,7 @@ def __init__( self.client_register_interval = client_register_interval self.timeout = timeout self.maint_msg_timeout = maint_msg_timeout - self.creation_timeout = cell_creation_timeout + # token and token_signature are issued by the Server after the client is authenticated # they are added to every message going to the server as proof of authentication self.token = None @@ -274,9 +273,9 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): start = time.time() while not self.cell: self.logger.info("Waiting for the client cell to be created.") - if time.time() - start > self.creation_timeout: + if time.time() - start > 15.0: raise RuntimeError("Client cell could not be created. Failed to login the client.") - time.sleep(1) + time.sleep(0.5) shared_fl_ctx = gen_new_peer_ctx(fl_ctx) private_key_file = None diff --git a/tests/unit_test/app_opt/job_launcher/docker_launcher_test.py b/tests/unit_test/app_opt/job_launcher/docker_launcher_test.py deleted file mode 100644 index 3e71e25b85..0000000000 --- a/tests/unit_test/app_opt/job_launcher/docker_launcher_test.py +++ /dev/null @@ -1,464 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import Mock, patch - -from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLContextKey, JobConstants, ReservedKey -from nvflare.apis.fl_context import FLContext -from nvflare.apis.job_def import JobMetaKey -from nvflare.apis.job_launcher_spec import JobReturnCode -from nvflare.app_opt.job_launcher.docker_launcher import ( - DOCKER_STATE, - JOB_RETURN_CODE_MAPPING, - DockerJobHandle, - DockerJobLauncher, -) - - -# --------------------------------------------------------------------------- -# Constants and mappings -# --------------------------------------------------------------------------- -class TestDockerState: - def test_state_values(self): - assert DOCKER_STATE.CREATED == "created" - assert DOCKER_STATE.RESTARTING == "restarting" - assert DOCKER_STATE.RUNNING == "running" - assert DOCKER_STATE.PAUSED == "paused" - assert DOCKER_STATE.EXITED == "exited" - assert DOCKER_STATE.DEAD == "dead" - - -class TestDockerJobReturnCodeMapping: - def test_running_maps_to_unknown(self): - assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.RUNNING] == JobReturnCode.UNKNOWN - - def test_exited_maps_to_success(self): - assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.EXITED] == JobReturnCode.SUCCESS - - def test_dead_maps_to_aborted(self): - assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.DEAD] == JobReturnCode.ABORTED - - def test_created_maps_to_unknown(self): - assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.CREATED] == JobReturnCode.UNKNOWN - - def test_paused_maps_to_unknown(self): - assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.PAUSED] == JobReturnCode.UNKNOWN - - def test_restarting_maps_to_unknown(self): - assert JOB_RETURN_CODE_MAPPING[DOCKER_STATE.RESTARTING] == JobReturnCode.UNKNOWN - - -# --------------------------------------------------------------------------- -# DockerJobHandle -# --------------------------------------------------------------------------- -class TestDockerJobHandle: - def test_init_defaults(self): - handle = DockerJobHandle() - assert handle.container is None - assert handle.timeout is None - - def test_init_with_timeout(self): - handle = DockerJobHandle(timeout=30) - assert handle.container is None - assert handle.timeout == 30 - - def test_set_container(self): - handle = DockerJobHandle() - container = Mock() - handle._set_container(container) - assert handle.container is container - - def test_terminate_stops_container(self): - handle = DockerJobHandle() - container = Mock() - handle._set_container(container) - handle.terminate() - container.stop.assert_called_once() - - def test_terminate_noop_when_no_container(self): - handle = DockerJobHandle() - handle.terminate() - - # -- poll ----------------------------------------------------------------- - @patch.object(DockerJobHandle, "_get_container") - def test_poll_running_returns_unknown(self, mock_get): - container = Mock() - container.status = DOCKER_STATE.RUNNING - mock_get.return_value = container - handle = DockerJobHandle() - assert handle.poll() == JobReturnCode.UNKNOWN - - @patch.object(DockerJobHandle, "_get_container") - def test_poll_exited_removes_and_returns_success(self, mock_get): - container = Mock() - container.status = DOCKER_STATE.EXITED - mock_get.return_value = container - handle = DockerJobHandle() - result = handle.poll() - container.remove.assert_called_once_with(force=True) - assert result == JobReturnCode.SUCCESS - - @patch.object(DockerJobHandle, "_get_container") - def test_poll_dead_removes_and_returns_aborted(self, mock_get): - container = Mock() - container.status = DOCKER_STATE.DEAD - mock_get.return_value = container - handle = DockerJobHandle() - result = handle.poll() - container.remove.assert_called_once_with(force=True) - assert result == JobReturnCode.ABORTED - - @patch.object(DockerJobHandle, "_get_container") - def test_poll_returns_none_when_container_gone(self, mock_get): - mock_get.return_value = None - handle = DockerJobHandle() - assert handle.poll() is None - - @patch.object(DockerJobHandle, "_get_container") - def test_poll_unknown_status_returns_unknown(self, mock_get): - container = Mock() - container.status = "something_unexpected" - mock_get.return_value = container - handle = DockerJobHandle() - assert handle.poll() == JobReturnCode.UNKNOWN - - # -- wait ----------------------------------------------------------------- - @patch.object(DockerJobHandle, "enter_states") - def test_wait_calls_enter_states(self, mock_enter): - handle = DockerJobHandle(timeout=10) - handle._set_container(Mock()) - handle.wait() - mock_enter.assert_called_once_with([DOCKER_STATE.EXITED, DOCKER_STATE.DEAD], 10) - - def test_wait_noop_when_no_container(self): - handle = DockerJobHandle() - handle.wait() - - # -- enter_states --------------------------------------------------------- - @patch.object(DockerJobHandle, "_get_container") - def test_enter_states_returns_true_when_state_matches(self, mock_get): - container = Mock() - container.status = DOCKER_STATE.RUNNING - mock_get.return_value = container - handle = DockerJobHandle() - assert handle.enter_states([DOCKER_STATE.RUNNING]) is True - - @patch.object(DockerJobHandle, "_get_container") - def test_enter_states_returns_false_when_container_gone(self, mock_get): - mock_get.return_value = None - handle = DockerJobHandle() - assert handle.enter_states([DOCKER_STATE.RUNNING]) is False - - @patch.object(DockerJobHandle, "_get_container") - def test_enter_states_returns_false_on_timeout(self, mock_get): - container = Mock() - container.status = DOCKER_STATE.CREATED - mock_get.return_value = container - handle = DockerJobHandle() - assert handle.enter_states([DOCKER_STATE.RUNNING], timeout=0) is False - - @patch.object(DockerJobHandle, "_get_container") - def test_enter_states_wraps_single_state(self, mock_get): - container = Mock() - container.status = DOCKER_STATE.EXITED - mock_get.return_value = container - handle = DockerJobHandle() - assert handle.enter_states(DOCKER_STATE.EXITED) is True - - # -- _get_container ------------------------------------------------------- - @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") - def test_get_container_returns_container(self, mock_docker): - orig_container = Mock() - orig_container.id = "abc123" - refreshed = Mock() - mock_docker.from_env.return_value.containers.get.return_value = refreshed - - handle = DockerJobHandle() - handle._set_container(orig_container) - result = handle._get_container() - assert result is refreshed - mock_docker.from_env.return_value.containers.get.assert_called_once_with("abc123") - - @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") - def test_get_container_returns_none_on_exception(self, mock_docker): - orig_container = Mock() - orig_container.id = "abc123" - mock_docker.from_env.side_effect = Exception("connection error") - - handle = DockerJobHandle() - handle._set_container(orig_container) - assert handle._get_container() is None - - -# --------------------------------------------------------------------------- -# DockerJobLauncher -# --------------------------------------------------------------------------- -def _make_fl_ctx_for_docker_launch(): - fl_ctx = FLContext() - fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) - workspace_obj = Mock() - workspace_obj.get_app_custom_dir.return_value = "/custom/dir" - fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace_obj, private=True, sticky=False) - fl_ctx.set_prop( - FLContextKey.JOB_PROCESS_ARGS, - { - "exe_module": ("-m", "nvflare.private.fed.app.client.worker_process"), - "workspace": ("-w", "/workspace"), - }, - private=True, - sticky=False, - ) - return fl_ctx - - -def _make_docker_job_meta(image="nvflare/nvflare:test", job_id="job-123"): - return { - JobConstants.JOB_ID: job_id, - JobMetaKey.DEPLOY_MAP.value: {"app": [{"sites": ["client-1"], "image": image}]}, - } - - -class TestDockerJobLauncher: - def test_init_defaults(self): - launcher = DockerJobLauncher() - assert launcher.mount_path == "/workspace" - assert launcher.network == "nvflare-network" - assert launcher.timeout is None - - def test_init_custom(self): - launcher = DockerJobLauncher(mount_path="/custom", network="my-net", timeout=120) - assert launcher.mount_path == "/custom" - assert launcher.network == "my-net" - assert launcher.timeout == 120 - - # -- handle_event --------------------------------------------------------- - def test_handle_event_adds_launcher_when_image_present(self): - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - launcher = ClientDockerJobLauncher() - fl_ctx = FLContext() - job_meta = {JobMetaKey.DEPLOY_MAP.value: {"app": [{"sites": ["client-1"], "image": "nvflare/custom:latest"}]}} - fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) - fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) - - launcher.handle_event(EventType.BEFORE_JOB_LAUNCH, fl_ctx) - - launchers = fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) - assert launchers is not None - assert launcher in launchers - - def test_handle_event_skips_when_no_image(self): - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - launcher = ClientDockerJobLauncher() - fl_ctx = FLContext() - job_meta = {JobMetaKey.DEPLOY_MAP.value: {}} - fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) - fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) - - launcher.handle_event(EventType.BEFORE_JOB_LAUNCH, fl_ctx) - assert fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) is None - - def test_handle_event_ignores_other_events(self): - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - launcher = ClientDockerJobLauncher() - fl_ctx = FLContext() - launcher.handle_event(EventType.SYSTEM_START, fl_ctx) - assert fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) is None - - # -- launch_job ----------------------------------------------------------- - @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") - @patch("nvflare.app_opt.job_launcher.docker_launcher.os") - def test_launch_job_success(self, mock_os, mock_docker): - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - mock_os.environ.get.return_value = "/docker/workspace" - container = Mock() - container.status = DOCKER_STATE.RUNNING - mock_docker.from_env.return_value.containers.run.return_value = container - - launcher = ClientDockerJobLauncher(timeout=5) - fl_ctx = _make_fl_ctx_for_docker_launch() - job_meta = _make_docker_job_meta() - - with patch.object(DockerJobHandle, "enter_states", return_value=True): - handle = launcher.launch_job(job_meta, fl_ctx) - - assert handle is not None - mock_docker.from_env.return_value.containers.run.assert_called_once() - - @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") - @patch("nvflare.app_opt.job_launcher.docker_launcher.os") - def test_launch_job_returns_handle_on_enter_states_failure(self, mock_os, mock_docker): - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - mock_os.environ.get.return_value = "/docker/workspace" - container = Mock() - mock_docker.from_env.return_value.containers.run.return_value = container - - launcher = ClientDockerJobLauncher(timeout=1) - fl_ctx = _make_fl_ctx_for_docker_launch() - job_meta = _make_docker_job_meta() - - with patch.object(DockerJobHandle, "enter_states", return_value=False): - handle = launcher.launch_job(job_meta, fl_ctx) - - assert handle is not None - - @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") - @patch("nvflare.app_opt.job_launcher.docker_launcher.os") - def test_launch_job_terminates_on_enter_states_failure(self, mock_os, mock_docker): - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - mock_os.environ.get.return_value = "/docker/workspace" - container = Mock() - mock_docker.from_env.return_value.containers.run.return_value = container - - launcher = ClientDockerJobLauncher(timeout=1) - fl_ctx = _make_fl_ctx_for_docker_launch() - job_meta = _make_docker_job_meta() - - with patch.object(DockerJobHandle, "enter_states", return_value=False) as mock_enter: - handle = launcher.launch_job(job_meta, fl_ctx) - - container.stop.assert_called_once() - - @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") - @patch("nvflare.app_opt.job_launcher.docker_launcher.os") - def test_launch_job_returns_handle_on_enter_states_exception(self, mock_os, mock_docker): - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - mock_os.environ.get.return_value = "/docker/workspace" - container = Mock() - mock_docker.from_env.return_value.containers.run.return_value = container - - launcher = ClientDockerJobLauncher(timeout=1) - fl_ctx = _make_fl_ctx_for_docker_launch() - job_meta = _make_docker_job_meta() - - with patch.object(DockerJobHandle, "enter_states", side_effect=RuntimeError("boom")): - handle = launcher.launch_job(job_meta, fl_ctx) - - assert handle is not None - container.stop.assert_called_once() - - @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") - @patch("nvflare.app_opt.job_launcher.docker_launcher.os") - def test_launch_job_returns_handle_on_image_not_found(self, mock_os, mock_docker): - import docker as docker_pkg - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - mock_os.environ.get.return_value = "/docker/workspace" - mock_docker.from_env.return_value.containers.run.side_effect = docker_pkg.errors.ImageNotFound("not found") - mock_docker.errors = docker_pkg.errors - - launcher = ClientDockerJobLauncher() - fl_ctx = _make_fl_ctx_for_docker_launch() - job_meta = _make_docker_job_meta(image="bad/image:latest") - - handle = launcher.launch_job(job_meta, fl_ctx) - assert handle is not None - assert handle.container is None - - @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") - @patch("nvflare.app_opt.job_launcher.docker_launcher.os") - def test_launch_job_returns_handle_on_api_error(self, mock_os, mock_docker): - import docker as docker_pkg - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - mock_os.environ.get.return_value = "/docker/workspace" - mock_docker.from_env.return_value.containers.run.side_effect = docker_pkg.errors.APIError("api error") - mock_docker.errors = docker_pkg.errors - - launcher = ClientDockerJobLauncher() - fl_ctx = _make_fl_ctx_for_docker_launch() - job_meta = _make_docker_job_meta() - - handle = launcher.launch_job(job_meta, fl_ctx) - assert handle is not None - assert handle.container is None - - @patch("nvflare.app_opt.job_launcher.docker_launcher.docker") - @patch("nvflare.app_opt.job_launcher.docker_launcher.os") - def test_launch_job_empty_custom_folder_uses_pythonpath_only(self, mock_os, mock_docker): - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - mock_os.environ.get.return_value = "/docker/workspace" - container = Mock() - mock_docker.from_env.return_value.containers.run.return_value = container - - launcher = ClientDockerJobLauncher() - fl_ctx = FLContext() - fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) - workspace_obj = Mock() - workspace_obj.get_app_custom_dir.return_value = "" - fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace_obj, private=True, sticky=False) - fl_ctx.set_prop( - FLContextKey.JOB_PROCESS_ARGS, - { - "exe_module": ("-m", "worker"), - "workspace": ("-w", "/workspace"), - }, - private=True, - sticky=False, - ) - - job_meta = _make_docker_job_meta() - - with patch.object(DockerJobHandle, "enter_states", return_value=True): - handle = launcher.launch_job(job_meta, fl_ctx) - - call_kwargs = mock_docker.from_env.return_value.containers.run.call_args - command_str = call_kwargs[1]["command"] if "command" in call_kwargs[1] else call_kwargs[0][1] - assert "$PYTHONPATH" in command_str - assert "/custom" not in command_str - - -# --------------------------------------------------------------------------- -# ClientDockerJobLauncher.get_command -# --------------------------------------------------------------------------- -class TestClientDockerJobLauncher: - @patch("nvflare.app_opt.job_launcher.docker_launcher.generate_client_command") - def test_get_command(self, mock_gen_cmd): - from nvflare.app_opt.job_launcher.docker_launcher import ClientDockerJobLauncher - - mock_gen_cmd.return_value = "python -u -m worker_process -w /workspace" - launcher = ClientDockerJobLauncher() - fl_ctx = FLContext() - fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, "client-1", private=False, sticky=True) - job_meta = {JobConstants.JOB_ID: "job-abc"} - - name, cmd = launcher.get_command(job_meta, fl_ctx) - assert name == "client-1-job-abc" - assert cmd == "python -u -m worker_process -w /workspace" - - -# --------------------------------------------------------------------------- -# ServerDockerJobLauncher.get_command -# --------------------------------------------------------------------------- -class TestServerDockerJobLauncher: - @patch("nvflare.app_opt.job_launcher.docker_launcher.generate_server_command") - def test_get_command(self, mock_gen_cmd): - from nvflare.app_opt.job_launcher.docker_launcher import ServerDockerJobLauncher - - mock_gen_cmd.return_value = "python -u -m server_process -w /workspace" - launcher = ServerDockerJobLauncher() - fl_ctx = FLContext() - job_meta = {JobConstants.JOB_ID: "job-xyz"} - - name, cmd = launcher.get_command(job_meta, fl_ctx) - assert name == "server-job-xyz" - assert cmd == "python -u -m server_process -w /workspace" diff --git a/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py b/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py index 73b7303e8c..b1ff897255 100644 --- a/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py +++ b/tests/unit_test/app_opt/job_launcher/k8s_launcher_test.py @@ -52,7 +52,7 @@ def __init__(self, status=None, reason=None, http_resp=None): sys.modules.setdefault(_mod_name, _mod_obj) from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLContextKey, ReservedKey +from nvflare.apis.fl_constant import FLContextKey, JobConstants, ReservedKey from nvflare.apis.fl_context import FLContext from nvflare.apis.job_def import JobMetaKey from nvflare.apis.job_launcher_spec import JobProcessArgs, JobReturnCode @@ -66,6 +66,7 @@ def __init__(self, status=None, reason=None, http_resp=None): K8sJobHandle, POD_Phase, _job_args_dict, + uuid4_to_rfc1123, ) @@ -93,6 +94,57 @@ def _make_api_instance(): return MagicMock() +def _make_handle(job_id="job-1", api=None, cfg=None, **kwargs): + if api is None: + api = _make_api_instance() + if cfg is None: + cfg = _make_job_config() + handle = K8sJobHandle(job_id, api, cfg, **kwargs) + handle.job_id = job_id + return handle + + +# --------------------------------------------------------------------------- +# uuid4_to_rfc1123 +# --------------------------------------------------------------------------- +class TestUuid4ToRfc1123: + def test_lowercase(self): + assert uuid4_to_rfc1123("ABCD-1234") == "abcd-1234" + + def test_strips_invalid_chars(self): + assert uuid4_to_rfc1123("abc_def.ghi") == "abcdefghi" + + def test_prefixes_leading_digit(self): + result = uuid4_to_rfc1123("1234-abcd") + assert result[0] == "j" + assert result == "j1234-abcd" + + def test_strips_trailing_hyphens(self): + assert uuid4_to_rfc1123("abc-") == "abc" + + def test_strips_trailing_hyphen_exposed_by_truncation(self): + # 62 'a's followed by '-' followed by more chars: truncation exposes the hyphen + name = "a" * 62 + "-" + "b" * 10 + result = uuid4_to_rfc1123(name) + assert not result.endswith("-"), f"trailing hyphen in {result!r}" + assert len(result) == 62 + + def test_truncates_to_63_chars(self): + long_str = "a" * 100 + assert len(uuid4_to_rfc1123(long_str)) == 63 + + def test_typical_uuid_gets_prefixed(self): + result = uuid4_to_rfc1123("550e8400-e29b-41d4-a716-446655440000") + assert result == "j550e8400-e29b-41d4-a716-446655440000" + + def test_letter_leading_uuid_no_prefix(self): + result = uuid4_to_rfc1123("abcd1234-e29b-41d4-a716-446655440000") + assert result == "abcd1234-e29b-41d4-a716-446655440000" + + def test_empty_string(self): + assert uuid4_to_rfc1123("") == "" + + # --------------------------------------------------------------------------- # Mapping tables # --------------------------------------------------------------------------- @@ -150,15 +202,20 @@ def test_stuck_count_starts_at_zero(self): handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=30) assert handle._stuck_count == 0 - def test_max_stuck_count_includes_grace_period(self): + def test_max_stuck_count_equals_timeout(self): cfg = _make_job_config() handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=30) - assert handle._max_stuck_count == 30 + handle._stuck_grace_period + assert handle._max_stuck_count == 30 - def test_max_stuck_count_is_none_with_no_timeout(self): + def test_max_stuck_count_uses_pending_timeout_when_no_timeout(self): cfg = _make_job_config() handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=None) - assert handle._max_stuck_count is None + assert handle._max_stuck_count == 30 + + def test_max_stuck_count_uses_custom_pending_timeout(self): + cfg = _make_job_config() + handle = K8sJobHandle("job-1", _make_api_instance(), cfg, timeout=None, pending_timeout=60) + assert handle._max_stuck_count == 60 # -- manifest ------------------------------------------------------------- def test_manifest_metadata_name(self): @@ -178,12 +235,16 @@ def test_manifest_container_name(self): container = handle.get_manifest()["spec"]["containers"][0] assert container["name"] == "container-test-job-123" - def test_manifest_default_image(self): + def test_manifest_raises_on_missing_image(self): cfg = _make_job_config() del cfg["image"] - handle = K8sJobHandle("job-1", _make_api_instance(), cfg) - container = handle.get_manifest()["spec"]["containers"][0] - assert container["image"] == "nvflare/nvflare:2.8.0" + with pytest.raises(ValueError, match="image"): + K8sJobHandle("job-1", _make_api_instance(), cfg) + + def test_manifest_raises_on_empty_image(self): + cfg = _make_job_config(image="") + with pytest.raises(ValueError, match="image"): + K8sJobHandle("job-1", _make_api_instance(), cfg) def test_manifest_default_container_name(self): cfg = _make_job_config() @@ -251,6 +312,14 @@ def test_manifest_none_module_args_skipped(self): assert "keep" in args assert "-b" not in args + def test_manifest_non_string_module_arg_values_are_stringified(self): + cfg = _make_job_config(module_args={"-p": 8080, "-n": 42}) + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + args = handle.get_manifest()["spec"]["containers"][0]["args"] + assert "8080" in args + assert "42" in args + assert all(isinstance(a, str) for a in args), "all container args must be str" + def test_manifest_default_module_args_copies_dict(self): cfg = _make_job_config() cfg["module_args"] = None @@ -274,7 +343,27 @@ def test_manifest_no_gpu_resources(self): cfg = _make_job_config(resources={"limits": {"nvidia.com/gpu": None}}) handle = K8sJobHandle("job-1", _make_api_instance(), cfg) container = handle.get_manifest()["spec"]["containers"][0] - assert container["resources"] is None + assert "resources" not in container + + def test_manifest_no_resources_key(self): + cfg = _make_job_config() + del cfg["resources"] + handle = K8sJobHandle("job-1", _make_api_instance(), cfg) + container = handle.get_manifest()["spec"]["containers"][0] + assert "resources" not in container + + def test_get_manifest_returns_independent_copy(self): + handle = K8sJobHandle("job-1", _make_api_instance(), _make_job_config()) + manifest = handle.get_manifest() + # Mutate the returned copy at every mutable level + manifest["metadata"]["name"] = "MUTATED" + manifest["spec"]["containers"][0]["image"] = "MUTATED" + manifest["spec"]["volumes"].clear() + # Internal state must be unchanged + internal = handle.pod_manifest + assert internal["metadata"]["name"] == "test-job-123" + assert internal["spec"]["containers"][0]["image"] == "nvflare/nvflare:test" + assert len(internal["spec"]["volumes"]) > 0 # -- poll ----------------------------------------------------------------- def test_poll_returns_unknown_when_running(self): @@ -282,7 +371,7 @@ def test_poll_returns_unknown_when_running(self): resp = Mock() resp.status.phase = POD_Phase.RUNNING.value api.read_namespaced_pod.return_value = resp - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) assert handle.poll() == JobReturnCode.UNKNOWN def test_poll_returns_success_when_succeeded(self): @@ -290,7 +379,7 @@ def test_poll_returns_success_when_succeeded(self): resp = Mock() resp.status.phase = POD_Phase.SUCCEEDED.value api.read_namespaced_pod.return_value = resp - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) assert handle.poll() == JobReturnCode.SUCCESS def test_poll_returns_aborted_when_failed(self): @@ -298,12 +387,12 @@ def test_poll_returns_aborted_when_failed(self): resp = Mock() resp.status.phase = POD_Phase.FAILED.value api.read_namespaced_pod.return_value = resp - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) assert handle.poll() == JobReturnCode.ABORTED def test_poll_uses_terminal_state_if_set(self): api = _make_api_instance() - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) handle.terminal_state = JobState.TERMINATED assert handle.poll() == JobReturnCode.ABORTED api.read_namespaced_pod.assert_not_called() @@ -311,7 +400,7 @@ def test_poll_uses_terminal_state_if_set(self): # -- terminate ------------------------------------------------------------ def test_terminate_deletes_pod_and_sets_terminated(self): api = _make_api_instance() - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) handle.terminate() api.delete_namespaced_pod.assert_called_once_with(name="job-1", namespace="default", grace_period_seconds=0) assert handle.terminal_state == JobState.TERMINATED @@ -319,20 +408,27 @@ def test_terminate_deletes_pod_and_sets_terminated(self): def test_terminate_sets_terminated_on_404(self): api = _make_api_instance() api.delete_namespaced_pod.side_effect = _FakeApiException(status=404, reason="Not Found") - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) handle.terminate() assert handle.terminal_state == JobState.TERMINATED - def test_terminate_does_not_set_state_on_non_404_error(self): + def test_terminate_sets_terminated_on_non_404_api_error(self): api = _make_api_instance() api.delete_namespaced_pod.side_effect = _FakeApiException(status=500, reason="Internal") - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) handle.terminate() - assert handle.terminal_state is None + assert handle.terminal_state == JobState.TERMINATED + + def test_terminate_sets_terminated_on_network_error(self): + api = _make_api_instance() + api.delete_namespaced_pod.side_effect = ConnectionError("network unreachable") + handle = _make_handle(api=api) + handle.terminate() + assert handle.terminal_state == JobState.TERMINATED def test_terminate_custom_namespace(self): api = _make_api_instance() - handle = K8sJobHandle("job-1", api, _make_job_config(), namespace="custom-ns") + handle = _make_handle(api=api, namespace="custom-ns") handle.terminate() api.delete_namespaced_pod.assert_called_once_with(name="job-1", namespace="custom-ns", grace_period_seconds=0) @@ -340,38 +436,69 @@ def test_terminate_custom_namespace(self): def test_query_phase_returns_unknown_on_api_error(self): api = _make_api_instance() api.read_namespaced_pod.side_effect = _FakeApiException(status=500, reason="Error") - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) + assert handle._query_phase() == POD_Phase.UNKNOWN.value + + def test_query_phase_returns_unknown_on_generic_exception(self): + api = _make_api_instance() + api.read_namespaced_pod.side_effect = RuntimeError("connection lost") + handle = _make_handle(api=api) assert handle._query_phase() == POD_Phase.UNKNOWN.value - # -- _stuck --------------------------------------------------------------- - def test_stuck_returns_false_when_no_timeout_and_grace_only(self): + # -- _stuck_in_pending ---------------------------------------------------- + def test_stuck_in_pending_returns_true_at_max_count(self): + # With _stuck_count seeded to max-1, one more PENDING call increments to + # exactly max, which should fire (>=, not >). api = _make_api_instance() - handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=None) - assert handle._stuck(POD_Phase.PENDING.value) is False + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) + handle._stuck_count = handle._max_stuck_count - 1 + assert handle._stuck_in_pending(POD_Phase.PENDING.value) is True - def test_stuck_returns_true_after_max_count_with_grace(self): + def test_stuck_in_pending_returns_false_one_before_max(self): + # One iteration before the threshold must not fire. api = _make_api_instance() handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) - handle._stuck_count = handle._max_stuck_count - assert handle._stuck(POD_Phase.PENDING.value) is True + handle._stuck_count = handle._max_stuck_count - 2 + assert handle._stuck_in_pending(POD_Phase.PENDING.value) is False - def test_stuck_returns_false_for_non_pending(self): + def test_stuck_in_pending_returns_false_for_non_pending(self): api = _make_api_instance() handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) handle._stuck_count = 9999 - assert handle._stuck(POD_Phase.RUNNING.value) is False + assert handle._stuck_in_pending(POD_Phase.RUNNING.value) is False - def test_stuck_increments_count_on_pending(self): + def test_stuck_in_pending_resets_count_on_non_pending(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=5) + handle._stuck_count = 3 + handle._stuck_in_pending(POD_Phase.RUNNING.value) + assert handle._stuck_count == 0 + + def test_stuck_in_pending_increments_count(self): api = _make_api_instance() handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=100) initial = handle._stuck_count - handle._stuck(POD_Phase.PENDING.value) + handle._stuck_in_pending(POD_Phase.PENDING.value) assert handle._stuck_count == initial + 1 + def test_stuck_in_pending_returns_false_when_under_max(self): + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=None, pending_timeout=100) + assert handle._stuck_in_pending(POD_Phase.PENDING.value) is False + + def test_stuck_in_pending_never_fires_when_pending_timeout_none(self): + # pending_timeout=None with timeout=None → _max_stuck_count=None → stuck detection disabled + api = _make_api_instance() + handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=None, pending_timeout=None) + assert handle._max_stuck_count is None + # Drive _stuck_count very high — must not raise and must return False + handle._stuck_count = 10_000 + assert handle._stuck_in_pending(POD_Phase.PENDING.value) is False + # -- wait ----------------------------------------------------------------- def test_wait_returns_immediately_if_terminal_state_set(self): api = _make_api_instance() - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) handle.terminal_state = JobState.TERMINATED handle.wait() api.read_namespaced_pod.assert_not_called() @@ -381,7 +508,7 @@ def test_wait_persists_succeeded_terminal_state(self): resp = Mock() resp.status.phase = POD_Phase.SUCCEEDED.value api.read_namespaced_pod.return_value = resp - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) handle.wait() assert handle.terminal_state == JobState.SUCCEEDED @@ -390,7 +517,7 @@ def test_wait_persists_terminated_terminal_state(self): resp = Mock() resp.status.phase = POD_Phase.FAILED.value api.read_namespaced_pod.return_value = resp - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) handle.wait() assert handle.terminal_state == JobState.TERMINATED @@ -399,7 +526,7 @@ def test_wait_poll_consistent_after_wait(self): resp = Mock() resp.status.phase = POD_Phase.SUCCEEDED.value api.read_namespaced_pod.return_value = resp - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) handle.wait() assert handle.poll() == JobReturnCode.SUCCESS @@ -409,7 +536,7 @@ def test_enter_states_returns_true_when_state_matches(self): resp = Mock() resp.status.phase = POD_Phase.RUNNING.value api.read_namespaced_pod.return_value = resp - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) assert handle.enter_states([JobState.RUNNING]) is True def test_enter_states_accepts_single_state(self): @@ -417,7 +544,7 @@ def test_enter_states_accepts_single_state(self): resp = Mock() resp.status.phase = POD_Phase.SUCCEEDED.value api.read_namespaced_pod.return_value = resp - handle = K8sJobHandle("job-1", api, _make_job_config()) + handle = _make_handle(api=api) assert handle.enter_states(JobState.SUCCEEDED) is True def test_enter_states_returns_false_on_timeout(self): @@ -425,15 +552,105 @@ def test_enter_states_returns_false_on_timeout(self): resp = Mock() resp.status.phase = POD_Phase.PENDING.value api.read_namespaced_pod.return_value = resp - handle = K8sJobHandle("job-1", api, _make_job_config(), timeout=None) - assert handle.enter_states([JobState.RUNNING], timeout=0) is False + handle = _make_handle(api=api, timeout=0) + assert handle.enter_states([JobState.RUNNING]) is False - def test_enter_states_raises_on_invalid_state(self): + def test_enter_states_terminates_on_timeout(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.PENDING.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api, timeout=0) + handle.enter_states([JobState.RUNNING]) + api.delete_namespaced_pod.assert_called_once() + assert handle.terminal_state == JobState.TERMINATED + + def test_enter_states_returns_false_on_terminal_pod_phase(self): api = _make_api_instance() - handle = K8sJobHandle("job-1", api, _make_job_config()) + resp = Mock() + resp.status.phase = POD_Phase.FAILED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + assert handle.enter_states([JobState.RUNNING]) is False + assert handle.terminal_state == JobState.TERMINATED + + def test_enter_states_returns_false_on_succeeded_when_waiting_for_running(self): + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api) + assert handle.enter_states([JobState.RUNNING]) is False + assert handle.terminal_state == JobState.SUCCEEDED + + def test_enter_states_raises_on_invalid_state(self): + handle = _make_handle() with pytest.raises(ValueError, match="expect job_states_to_enter"): handle.enter_states(["not_a_state"]) + # -- enter_states: wall-clock timeout branch ------------------------------ + # The pod is placed in UNKNOWN phase (non-pending so stuck detection does + # not fire, non-terminal so the terminal-phase branch is skipped) and + # time.time is mocked so the elapsed-time check fires on the first + # iteration. This is the branch that existing tests miss because they + # use timeout=0 with a PENDING pod, which hits stuck detection instead. + + @patch("nvflare.app_opt.job_launcher.k8s_launcher.time") + def test_enter_states_wall_clock_timeout_returns_false(self, mock_time): + mock_time.time.side_effect = [0.0, 100.0] # start=0, check=100 → 100 > timeout=10 + mock_time.sleep = Mock() + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.UNKNOWN.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api, timeout=10) + assert handle.enter_states([JobState.RUNNING]) is False + + @patch("nvflare.app_opt.job_launcher.k8s_launcher.time") + def test_enter_states_wall_clock_timeout_calls_terminate_and_sets_terminal_state(self, mock_time): + mock_time.time.side_effect = [0.0, 100.0] + mock_time.sleep = Mock() + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.UNKNOWN.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api, timeout=10) + handle.enter_states([JobState.RUNNING]) + api.delete_namespaced_pod.assert_called_once() + assert handle.terminal_state == JobState.TERMINATED + + @patch("nvflare.app_opt.job_launcher.k8s_launcher.time") + def test_enter_states_wall_clock_not_fired_when_timeout_none(self, mock_time): + # With timeout=None the wall-clock guard (self.timeout is not None) is + # unconditionally False; the loop exits through the terminal-phase path. + mock_time.time.return_value = 9999.0 + mock_time.sleep = Mock() + api = _make_api_instance() + resp = Mock() + resp.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.return_value = resp + handle = _make_handle(api=api, timeout=None) + handle.enter_states([JobState.RUNNING]) + api.delete_namespaced_pod.assert_not_called() + + @patch("nvflare.app_opt.job_launcher.k8s_launcher.time") + def test_enter_states_wall_clock_not_fired_before_elapsed(self, mock_time): + # First iteration: time not yet elapsed → wall-clock skipped, loop continues. + # Second iteration: pod completes → exits via terminal-phase path, no terminate(). + mock_time.time.side_effect = [0.0, 0.5] # start=0, first check=0.5 < timeout=10 + mock_time.sleep = Mock() + api = _make_api_instance() + resp_unknown = Mock() + resp_unknown.status.phase = POD_Phase.UNKNOWN.value + resp_succeeded = Mock() + resp_succeeded.status.phase = POD_Phase.SUCCEEDED.value + api.read_namespaced_pod.side_effect = [resp_unknown, resp_succeeded] + handle = _make_handle(api=api, timeout=10) + result = handle.enter_states([JobState.RUNNING]) + assert result is False + api.delete_namespaced_pod.assert_not_called() + assert handle.terminal_state == JobState.SUCCEEDED + # --------------------------------------------------------------------------- # _job_args_dict helper @@ -588,6 +805,27 @@ def test_init_raises_on_empty_pvc_file(self): finally: _exit_patches(patches) + def test_init_raises_on_non_dict_pvc_file(self): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + patches = _make_k8s_launcher_patches() + mock_cfg, mock_conf, mock_core, mock_open, mock_yaml = _enter_patches(patches) + try: + mock_yaml.safe_load.return_value = "not-a-dict" + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + + with pytest.raises(ValueError, match="dictionary"): + ClientK8sJobLauncher( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + ) + finally: + _exit_patches(patches) + # --------------------------------------------------------------------------- # ClientK8sJobLauncher.get_module_args @@ -664,3 +902,273 @@ def test_raises_when_no_args(self): launcher.get_module_args("job-1", fl_ctx) finally: _exit_patches(patches) + + +# --------------------------------------------------------------------------- +# K8sJobLauncher launch_job — integration-style happy path +# --------------------------------------------------------------------------- + +_WORKER_MODULE = "nvflare.private.fed.app.client.worker_process" +_JOB_UUID = "550e8400-e29b-41d4-a716-446655440000" +_EXPECTED_JOB_ID = uuid4_to_rfc1123(_JOB_UUID) + + +def _make_launch_job_meta(site_name="site-1", image="nvflare/nvflare:latest", gpu=None): + meta = { + JobConstants.JOB_ID: _JOB_UUID, + JobMetaKey.DEPLOY_MAP.value: {"app": [{"sites": [site_name], "image": image}]}, + } + if gpu is not None: + meta[JobMetaKey.RESOURCE_SPEC.value] = {site_name: {"num_of_gpus": gpu}} + return meta + + +def _make_launch_fl_ctx(site_name="site-1", set_items=None): + fl_ctx = FLContext() + fl_ctx.set_prop(ReservedKey.IDENTITY_NAME, site_name, private=False, sticky=True) + job_args = { + JobProcessArgs.EXE_MODULE: ("-m", _WORKER_MODULE), + JobProcessArgs.WORKSPACE: ("-w", "/var/tmp/nvflare/workspace"), + JobProcessArgs.JOB_ID: ("-n", "job-abc"), + } + fl_ctx.set_prop(FLContextKey.JOB_PROCESS_ARGS, job_args, private=True, sticky=False) + if set_items is not None: + args_obj = Mock() + args_obj.set = set_items + fl_ctx.set_prop(FLContextKey.ARGS, args_obj, private=False, sticky=False) + return fl_ctx + + +class TestK8sJobLauncherLaunchJob: + """Integration-style tests that exercise the full launch_job() code path. + + The kubernetes API is mocked but the rest of the code — uuid sanitization, + manifest construction, enter_states polling, and handle construction — runs + for real. read_namespaced_pod is primed to return Running immediately so + enter_states returns True on the first iteration without sleeping. + """ + + def _setup(self, patches, namespace="test-ns"): + from nvflare.app_opt.job_launcher.k8s_launcher import ClientK8sJobLauncher + + mock_cfg, mock_conf, mock_core_module, mock_open, mock_yaml = _enter_patches(patches) + mock_yaml.safe_load.return_value = {"data-pvc": "/data"} + mock_conf_instance = MagicMock() + mock_conf.return_value = mock_conf_instance + mock_conf.get_default_copy = Mock(return_value=mock_conf_instance) + mock_api = MagicMock() + mock_core_module.CoreV1Api.return_value = mock_api + launcher = ClientK8sJobLauncher( + config_file_path="/fake/kube/config", + workspace_pvc="ws-pvc", + etc_pvc="etc-pvc", + data_pvc_file_path="/fake/data_pvc.yaml", + namespace=namespace, + ) + return launcher, mock_api + + def _prime_running(self, mock_api): + resp = Mock() + resp.status.phase = POD_Phase.RUNNING.value + mock_api.read_namespaced_pod.return_value = resp + + # -- return value --------------------------------------------------------- + + def test_returns_k8s_job_handle(self): + from nvflare.app_opt.job_launcher.k8s_launcher import K8sJobHandle + + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + handle = launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + assert isinstance(handle, K8sJobHandle) + finally: + _exit_patches(patches) + + def test_terminal_state_is_none_after_clean_launch(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + handle = launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + assert handle.terminal_state is None + finally: + _exit_patches(patches) + + # -- API call ------------------------------------------------------------- + + def test_create_namespaced_pod_called_once(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + mock_api.create_namespaced_pod.assert_called_once() + finally: + _exit_patches(patches) + + def test_create_namespaced_pod_uses_launcher_namespace(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches, namespace="prod-ns") + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + assert mock_api.create_namespaced_pod.call_args.kwargs["namespace"] == "prod-ns" + finally: + _exit_patches(patches) + + # -- pod manifest: identity fields ---------------------------------------- + + def test_pod_manifest_name_is_rfc1123_of_job_id(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert manifest["metadata"]["name"] == _EXPECTED_JOB_ID + finally: + _exit_patches(patches) + + def test_pod_manifest_image_from_job_meta(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(image="myrepo/myimage:v2"), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert manifest["spec"]["containers"][0]["image"] == "myrepo/myimage:v2" + finally: + _exit_patches(patches) + + def test_pod_manifest_restart_policy_never(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert manifest["spec"]["restartPolicy"] == "Never" + finally: + _exit_patches(patches) + + # -- pod manifest: container args ----------------------------------------- + + def test_pod_manifest_container_args_contain_worker_module(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + args = manifest["spec"]["containers"][0]["args"] + assert "-u" in args + assert "-m" in args + assert _WORKER_MODULE in args + finally: + _exit_patches(patches) + + def test_pod_manifest_set_list_propagated_from_args(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + fl_ctx = _make_launch_fl_ctx(set_items=["lr=0.01", "epochs=5"]) + launcher.launch_job(_make_launch_job_meta(), fl_ctx) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + args = manifest["spec"]["containers"][0]["args"] + assert "--set" in args + assert "lr=0.01" in args + assert "epochs=5" in args + finally: + _exit_patches(patches) + + def test_pod_manifest_no_set_list_when_args_not_set(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + args = manifest["spec"]["containers"][0]["args"] + assert "--set" not in args + finally: + _exit_patches(patches) + + # -- pod manifest: volumes ------------------------------------------------ + + def test_pod_manifest_pvcs_use_launcher_pvc_names(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + claims = {v["name"]: v["persistentVolumeClaim"]["claimName"] for v in manifest["spec"]["volumes"]} + assert claims[PV_NAME.WORKSPACE.value] == "ws-pvc" + assert claims[PV_NAME.DATA.value] == "data-pvc" + assert claims[PV_NAME.ETC.value] == "etc-pvc" + finally: + _exit_patches(patches) + + # -- pod manifest: GPU resources ------------------------------------------ + + def test_pod_manifest_includes_gpu_limit_when_specified(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(gpu=2), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert manifest["spec"]["containers"][0]["resources"]["limits"]["nvidia.com/gpu"] == 2 + finally: + _exit_patches(patches) + + def test_pod_manifest_omits_gpu_limit_when_not_specified(self): + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + self._prime_running(mock_api) + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + manifest = mock_api.create_namespaced_pod.call_args.kwargs["body"] + assert "resources" not in manifest["spec"]["containers"][0] + finally: + _exit_patches(patches) + + # -- create_namespaced_pod failure paths ---------------------------------- + + def test_network_error_on_create_returns_handle_with_terminal_state(self): + """Non-ApiException (e.g. network timeout) must not leave terminal_state=None.""" + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + mock_api.create_namespaced_pod.side_effect = ConnectionError("network unreachable") + try: + handle = launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + assert isinstance(handle, K8sJobHandle) + assert handle.terminal_state == JobState.TERMINATED + finally: + _exit_patches(patches) + + def test_network_error_on_create_does_not_call_terminate_api(self): + """Pod was never created; no delete API call should be made.""" + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + mock_api.create_namespaced_pod.side_effect = ConnectionError("network unreachable") + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + mock_api.delete_namespaced_pod.assert_not_called() + finally: + _exit_patches(patches) + + def test_network_error_on_create_does_not_call_enter_states(self): + """enter_states must not be reached when pod creation fails.""" + patches = _make_k8s_launcher_patches() + launcher, mock_api = self._setup(patches) + mock_api.create_namespaced_pod.side_effect = ConnectionError("network unreachable") + try: + launcher.launch_job(_make_launch_job_meta(), _make_launch_fl_ctx()) + # read_namespaced_pod is the backing call for _query_phase / enter_states + mock_api.read_namespaced_pod.assert_not_called() + finally: + _exit_patches(patches)