Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
581 changes: 581 additions & 0 deletions docs/design/JobLauncher_and_JobHandle.md

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions nvflare/apis/job_launcher_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from abc import ABC, abstractmethod

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
Expand Down Expand Up @@ -56,12 +56,12 @@ def add_launcher(launcher, fl_ctx: FLContext):
fl_ctx.set_prop(FLContextKey.JOB_LAUNCHER, job_launcher, private=True, sticky=False)


class JobHandleSpec:
class JobHandleSpec(ABC):
@abstractmethod
def terminate(self):
"""To terminate the job run.

Returns: the job run return code.
Returns: None

"""
raise NotImplementedError()
Expand All @@ -85,7 +85,7 @@ def wait(self):
raise NotImplementedError()


class JobLauncherSpec(FLComponent):
class JobLauncherSpec(FLComponent, ABC):
@abstractmethod
def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
"""To launch a job run.
Expand All @@ -94,7 +94,7 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
job_meta: job metadata
fl_ctx: FLContext

Returns: boolean to indicates the job launch success or fail.
Returns: a JobHandle instance.

"""
raise NotImplementedError()
21 changes: 21 additions & 0 deletions nvflare/app_common/resource_consumers/BE_resource_consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.resource_manager_spec import ResourceConsumerSpec


class BEResourceConsumer(ResourceConsumerSpec):

def consume(self, resources: dict):
pass
50 changes: 50 additions & 0 deletions nvflare/app_common/resource_managers/BE_resource_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.apis.resource_manager_spec import ResourceManagerSpec


class BEResourceManager(ResourceManagerSpec, FLComponent):
def __init__(self):
"""Best Effort Resource Manager implementation.

It will accept all resource allocation requests
and let the job fail when the requested resources are unavailable
at runtime.

"""
super().__init__()

def check_resources(self, resource_requirement: dict, fl_ctx: FLContext):
if not isinstance(resource_requirement, dict):
raise TypeError(f"resource_requirement should be of type dict, but got {type(resource_requirement)}.")

token = str(uuid.uuid4())
return True, token

def cancel_resources(self, resource_requirement: dict, token: str, fl_ctx: FLContext):
return None

def allocate_resources(self, resource_requirement: dict, token: str, fl_ctx: FLContext) -> dict:
return {}

def free_resources(self, resources: dict, token: str, fl_ctx: FLContext):
pass

def report_resources(self, fl_ctx):
return {}
30 changes: 19 additions & 11 deletions nvflare/app_common/resource_managers/gpu_resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
num_gpu_key: str = "num_of_gpus",
gpu_mem_key: str = "mem_per_gpu_in_GiB",
expiration_period: Union[int, float] = 30,
ignore_host: bool = False,
):
"""Resource manager for GPUs.

Expand All @@ -46,6 +47,9 @@ def __init__(
expiration_period: Number of seconds to hold the resources reserved.
If check_resources is called but after "expiration_period" no allocate resource is called,
then the reserved resources will be released.
ignore_host: Whether to skip validation against GPUs present on the local host. Set to True in
environments where the NVFlare process runs on a node without GPUs (for example, some
Kubernetes deployments) but GPU resources are managed externally.
"""
if not isinstance(num_of_gpus, int):
raise ValueError(f"num_of_gpus should be of type int, but got {type(num_of_gpus)}.")
Expand All @@ -62,17 +66,21 @@ def __init__(
if expiration_period < 0:
raise ValueError("expiration_period should be greater than or equal to 0.")

if num_of_gpus > 0:
num_host_gpus = len(get_host_gpu_ids())
if num_of_gpus > num_host_gpus:
raise ValueError(f"num_of_gpus specified ({num_of_gpus}) exceeds available GPUs: {num_host_gpus}.")

host_gpu_mem = get_host_gpu_memory_total()
for i in host_gpu_mem:
if mem_per_gpu_in_GiB * 1024 > i:
raise ValueError(
f"Memory per GPU specified ({mem_per_gpu_in_GiB * 1024}) exceeds available GPU memory: {i}."
)
if not isinstance(ignore_host, bool):
raise ValueError(f"ignore_host should be of type bool, but got {type(ignore_host)}.")

if not ignore_host:
if num_of_gpus > 0:
num_host_gpus = len(get_host_gpu_ids())
if num_of_gpus > num_host_gpus:
raise ValueError(f"num_of_gpus specified ({num_of_gpus}) exceeds available GPUs: {num_host_gpus}.")

host_gpu_mem = get_host_gpu_memory_total()
for i in host_gpu_mem:
if mem_per_gpu_in_GiB * 1024 > i:
raise ValueError(
f"Memory per GPU specified ({mem_per_gpu_in_GiB * 1024}) exceeds available GPU memory: {i}."
)

self.num_gpu_key = num_gpu_key
self.gpu_mem_key = gpu_mem_key
Expand Down
Loading
Loading