diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 8f9b7ec8fc..881bb92004 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -197,6 +197,7 @@ class FLContextKey(object): EVENT_PROCESSED = "__event_processed__" CELL_MESSAGE = "__cell_message__" CLIENT_HIERARCHY = "__client_hierarchy__" + FOX_MODE = "__fox_mode__" class ProcessType: diff --git a/nvflare/apis/fl_exception.py b/nvflare/apis/fl_exception.py index 3bc5f13fbc..81d51d04ea 100644 --- a/nvflare/apis/fl_exception.py +++ b/nvflare/apis/fl_exception.py @@ -66,3 +66,9 @@ class NotReadyToEndRun(Exception): """Raised when a component is not ready to end run""" pass + + +class RunAborted(Exception): + """Raised when a run is aborted""" + + pass diff --git a/nvflare/app_common/workflows/lr/fedavg.py b/nvflare/app_common/workflows/lr/fedavg.py index f807ae7aa0..7904aa1b84 100644 --- a/nvflare/app_common/workflows/lr/fedavg.py +++ b/nvflare/app_common/workflows/lr/fedavg.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ - Federated Averaging for Logistic Regression with Newton-Raphson method - using Numpy +Federated Averaging for Logistic Regression with Newton-Raphson method +using Numpy """ from typing import List, Optional diff --git a/nvflare/fox/__init__.py b/nvflare/fox/__init__.py new file mode 100644 index 0000000000..8afd3df8ce --- /dev/null +++ b/nvflare/fox/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .api.facade import facade as fox diff --git a/nvflare/fox/api/__init__.py b/nvflare/fox/api/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/fox/api/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fox/api/app.py b/nvflare/fox/api/app.py new file mode 100644 index 0000000000..282f816a84 --- /dev/null +++ b/nvflare/fox/api/app.py @@ -0,0 +1,448 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import fnmatch +import os +import re +from typing import List + +from nvflare.fuel.utils.log_utils import get_obj_logger +from nvflare.fuel.utils.tree_utils import Forest, Node, build_forest + +from .constants import BackendType, CollabMethodArgName, ContextKey, FilterDirection +from .ctx import Context, set_call_context +from .dec import ( + collab, + get_object_algo_funcs, + get_object_collab_interface, + get_object_final_funcs, + get_object_init_funcs, + is_collab, + supports_context, +) +from .filter import CallFilter, FilterChain, ResultFilter +from .proxy import Proxy +from .utils import check_context_support, get_collab_object_name +from .workspace import Workspace + + +class App: + + def __init__(self, obj, name: str): + self.obj = obj + self.name = name + self._fqn = None + self._server_proxy = None + self._client_proxies = None + self._client_hierarchy = None + self._backend_type = None + self._me = None + self._collab_objs = {} + self._abort_signal = None + self._props = {} + self._event_handlers = {} # event type => list of (cb, kwargs) + self._incoming_call_filter_chains = [] + self._outgoing_call_filter_chains = [] + self._incoming_result_filter_chains = [] + self._outgoing_result_filter_chains = [] + self._workspace = None + self._resource_dirs = {} + self._managed_objects = {} # id => obj + self.logger = get_obj_logger(self) + self._collab_interface = {"": get_object_collab_interface(self)} + self.add_collab_object(name, obj) + + def set_resource_dirs(self, resource_dirs: dict[str, str]): + if not isinstance(resource_dirs, dict): + raise TypeError(f"resource_dirs must be a dict but got {type(resource_dirs)}") + for name, resource_dir in resource_dirs.items(): + if not os.path.isdir(resource_dir): + raise ValueError(f"Resource dir {resource_dir} does not exist for {name}") + self._resource_dirs = resource_dirs + + def get_resource_dirs(self): + return self._resource_dirs + + def _add_managed_object(self, obj): + self._managed_objects[id(obj)] = obj + + def set_fqn(self, fqn): + self._fqn = fqn + + @property + def fqn(self): + return self._fqn + + @property + def backend(self): + if not self._me: + return None + else: + return self._me.backend + + @property + def backend_type(self): + return self._backend_type + + def set_backend_type(self, t: str): + valid_types = [BackendType.SIMULATION, BackendType.FLARE] + if t not in valid_types: + raise ValueError(f"bad backend type: {t}: must be one of {valid_types}") + self._backend_type = t + + @property + def workspace(self): + return self._workspace + + @property + def server_proxy(self): + return self._server_proxy + + @property + def client_proxies(self): + return copy.copy(self._client_proxies) + + @property + def client_hierarchy(self): + return self._client_hierarchy + + def _add_filters(self, pattern: str, filters, to_list: list, filter_type, incoming): + if not filters: + return + + if not isinstance(filters, list): + raise ValueError(f"filters must be a list but got {type(filters)}") + + filter_objs = [] + for f in filters: + if not isinstance(f, filter_type): + # convert to proper filter type + filter_obj = filter_type(f, incoming) + else: + filter_obj = f + + filter_objs.append(filter_obj) + + # f is a managed object, but the filter_obj (if wrapped) is not! + self._add_managed_object(f) + + chain = FilterChain(pattern, filter_type) + chain.add_filters(filter_objs) + to_list.append(chain) + + def add_incoming_call_filters(self, pattern: str, filters: List[object]): + self._add_filters(pattern, filters, self._incoming_call_filter_chains, CallFilter, True) + + def get_incoming_call_filters(self): + return self._incoming_call_filter_chains + + def add_outgoing_call_filters(self, pattern: str, filters: List[object]): + self._add_filters(pattern, filters, self._outgoing_call_filter_chains, CallFilter, False) + + def get_outgoing_call_filters(self): + return self._outgoing_call_filter_chains + + def add_incoming_result_filters(self, pattern: str, filters: List[object]): + self._add_filters(pattern, filters, self._incoming_result_filter_chains, ResultFilter, True) + + def get_incoming_result_filters(self): + return self._incoming_result_filter_chains + + def add_outgoing_result_filters(self, pattern: str, filters: List[object]): + self._add_filters(pattern, filters, self._outgoing_result_filter_chains, ResultFilter, False) + + def get_outgoing_result_filters(self): + return self._outgoing_result_filter_chains + + @staticmethod + def _find_filter_chain(direction, chains: List[FilterChain], target_name: str, func_name: str, ctx: Context): + """ + + Args: + chains: + target_name: + func_name: + + Returns: + + """ + collab_obj_name = get_collab_object_name(target_name) + qualified_func_name = f"{collab_obj_name}.{func_name}" + ctx.set_prop(ContextKey.QUALIFIED_FUNC_NAME, qualified_func_name) + ctx.set_prop(ContextKey.DIRECTION, direction) + + if not chains: + return None + + for c in chains: + if fnmatch.fnmatch(qualified_func_name, c.pattern): + return c + return None + + def apply_incoming_call_filters(self, target_name: str, func_name: str, func_kwargs, context: Context): + filter_chain = self._find_filter_chain( + FilterDirection.INCOMING, self._incoming_call_filter_chains, target_name, func_name, context + ) + if filter_chain: + return filter_chain.apply_filters(func_kwargs, context) + else: + return func_kwargs + + def apply_outgoing_call_filters(self, target_name: str, func_name: str, func_kwargs, context: Context): + filter_chain = self._find_filter_chain( + FilterDirection.OUTGOING, self._outgoing_call_filter_chains, target_name, func_name, context + ) + if filter_chain: + return filter_chain.apply_filters(func_kwargs, context) + else: + return func_kwargs + + def apply_incoming_result_filters(self, target_name: str, func_name: str, result, context: Context): + filter_chain = self._find_filter_chain( + FilterDirection.INCOMING, self._incoming_result_filter_chains, target_name, func_name, context + ) + if filter_chain: + return filter_chain.apply_filters(result, context) + else: + return result + + def apply_outgoing_result_filters(self, target_name: str, func_name: str, result, context: Context): + filter_chain = self._find_filter_chain( + FilterDirection.OUTGOING, self._outgoing_result_filter_chains, target_name, func_name, context + ) + if filter_chain: + return filter_chain.apply_filters(result, context) + else: + return result + + def set_prop(self, name: str, value): + self._props[name] = value + + def get_prop(self, name: str, default=None): + return self._props.get(name, default) + + def get_props(self): + return self._props + + def update_props(self, props: dict): + if isinstance(props, dict): + self._props.update(props) + + def add_collab_object(self, name: str, obj): + # name must be acceptable str + pattern = r"^[A-Za-z][A-Za-z0-9_]*$" + if not re.match(pattern, name): + raise ValueError( + f"invalid name {name} for collab object - must start with a letter, " + "followed by one or more alphanumeric and/or underscore chars" + ) + + if name in self._collab_objs: + raise ValueError(f"conflict with existing collab object '{name}' of {type(self._collab_objs[name])}") + + if hasattr(self, name): + raise ValueError(f"conflict with reserved name {name}") + + setattr(self, name, obj) + self._collab_objs[name] = obj + self._collab_interface[name] = get_object_collab_interface(obj) + self._add_managed_object(obj) + + def get_collab_objects(self): + return self._collab_objs + + def setup(self, workspace: Workspace, server: Proxy, clients: List[Proxy], abort_signal): + self._workspace = workspace + workspace.resource_dirs = self._resource_dirs + + self._server_proxy = server + self._abort_signal = abort_signal + + self._client_proxies = clients + self._me = None + if not self.name or self.name == "server": + self._me = server + else: + for c in clients: + if c.name == self.name: + self._me = c + break + + if not self._me: + raise ValueError(f"cannot find site for {self.name}") + + forest = build_forest(objs=clients, get_fqn_f=lambda c: c.fqn, get_name_f=lambda c: c.name) + self._client_hierarchy = forest + + @property + def my_site(self) -> Proxy: + return self._me + + def find_method(self, target_obj, method_name): + m = getattr(target_obj, method_name, None) + if m: + return m + + if isinstance(target_obj, App): + # see whether any targets have this method + default_target = self.obj + m = getattr(default_target, method_name, None) + if m: + return m + + targets = self.get_collab_objects() + for _, obj in targets.items(): + m = getattr(obj, method_name, None) + if m: + return m + return None + + def find_collab_method(self, target_obj, method_name): + m = self.find_method(target_obj, method_name) + if m and is_collab(m): + return m + return None + + def _fox_init(self, obj, ctx: Context): + init_funcs = get_object_init_funcs(obj) + for name, f in init_funcs: + self.logger.debug(f"calling init func {name} ...") + if supports_context(f): + kwargs = {CollabMethodArgName.CONTEXT: ctx} + else: + kwargs = {} + f(**kwargs) + + def initialize(self, context: Context): + self._fox_init(self, context) + + # initialize target objects + for obj in self._managed_objects.values(): + self._fox_init(obj, context) + + def _fox_finalize(self, obj, ctx: Context): + funcs = get_object_final_funcs(obj) + for name, f in funcs: + self.logger.debug(f"calling final func {name} ...") + if supports_context(f): + kwargs = {CollabMethodArgName.CONTEXT: ctx} + else: + kwargs = {} + f(**kwargs) + + def finalize(self, context: Context): + self._fox_finalize(self, context) + + # finalize target objects + for obj in self._managed_objects.values(): + self._fox_finalize(obj, context) + + def new_context(self, caller: str, callee: str, target_group=None, set_call_ctx=True): + ctx = Context(self, caller, callee, self._abort_signal, target_group=target_group) + if set_call_ctx: + set_call_context(ctx) + return ctx + + def register_event_handler(self, event_type: str, handler, **handler_kwargs): + handlers = self._event_handlers.get(event_type) + if not handlers: + handlers = [] + self._event_handlers[event_type] = handlers + handlers.append((handler, handler_kwargs)) + self.logger.debug(f"registered event handler {handler.__qualname__} for {event_type=}") + + def get_collab_interface(self): + return self._collab_interface + + def get_target_object_collab_interface(self, target_name: str): + if not target_name or target_name.lower() == "app": + return self._collab_interface.get("") + else: + return self._collab_interface.get(target_name) + + @collab + def fire_event(self, event_type: str, data, context: Context): + result = {} + for e, handlers in self._event_handlers.items(): + if e == event_type: + for h, kwargs in handlers: + kwargs = copy.copy(kwargs) + kwargs.update({CollabMethodArgName.CONTEXT: context}) + check_context_support(h, kwargs) + result[h.__qualname__] = h(event_type, data, **kwargs) + return result + + def get_children(self): + return [] + + def has_children(self): + return False + + def get_leaf_clients(self): + if not isinstance(self._client_hierarchy, Forest): + raise RuntimeError(f"client_hierarchy must be Forest but got {type(self._client_hierarchy)}") + leaf_nodes = [self._client_hierarchy.nodes[n] for n in self._client_hierarchy.leaves] + return [node.obj for node in leaf_nodes] + + +class ServerApp(App): + + def __init__(self, obj, name: str = "server"): + if not obj: + raise ValueError("server object must be specified") + super().__init__(obj, name) + self.algos = get_object_algo_funcs(obj) + if not self.algos: + raise ValueError("server object must have at least one algo") + + def get_children(self): + if not isinstance(self._client_hierarchy, Forest): + raise RuntimeError( + f"client_hierarchy in app {self.name} must be Forest but got {type(self._client_hierarchy)}" + ) + root_nodes = [self._client_hierarchy.nodes[n] for n in self._client_hierarchy.roots] + return [node.obj for node in root_nodes] + + def has_children(self): + return True + + +class ClientApp(App): + + def __init__(self, obj, name: str = "client"): + if not obj: + raise ValueError("client object must be specified") + super().__init__(obj, name) + + def _get_my_node(self): + if not isinstance(self._client_hierarchy, Forest): + raise RuntimeError( + f"client_hierarchy in app {self.name} must be Forest but got {type(self._client_hierarchy)}" + ) + + node = self._client_hierarchy.nodes.get(self.name) + if not isinstance(node, Node): + raise RuntimeError(f"node for site {self.name} must be a Node but got {type(node)}") + return node + + def get_children(self): + my_node = self._get_my_node() + if my_node.children: + return [node.obj for node in my_node.children] + else: + return [] + + def has_children(self): + my_node = self._get_my_node() + return True if my_node.children else False diff --git a/nvflare/fox/api/backend.py b/nvflare/fox/api/backend.py new file mode 100644 index 0000000000..5699de4009 --- /dev/null +++ b/nvflare/fox/api/backend.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from nvflare.apis.signal import Signal +from nvflare.fuel.utils.log_utils import get_obj_logger +from nvflare.security.logging import secure_format_traceback + +from .call_opt import CallOption +from .ctx import Context +from .gcc import GroupCallContext + + +class Backend(ABC): + """A FOX Backend implements remote object calls. This interface defines the required methods that a Backend + must implement. + """ + + def __init__(self, abort_signal: Signal): + self.abort_signal = abort_signal + self.logger = get_obj_logger(self) + + @abstractmethod + def call_target(self, context: Context, target_name: str, call_opt: CallOption, func_name: str, *args, **kwargs): + """ + Call a target function with arguments and return a result. + + Args: + context: the call context + target_name: the fully qualified name of the target object to be called in the remote app. + call_opt: call options. + func_name: name of the function to be called in the remote app. + *args: args to pass to the target function. + **kwargs: kwargs to pass to the target function. + + Notes: the target name is fully qualified: . + + Returns: + + """ + pass + + @abstractmethod + def call_target_in_group(self, gcc: GroupCallContext, func_name: str, *args, **kwargs): + """Call a remote object as part of a group. + + Args: + gcc: contextual information about group call. + func_name: name of the function to be called in the remote app. + *args: args to pass to the target function. + **kwargs: kwargs to pass to the target function. + + Returns: + + """ + pass + + def handle_exception(self, exception: Exception): + self.logger.error(f"exception occurred: {secure_format_traceback()}") + raise exception diff --git a/nvflare/fox/api/call_opt.py b/nvflare/fox/api/call_opt.py new file mode 100644 index 0000000000..f927d17cff --- /dev/null +++ b/nvflare/fox/api/call_opt.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +class CallOption: + + def __init__( + self, + expect_result: bool = True, + blocking: bool = True, + timeout: float = 5.0, + secure: bool = False, + optional: bool = False, + target=None, + parallel=0, + ): + """CallOption defines behavior of a collab call. + + Args: + expect_result: whether result is expected from the remote object. + blocking: whether the call is blocking. Only for group calls. + timeout: when expecting result, the max number of secs to wait for result. + secure: whether to use P2P secure messaging. + optional: whether the call is optional. + target: name of the collab object to be called. + parallel: number of parallel outgoing messages. + """ + self.expect_result = expect_result + self.blocking = blocking + self.timeout = timeout + self.secure = secure + self.optional = optional + self.target = target + self.parallel = parallel + + if not self.expect_result: + # fire and forget - no need to control parallel + self.parallel = 0 + + def __str__(self): + return ( + f"expect_result={self.expect_result} blocking={self.blocking} timeout={self.timeout} " + f"secure={self.secure} optional={self.optional} target={self.target} parallel={self.parallel}" + ) diff --git a/nvflare/fox/api/constants.py b/nvflare/fox/api/constants.py new file mode 100644 index 0000000000..863844d45c --- /dev/null +++ b/nvflare/fox/api/constants.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +class CollabMethodArgName: + # defines optional args that a target's collaboration method can have + CONTEXT = "context" + + +class ContextKey: + RESULT = "result" + QUALIFIED_FUNC_NAME = "qualified_func_name" + DIRECTION = "direction" + + +class FilterDirection: + INCOMING = "incoming" + OUTGOING = "outgoing" + + +class BackendType: + SIMULATION = "simulation" + FLARE = "flare" + + +MAKE_CLIENT_APP_METHOD = "make_client_app" diff --git a/nvflare/fox/api/ctx.py b/nvflare/fox/api/ctx.py new file mode 100644 index 0000000000..26f9b7ecff --- /dev/null +++ b/nvflare/fox/api/ctx.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading + +from nvflare.apis.signal import Signal + +fox_context = threading.local() + + +class Context: + + def __init__(self, app, caller: str, callee: str, abort_signal: Signal, target_group=None): + if not isinstance(caller, str): + raise ValueError(f"caller must be str but got {type(caller)}") + + if not isinstance(callee, str): + raise ValueError(f"callee must be str but got {type(callee)}") + + self.caller = caller + self.callee = callee + self.target_group = target_group + self.abort_signal = abort_signal + self.app = app + self.props = {} + self.parent_ctx = get_call_context() + + @property + def backend(self): + return self.app.backend + + @property + def backend_type(self): + return self.app.backend_type + + @property + def clients(self): + return self.app.client_proxies + + @property + def server(self): + return self.app.server_proxy + + @property + def client_hierarchy(self): + return self.app.client_hierarchy + + @property + def workspace(self): + return self.app.workspace + + @property + def target_group_size(self): + if self.target_group: + return self.target_group.size + else: + return 1 + + def set_prop(self, name: str, value): + self.props[name] = value + + def get_prop(self, name: str, default=None): + return self.props.get(name, default) + + def is_aborted(self): + return self.abort_signal and self.abort_signal.triggered + + def __str__(self): + return f"{self.app.name}:{self.caller}=>{self.callee}" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.parent_ctx: + set_call_context(self.parent_ctx) + else: + set_call_context(None) + + +def get_call_context(): + if hasattr(fox_context, "call_ctx"): + return fox_context.call_ctx + else: + return None + + +def set_call_context(ctx): + fox_context.call_ctx = ctx diff --git a/nvflare/fox/api/dec.py b/nvflare/fox/api/dec.py new file mode 100644 index 0000000000..1ec7d03c0c --- /dev/null +++ b/nvflare/fox/api/dec.py @@ -0,0 +1,224 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + +from .constants import CollabMethodArgName + +_FLAG_COLLAB = "_fox_is_collab" +_FLAG_INIT = "_fox_is_init" +_FLAG_FINAL = "_fox_is_final" +_FLAG_ALGO = "_fox_is_algo" +_FLAG_CALL_FILTER = "_fox_is_call_filter" +_FLAG_IN_CALL_FILTER = "_fox_is_in_call_filter" +_FLAG_OUT_CALL_FILTER = "_fox_is_out_call_filter" +_FLAG_RESULT_FILTER = "_fox_is_result_filter" +_FLAG_IN_RESULT_FILTER = "_fox_is_in_result_filter" +_FLAG_OUT_RESULT_FILTER = "_fox_is_out_result_filter" +_FLAG_SUPPORT_CTX = "_fox_supports_ctx" +_ATTR_PARAM_NAMES = "_fox_param_names" + + +class classproperty: + def __init__(self, fget): + self.fget = fget + + def __get__(self, owner_instance, owner_class): + return self.fget(owner_class) + + +def _set_attrs(func, wrapper): + signature = inspect.signature(func) + parameter_names = list(signature.parameters.keys()) + if "self" in parameter_names: + parameter_names.remove("self") + setattr(wrapper, _ATTR_PARAM_NAMES, parameter_names) + if CollabMethodArgName.CONTEXT in parameter_names: + setattr(wrapper, _FLAG_SUPPORT_CTX, True) + + +def collab(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_COLLAB, True) + return wrapper + + +def is_collab(func): + return _has_flag(func, _FLAG_COLLAB) + + +def get_object_collab_interface(obj): + result = {} + for name in dir(obj): + func = getattr(obj, name) + if callable(func) and is_collab(func): + result[name] = get_param_names(func) + return result + + +def init(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_INIT, True) + return wrapper + + +def get_object_init_funcs(obj): + return _get_object_funcs(obj, _FLAG_INIT, "init") + + +def final(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_FINAL, True) + return wrapper + + +def get_object_final_funcs(obj): + return _get_object_funcs(obj, _FLAG_FINAL, "final") + + +def algo(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_ALGO, True) + return wrapper + + +def get_object_algo_funcs(obj): + return _get_object_funcs(obj, _FLAG_ALGO, "algo") + + +def call_filter(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_CALL_FILTER, True) + return wrapper + + +def get_object_call_filter_funcs(obj): + return _get_object_funcs(obj, _FLAG_CALL_FILTER, "call_filter") + + +def in_call_filter(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_IN_CALL_FILTER, True) + return wrapper + + +def get_object_in_call_filter_funcs(obj): + return _get_object_funcs(obj, _FLAG_IN_CALL_FILTER, "in_call_filter") + + +def out_call_filter(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_OUT_CALL_FILTER, True) + return wrapper + + +def get_object_out_call_filter_funcs(obj): + return _get_object_funcs(obj, _FLAG_OUT_CALL_FILTER, "out_call_filter") + + +def result_filter(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_RESULT_FILTER, True) + return wrapper + + +def get_object_result_filter_funcs(obj): + return _get_object_funcs(obj, _FLAG_RESULT_FILTER, "result_filter") + + +def in_result_filter(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_IN_RESULT_FILTER, True) + return wrapper + + +def get_object_in_result_filter_funcs(obj): + return _get_object_funcs(obj, _FLAG_IN_RESULT_FILTER, "in_result_filter") + + +def out_result_filter(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _set_attrs(func, wrapper) + setattr(wrapper, _FLAG_OUT_RESULT_FILTER, True) + return wrapper + + +def get_object_out_result_filter_funcs(obj): + return _get_object_funcs(obj, _FLAG_OUT_RESULT_FILTER, "out_result_filter") + + +def get_param_names(func): + return getattr(func, _ATTR_PARAM_NAMES, None) + + +def _has_flag(func, flag: str) -> bool: + v = getattr(func, flag, None) + return v is True + + +def _get_object_funcs(obj, flag, func_type): + result = [] + for name in dir(obj): + func = getattr(obj, name) + if callable(func) and _has_flag(func, flag): + # print(f"found {func_type} func of object {obj.__class__.__name__}.{name}") + result.append((name, func)) + return result + + +def supports_context(func): + return _has_flag(func, _FLAG_SUPPORT_CTX) + + +def adjust_kwargs(func, kwargs): + """Adjust the kwargs and remove keys that are not supported by the func. + + Args: + func: the func to be checked + kwargs: the kwargs to be adjusted + + Returns: the adjusted kwargs + + """ + if not supports_context(func): + kwargs.pop(CollabMethodArgName.CONTEXT, None) + return kwargs diff --git a/nvflare/fox/api/facade.py b/nvflare/fox/api/facade.py new file mode 100644 index 0000000000..27b02abee1 --- /dev/null +++ b/nvflare/fox/api/facade.py @@ -0,0 +1,342 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .constants import ContextKey +from .ctx import get_call_context +from .dec import algo as dec_algo +from .dec import call_filter as dec_call_filter +from .dec import classproperty +from .dec import collab as dec_collab +from .dec import final as dec_final +from .dec import in_call_filter as dec_in_call_filter +from .dec import in_result_filter as dec_in_result_filter +from .dec import init as dec_init +from .dec import out_call_filter as dec_out_call_filter +from .dec import out_result_filter as dec_out_result_filter +from .dec import result_filter as dec_result_filter +from .proxy_list import ProxyList + + +class facade: + + collab = dec_collab + init = dec_init + final = dec_final + algo = dec_algo + call_filter = dec_call_filter + in_call_filter = dec_in_call_filter + out_call_filter = dec_out_call_filter + result_filter = dec_result_filter + in_result_filter = dec_in_result_filter + out_result_filter = dec_out_result_filter + + @classproperty + def context(cls): + """Get the call context. + + Returns: a context object + + """ + return get_call_context() + + @classproperty + def caller(cls): + """Get the site name of the caller + + Returns: name of the caller + + """ + ctx = get_call_context() + return ctx.caller + + @classproperty + def callee(cls): + """Get the fully qualified collab object name of the invoked object: [.] + + Returns: fully qualified collab object name of the invoked object + + """ + ctx = get_call_context() + return ctx.callee + + @classproperty + def call_info(cls): + """Get a string that represents call information + + Returns: a string that represents call information + + The string looks like: + + :=> + + """ + ctx = get_call_context() + return str(ctx) + + @classproperty + def site_name(cls): + """Get the current site name, which is the name of the "app" object of the current site. + + Returns: the current site name + + """ + ctx = get_call_context() + return ctx.app.name + + @classproperty + def server(cls): + """Get the server proxy. + + Returns: the server proxy + + """ + ctx = get_call_context() + return ctx.server + + @classproperty + def clients(cls): + """Get all client proxies. + + Returns: all client proxies as a ProxyList + + """ + ctx = get_call_context() + return ProxyList(ctx.clients) + + @classproperty + def other_clients(cls): + """Get all client proxies, excluding the site's own proxy. + + Returns: all client proxies, excluding the site's own proxy + + """ + ctx = get_call_context() + + # Note that ctx.clients returns a copy of client proxies, not the original client proxy list! + # So it is safe to manipulate the candidates here. + candidates = ctx.clients + me = ctx.app.my_site + if me in candidates: + candidates.remove(me) + return ProxyList(candidates) + + @classproperty + def child_clients(cls): + """Get all child client proxies. + + Returns: all child client proxies if the site has children. An exception is raised if no children. + + """ + ctx = get_call_context() + candidates = ctx.app.get_children() + if not candidates: + raise RuntimeError(f"app {ctx.app.name} has no child clients") + return ProxyList(candidates) + + @classproperty + def has_children(cls): + """Check whether the site has any child proxies. + + Returns: whether the site has any child proxies + + """ + ctx = get_call_context() + return ctx.app.has_children() + + @classproperty + def leaf_clients(cls): + """Get all leaf client proxies. + + Returns: all leaf client proxies + + """ + ctx = get_call_context() + candidates = ctx.app.get_leaf_clients() + if not candidates: + raise RuntimeError(f"app {ctx.app.name} has no leaf clients") + return ProxyList(candidates) + + @classmethod + def get_clients(cls, names: list[str]): + """Get proxies for specified site names. + + Args: + names: names of the sites for which to get proxies. + + Returns: + + """ + ctx = get_call_context() + candidates = ctx.clients + result = [] + for n in names: + p = None + for c in candidates: + if c.name == n: + p = c + break + if not p: + # no proxy for this name + raise RuntimeError(f"app {ctx.app.name} has no client '{n}'") + result.append(p) + return ProxyList(result) + + @classproperty + def backend_type(cls): + """Get the backend type of the current site. + + Returns: the backend type of the current site + + """ + ctx = get_call_context() + return ctx.backend_type + + @classproperty + def is_aborted(cls): + """Check whether the job/experiment has been aborted. + + Returns: whether the job/experiment has been aborted + + """ + ctx = get_call_context() + return ctx.is_aborted() + + @classproperty + def workspace(cls): + """Get the workspace object. + + Returns: the workspace object + + """ + ctx = get_call_context() + return ctx.workspace + + @classproperty + def filter_direction(cls): + """Get the direction of filter call (incoming or outgoing). Only available to filter functions. + + Returns: the direction of filter call + + """ + ctx = get_call_context() + return ctx.get_prop(ContextKey.DIRECTION) + + @classproperty + def qual_func_name(cls): + """Get the filter's qualified function name. Only available to filter functions. + + Returns: the filter's qualified function name + + """ + ctx = get_call_context() + return ctx.get_prop(ContextKey.QUALIFIED_FUNC_NAME) + + @staticmethod + def fire_event(event_type: str, data): + """Fire an event to listening objects within the site. + + Args: + event_type: type of the event + data: data of the event + + Returns: results from event handlers. + + """ + ctx = get_call_context() + return ctx.app.fire_event(event_type, data, ctx) + + @staticmethod + def register_event_handler(event_type: str, handler, **handler_kwargs): + """Register an event handler for a specified event type + + Args: + event_type: type of the event + handler: the handler function to be registered + **handler_kwargs: kwargs to be passed to the handler + + Returns: None + + """ + ctx = get_call_context() + ctx.app.register_event_handler(event_type, handler, **handler_kwargs) + + @staticmethod + def get_app_prop(name: str, default=None): + """Get a specified property from the site's app (usually for configuration properties). + + Args: + name: name of the property. + default: default value if the property does not exist. + + Returns: value of the specified app property, or default value if the property does not exist + + """ + ctx = get_call_context() + return ctx.app.get_prop(name, default) + + @staticmethod + def set_app_prop(name: str, value): + """Set a specified property into the site's app. + Properties in app are permanent during the job/experiment execution. + + Args: + name: name of the property. + value: value of the property. + + Returns: + + """ + ctx = get_call_context() + return ctx.app.set_prop(name, value) + + @staticmethod + def get_prop(name: str, default=None): + """Get a specified property from the call context. Usually for sharing information during collab function + processing. + + Args: + name: name of the property. + default: default value if the property does not exist. + + Returns: + + """ + ctx = get_call_context() + return ctx.get_prop(name, default) + + @staticmethod + def set_prop(name: str, value): + """Set a specified property into the call context. Usually for sharing information during collab function + processing. + + Args: + name: name of the property. + value: value of the property. + + Returns: + + """ + ctx = get_call_context() + return ctx.set_prop(name, value) + + @staticmethod + def get_result(default=None): + """Get the last algo execution result from the call context. + + Args: + default: the default value if the result does not exist in the call context. + + Returns: the last algo execution result from the call context + + """ + return facade.get_prop(ContextKey.RESULT, default) diff --git a/nvflare/fox/api/filter.py b/nvflare/fox/api/filter.py new file mode 100644 index 0000000000..6584819f24 --- /dev/null +++ b/nvflare/fox/api/filter.py @@ -0,0 +1,182 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from nvflare.fuel.utils.log_utils import get_obj_logger + +from .constants import CollabMethodArgName +from .ctx import Context +from .dec import ( + get_object_call_filter_funcs, + get_object_in_call_filter_funcs, + get_object_in_result_filter_funcs, + get_object_out_call_filter_funcs, + get_object_out_result_filter_funcs, + get_object_result_filter_funcs, + supports_context, +) + + +class _Filter: + + def __init__(self, filter_type: str, impl: object = None, incoming=True): + self.filter_type = filter_type + self.impl = impl + self.incoming = incoming + self.impl_func = None + self.logger = get_obj_logger(self) + + def get_impl_object(self): + if self.impl: + return self.impl + else: + return self + + def filter_data(self, data, context: Context): + if self.impl_func is not None: + if self.incoming: + d = "incoming" + else: + d = "outgoing" + + name, f = self.impl_func + self.logger.info(f"calling {d} {self.filter_type}: {name} on ctx {id(context)}") + if supports_context(f): + kwargs = {CollabMethodArgName.CONTEXT: context} + else: + kwargs = {} + return f(data, **kwargs) + else: + return data + + +def _determine_filter_impl_func( + obj, + filter_type: str, + incoming: bool, + get_filter_f, + get_in_filter_f, + get_out_filter_f, +): + if incoming: + funcs = get_in_filter_f(obj) + d = "in" + else: + funcs = get_out_filter_f(obj) + d = "out" + + if len(funcs) > 1: + raise ValueError( + f"filter object {obj.__class__.__name__} must have one {d}_{filter_type} func but got {len(funcs)}" + ) + + if len(funcs) == 1: + return funcs[0] + + funcs = get_filter_f(obj) + if not funcs: + raise ValueError(f"filter impl object {obj.__class__.__name__} has no {filter_type} func") + + if len(funcs) > 1: + raise ValueError( + f"filter object {obj.__class__.__name__} must have one {filter_type} func but got {len(funcs)}" + ) + return funcs[0] + + +class CallFilter(_Filter): + + def __init__(self, impl: object = None, incoming=True): + super().__init__("call filter", impl, incoming) + if impl: + self.impl_func = _determine_filter_impl_func( + obj=impl, + incoming=incoming, + filter_type="call_filter", + get_filter_f=get_object_call_filter_funcs, + get_in_filter_f=get_object_in_call_filter_funcs, + get_out_filter_f=get_object_out_call_filter_funcs, + ) + + def filter_call(self, func_kwargs: dict, context: Context): + """Filter kwargs of function call. + + Args: + func_kwargs: kwargs to be filtered + context: call context + + Returns: filtered kwargs that will be passed to a collab func. + + """ + return self.filter_data(func_kwargs, context) + + +class ResultFilter(_Filter): + + def __init__(self, impl: object = None, incoming=True): + super().__init__("result filter", impl, incoming) + if impl: + self.impl_func = _determine_filter_impl_func( + obj=impl, + filter_type="result_filter", + incoming=incoming, + get_filter_f=get_object_result_filter_funcs, + get_in_filter_f=get_object_in_result_filter_funcs, + get_out_filter_f=get_object_out_result_filter_funcs, + ) + + def filter_result(self, result: Any, context: Context): + """Filter result produced by a collab func. + + Args: + result: data to be filtered + context: call context + + Returns: filtered result + + """ + return self.filter_data(result, context) + + +class FilterChain: + + def __init__(self, pattern, filter_type): + if filter_type not in [ResultFilter, CallFilter]: + raise ValueError( + f"filter_type must be type of {ResultFilter.__name__} or {CallFilter.__name__} but got {filter_type}" + ) + self.pattern = pattern + self.filter_type = filter_type + self.filters = [] + + def add_filters(self, filters): + if not filters: + return + + if isinstance(filters, list): + if not all(isinstance(item, self.filter_type) for item in filters): + raise ValueError(f"some items in filters are not {self.filter_type}") + self.filters.extend(filters) + else: + if not isinstance(filters, self.filter_type): + raise ValueError(f"filter item must be {self.filter_type} but got {type(filters)}") + self.filters.append(filters) + + def apply_filters(self, data, context: Context): + for f in self.filters: + if isinstance(f, ResultFilter): + data = f.filter_result(data, context) + else: + data = f.filter_call(data, context) + return data diff --git a/nvflare/fox/api/gcc.py b/nvflare/fox/api/gcc.py new file mode 100644 index 0000000000..32422f0621 --- /dev/null +++ b/nvflare/fox/api/gcc.py @@ -0,0 +1,260 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import queue +import threading + +from nvflare.apis.fl_exception import RunAborted +from nvflare.fuel.utils.log_utils import get_obj_logger + +from .call_opt import CallOption +from .constants import CollabMethodArgName +from .ctx import Context, set_call_context +from .utils import check_context_support + +_SHORT_WAIT = 1.0 + + +class ResultQueue: + def __init__(self, limit: int): + if limit <= 0: + raise ValueError(f"bad queue limit {limit}: must be > 0") + + self.limit = limit + self.q = queue.Queue() + + # num_whole_items_received is the number of WHOLE items received. + # the queue could contain partial results. + self.num_whole_items_received = 0 + self.update_lock = threading.Lock() + + def append(self, item, is_whole=True): + """Append an item to the result queue. + + Args: + item: the item to be appended. + is_whole: whether the item is whole. + + Returns: whether the queue has received all whole items. + """ + with self.update_lock: + if self.num_whole_items_received < self.limit: + self.q.put_nowait(item) + if is_whole: + # increment num_whole_items_received only if the item is whole! + # note: num_whole_items_received is not the number of all items received. + # partial items could be added to the queue but do not count as whole items. + self.num_whole_items_received += 1 + return self.num_whole_items_received == self.limit + else: + # do not allow any items (partial or whole) to be added to the queue if the queue + # has already received all expected whole items. + raise RuntimeError(f"queue is full: {self.limit} whole items are already appended") + + def __iter__(self): + return self + + def __next__(self): + if not self.q.empty(): + return self.q.get() + + # queue is empty: do we expect more? + if self.num_whole_items_received < self.limit: + # there will be more items - wait until more item is received + return self.q.get(block=True) + else: + # no more items + raise StopIteration() + + def __len__(self): + """Return the number of whole items that have been received. + Note that this is NOT the current number of items in the queue! + + Returns: the number of whole items that have been received + + """ + return self.num_whole_items_received + + +class ResultWaiter(threading.Event): + + def __init__(self, sites: list[str]): + super().__init__() + self.sites = sites + self.results = ResultQueue(len(sites)) + self.standing_call_count = 0 + self.call_count_decreased = threading.Condition(threading.Lock()) + + def inc_call_count(self): + """Increment standing call count by 1. + + Returns: None + + """ + with self.call_count_decreased: + self.standing_call_count += 1 + + def dec_call_count(self): + """Decrease standing call count by 1, and notify other threads waiting for call count decreased. + + Returns: None + + """ + with self.call_count_decreased: + self.standing_call_count -= 1 + self.call_count_decreased.notify() + + def wait_for_call_permission(self, limit, abort_signal): + """Wait for the permission to make next call. + The permission is granted when parallel call count is lower than the specified limit. + + Args: + limit: to limit to check + abort_signal: abort signal + + Returns: None + + """ + while True: + with self.call_count_decreased: + if abort_signal and abort_signal.triggered: + raise RunAborted("run is aborted while waiting for sending availability") + + if self.standing_call_count < limit: + return + else: + self.call_count_decreased.wait(_SHORT_WAIT) + + def wait_for_responses(self, abort_signal): + while True: + if abort_signal.triggered: + raise RunAborted("run is aborted while waiting for remote responses") + + done = self.wait(_SHORT_WAIT) + if done: + break + + @staticmethod + def _get_site_name(target_name: str): + # target_name is either or . + parts = target_name.split(".") + return parts[0] + + def set_result(self, target_name: str, result): + site_name = self._get_site_name(target_name) + all_received = self.results.append((site_name, result)) + if all_received: + self.set() + + def add_partial_result(self, target_name: str, partial_result): + site_name = self._get_site_name(target_name) + self.results.append((site_name, partial_result), is_whole=False) + + +class GroupCallContext: + + def __init__( + self, + app, + target_name: str, + call_opt: CallOption, + func_name: str, + process_cb, + cb_kwargs, + context: Context, + waiter: ResultWaiter, + ): + """GroupCallContext contains contextual information about a group call to a target. + + Args: + app: the calling app. + target_name: name of the target to be called in the remote app. + call_opt: call options. + func_name: name of the function to be called in the remote app. + process_cb: the callback function to be called to process response from the remote app. + cb_kwargs: kwargs passed to the callback function. + context: call context. + waiter: the waiter to wait for result + """ + self.app = app + self.call_opt = call_opt + self.target_name = target_name + self.func_name = func_name + self.process_cb = process_cb + self.cb_kwargs = cb_kwargs + self.context = context + self.waiter = waiter + self.send_complete_cb = None + self.send_complete_cb_kwargs = {} + self.logger = get_obj_logger(self) + + def set_send_complete_cb(self, cb, **cb_kwargs): + if not callable(cb): + raise ValueError("send_complete_cb must be callable") + self.send_complete_cb = cb + self.send_complete_cb_kwargs = cb_kwargs + + def send_completed(self): + if self.send_complete_cb: + self.send_complete_cb(**self.send_complete_cb_kwargs) + + def set_result(self, result): + """This is called by the backend to set the result received from the remote app. + If process_cb is available, it will be called with the result from the remote app. + + Args: + result: the result received from the remote app. + + Returns: None + + """ + try: + # filter incoming result + ctx = copy.copy(self.context) + + # swap caller/callee + original_caller = ctx.caller + ctx.caller = ctx.callee + ctx.callee = original_caller + + if not isinstance(result, Exception): + set_call_context(ctx) + try: + result = self.app.apply_incoming_result_filters(self.target_name, self.func_name, result, ctx) + if self.process_cb: + self.cb_kwargs[CollabMethodArgName.CONTEXT] = ctx + check_context_support(self.process_cb, self.cb_kwargs) + result = self.process_cb(self, result, **self.cb_kwargs) + finally: + # set back to original context + set_call_context(self.context) + except Exception as ex: + result = ex + finally: + self.waiter.set_result(self.target_name, result) + + def set_exception(self, ex): + """This is called by the backend to set the exception received from the remote app. + The process_cb will NOT be called. + + Args: + ex: the exception received from the remote app. + + Returns: + + """ + self.waiter.set_result(self.target_name, ex) + + def add_partial_result(self, partial_result): + self.waiter.add_partial_result(self.target_name, partial_result) diff --git a/nvflare/fox/api/group.py b/nvflare/fox/api/group.py new file mode 100644 index 0000000000..a3a46f1c0d --- /dev/null +++ b/nvflare/fox/api/group.py @@ -0,0 +1,217 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import List + +from nvflare.apis.signal import Signal +from nvflare.fuel.utils.log_utils import get_obj_logger + +from .app import App +from .call_opt import CallOption +from .ctx import Context +from .gcc import GroupCallContext, ResultWaiter +from .proxy import Proxy +from .utils import check_call_args + + +class Group: + + def __init__( + self, + app, + abort_signal: Signal, + proxies: List[Proxy], + call_opt: CallOption = None, + process_resp_cb=None, + **cb_kwargs, + ): + """A Group is a group of remote apps to be called. + + Args: + app: the calling app. + abort_signal: signal to abort execution. + proxies: proxies of the remote apps to be called. + call_opt: call option that specifies call behavior + process_resp_cb: callback function to be called to process responses from remote apps. + **cb_kwargs: kwargs passed to process_resp_cb. + """ + if not proxies: + raise ValueError("no proxies to group") + + self._app = app + self._abort_signal = abort_signal + self._proxies = proxies + if not call_opt: + call_opt = CallOption() + self._call_opt = call_opt + self._process_resp_cb = process_resp_cb + self._cb_kwargs = cb_kwargs + self._logger = get_obj_logger(self) + + @property + def size(self): + """Size of the group, which is the number of remote apps to be called. + + Returns: size of the group. + + """ + return len(self._proxies) + + @property + def members(self): + """ + Returns the members of the group, which is the list of all remote apps to be called. + + Returns: the members of the group + + """ + return self._proxies + + def _get_work_proxy(self, p, func_name): + if self._call_opt.target: + child = p.get_child(self._call_opt.target) + if not child: + raise RuntimeError( + f"site {p.name} does not have collab target named '{self._call_opt.target}': " + f"make sure to use correct target in the group call of '{func_name}'." + ) + return child + else: + return p + + def __getattr__(self, func_name): + """ + This method is called to invoke the specified collab function. + + If expect_result is False, then the call immediately returns None. + + If expect_result is True, a ResultQueue object is returned. Results from each site will be appended to + the queue when they become available. If a site does not return result before timeout, the site's result + is TimeoutError exception. Each item in the queue is a tuple of (site_name, result). + + The blocking flag is only meaningful when expect_result is True. If blocking is True, the call does not + return until results are received from all sites (or timed out). If blocking is False, the call immediately + returns. In both cases, the ResultQueue object is returned, and the application should iterate through it + to process site results. + + """ + + def method(*args, **kwargs): + the_backend = None + try: + # filter once for all targets + p = self._get_work_proxy(self._proxies[0], func_name) + + # func_proxy is the proxy that actually has the func. + # the func_proxy is either "p" or a child of "p". + func_proxy, func_itf, adj_args, adj_kwargs = p.adjust_func_args(func_name, args, kwargs) + the_backend = p.backend + + with func_proxy.app.new_context(func_proxy.caller_name, func_proxy.name, target_group=self) as ctx: + self._logger.info( + f"[{ctx}] calling {func_name} {self._call_opt} of group {[p.name for p in self._proxies]}" + ) + + # apply outgoing call filters + assert isinstance(self._app, App) + adj_kwargs = self._app.apply_outgoing_call_filters( + func_proxy.target_name, func_name, adj_kwargs, ctx + ) + check_call_args(func_name, func_itf, adj_args, adj_kwargs) + + waiter = ResultWaiter([p.name for p in self._proxies]) + max_parallel = self._call_opt.parallel + if max_parallel <= 0: + max_parallel = len(self._proxies) + + for p in self._proxies: + p = self._get_work_proxy(p, func_name) + func_proxy, func_itf, call_args, call_kwargs = p.adjust_func_args( + func_name, adj_args, adj_kwargs + ) + call_kwargs = copy.copy(call_kwargs) + ctx = self._app.new_context( + func_proxy.caller_name, func_proxy.name, target_group=self, set_call_ctx=False + ) + + gcc = GroupCallContext( + app=self._app, + target_name=func_proxy.target_name, + call_opt=self._call_opt, + func_name=func_name, + process_cb=self._process_resp_cb, + cb_kwargs=self._cb_kwargs, + context=ctx, + waiter=waiter, + ) + + # try to get permission to make next call + gcc.set_send_complete_cb(self._request_sent, gcc=gcc, proxy=func_proxy) + waiter.wait_for_call_permission(max_parallel, self._abort_signal) + + # make next call + waiter.inc_call_count() + func_proxy.backend.call_target_in_group(gcc, func_name, *call_args, **call_kwargs) + + if not self._call_opt.expect_result: + # do not wait for responses + return None + + if not self._call_opt.blocking: + self._logger.debug(f"not blocking {func_name}") + return waiter.results + + # wait for responses + waiter.wait_for_responses(self._abort_signal) + return waiter.results + except Exception as ex: + self._logger.error(f"exception {type(ex)} occurred: {ex}") + if the_backend: + the_backend.handle_exception(ex) + raise ex + + return method + + def _request_sent(self, gcc: GroupCallContext, proxy: Proxy): + self._logger.debug(f"[{gcc.context}] call has been sent to '{proxy.name}' for func '{gcc.func_name}'") + gcc.waiter.dec_call_count() + + +def group( + ctx: Context, + proxies: List[Proxy], + call_opt: CallOption = None, + process_resp_cb=None, + **cb_kwargs, +): + """This is a convenience method for creating a group. + + Args: + ctx: call context. + proxies: list of proxies. + call_opt: call option that defines call behavior. + process_resp_cb: callback to be called to process response from remote site. + **cb_kwargs: kwargs to be passed to the CB. + + Returns: a Group object. + + """ + return Group( + ctx.app, + ctx.abort_signal, + proxies, + call_opt, + process_resp_cb, + **cb_kwargs, + ) diff --git a/nvflare/fox/api/proxy.py b/nvflare/fox/api/proxy.py new file mode 100644 index 0000000000..6e57cb5024 --- /dev/null +++ b/nvflare/fox/api/proxy.py @@ -0,0 +1,249 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from nvflare.fuel.utils.log_utils import get_obj_logger + +from .backend import Backend +from .call_opt import CallOption +from .utils import check_call_args + + +class _ProxyCall: + + def __init__( + self, + proxy, + expect_result: bool = True, + timeout: float = 5.0, + optional: bool = False, + secure: bool = False, + target: str = None, + ): + self.proxy = proxy + self.call_opt = CallOption( + expect_result=expect_result, + blocking=expect_result, + timeout=timeout, + optional=optional, + secure=secure, + target=target, + ) + + def __getattr__(self, func_name): + def method(*args, **kwargs): + return self.proxy.call_func(self.call_opt, func_name, args, kwargs) + + return method + + +class Proxy: + + def __init__(self, app, target_name, target_fqn: str, backend: Backend, target_interface): + """The Proxy represents a target in the App.""" + self.app = app + self.target_name = target_name + self.fqn = target_fqn # fully qualified name of the target in hierarchy + self.backend = backend + self.caller_name = app.name + self.target_interface = target_interface + self.children = {} # child proxies + self.logger = get_obj_logger(self) + + def __call__( + self, + expect_result: bool = True, + timeout: float = 5.0, + optional: bool = False, + secure: bool = False, + target: str = None, + ): + """This is called when the proxy is used with call options. + + Args: + expect_result: + timeout: + optional: + secure: + target: + + Returns: + + """ + return _ProxyCall( + proxy=self, + expect_result=expect_result, + timeout=timeout, + optional=optional, + secure=secure, + target=target, + ) + + @property + def name(self): + return self.target_name + + def add_child(self, name, p): + self.children[name] = p + setattr(self, name, p) + + def get_child(self, name): + """Get the specified child proxy. + + Args: + name: name of the child proxy. + + Returns: the child proxy if defined. + + """ + return self.children.get(name) + + def _find_interface(self, func_name): + """Find interface for specified func name. + + Args: + func_name: name of the func. + + Returns: the proxy that the func belongs to, the func interface. + + Notes: the proxy represents a remote object. The remote object could have sub-objects. In this case, + the proxy will have child proxies, each representing a sub-object. + + We first try to find the interface from the proxy itself. If not found, we try to find it from child proxies. + + """ + # self.logger.debug(f"trying to find interface for {func_name}") + args = self.target_interface.get(func_name) if self.target_interface else None + if args: + return self, args + + # try children + the_args = None + the_proxy = None + the_name = None + for n, c in self.children.items(): + args = c.target_interface.get(func_name) if c.target_interface else None + if not args: + continue + + # self.logger.debug(f"found interface for func {func_name}: defined in child {n}") + + if not the_proxy: + the_name = n + the_proxy = c + the_args = args + else: + # already found a child proxy that has this func - ambiguity + raise RuntimeError( + f"multiple collab objects ({the_name} and {n}) have {func_name}: please use qualified call" + ) + return the_proxy, the_args + + def adjust_func_args(self, func_name, args, kwargs): + """Based on specified args and kwargs, this method finds corresponding keywords for all positional + args based on the interface of the func, and then moves the positional args into kwargs. + + Once done, all args will have keywords, which makes it easy for call filters to identify the args to process. + + Args: + func_name: name of the func. + args: positional arg values. + kwargs: keyword arg values + + Returns: the proxy that the func belongs to, interface of the func, empty args, and new kwargs + + """ + call_args = args + call_kwargs = kwargs + + # find the proxy and interface for the func + p, func_itf = self._find_interface(func_name) + if not p: + raise RuntimeError(f"target {self.target_name} does not have method '{func_name}'") + + if func_itf: + # check args and turn them to kwargs + num_call_args = len(args) + len(kwargs) + if num_call_args > len(func_itf): + raise RuntimeError( + f"there are {num_call_args} call args ({args=} {kwargs=}), " + f"but function '{func_name}' only supports {len(func_itf)} args ({func_itf})" + ) + call_kwargs = copy.copy(kwargs) + call_args = [] + for i, arg_value in enumerate(args): + call_kwargs[func_itf[i]] = arg_value + + return p, func_itf, call_args, call_kwargs + + def call_func(self, call_opt: CallOption, func_name, args, kwargs): + """Call the specified function with call options. + + Args: + call_opt: call option that controls the call behavior. + func_name: name of func to be called. + args: args to be passed to the func. + kwargs: kwargs to be passed to the func. + + Returns: result of the function, or exception. + + """ + try: + if call_opt.target: + p = self.get_child(call_opt.target) + if not p: + raise RuntimeError( + f"site {self.name} does not have collab target named '{call_opt.target}': " + f"make sure to use correct target when calling '{func_name}'." + ) + else: + p = self + + p, func_itf, call_args, call_kwargs = p.adjust_func_args(func_name, args, kwargs) + + with p.app.new_context(self.caller_name, self.name) as ctx: + # apply outgoing call filters + call_kwargs = self.app.apply_outgoing_call_filters(p.target_name, func_name, call_kwargs, ctx) + check_call_args(func_name, func_itf, call_args, call_kwargs) + + result = p.backend.call_target(ctx, p.target_name, call_opt, func_name, *call_args, **call_kwargs) + if isinstance(result, Exception): + raise result + + if result is not None: + # filter incoming result filters + result = self.app.apply_incoming_result_filters(p.target_name, func_name, result, ctx) + return result + except Exception as ex: + if self.backend: + try: + self.backend.handle_exception(ex) + except Exception as ex2: + # ignore exception from backend handling + self.logger.error(f"ignored backend's exception {type(ex2)}") + + # Must return the exception as the result of the func call. + # Do NOT raise it! + return ex + + def __getattr__(self, func_name): + """ + This method is called when the proxy is invoked to perform the specified func without any call options. + In this case, we use a CallOpt with default values. + """ + + def method(*args, **kwargs): + return self.call_func(CallOption(), func_name, args, kwargs) + + return method diff --git a/nvflare/fox/api/proxy_list.py b/nvflare/fox/api/proxy_list.py new file mode 100644 index 0000000000..a37c2a4f49 --- /dev/null +++ b/nvflare/fox/api/proxy_list.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .call_opt import CallOption +from .ctx import get_call_context +from .group import group + + +class ProxyList(list): + + def __init__(self, proxies: list): + super().__init__(proxies) + + def __getattr__(self, func_name): + """This is called to invoke the specified func without specifying call option. + In this case, default call option will be used. + + Args: + func_name: + + Returns: + + """ + + def method(*args, **kwargs): + grp = group( + ctx=get_call_context(), + proxies=self, + ) + return getattr(grp, func_name)(*args, **kwargs) + + return method + + def __call__( + self, + blocking: bool = True, + expect_result: bool = True, + timeout: float = 5.0, + optional: bool = False, + secure: bool = False, + target=None, + parallel=0, + process_resp_cb=None, + **cb_kwargs, + ): + """This is called to define the behavior (Call Option) of the group call. + + Args: + blocking: + expect_result: + timeout: + optional: + secure: + target: + parallel: + process_resp_cb: + **cb_kwargs: + + Returns: + + """ + return group( + ctx=get_call_context(), + proxies=self, + call_opt=CallOption( + blocking=blocking, + expect_result=expect_result, + timeout=timeout, + optional=optional, + secure=secure, + target=target, + parallel=parallel, + ), + process_resp_cb=process_resp_cb, + **cb_kwargs, + ) diff --git a/nvflare/fox/api/run_server.py b/nvflare/fox/api/run_server.py new file mode 100644 index 0000000000..194372a087 --- /dev/null +++ b/nvflare/fox/api/run_server.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.security.logging import secure_log_traceback + +from .app import ServerApp +from .constants import CollabMethodArgName, ContextKey +from .dec import supports_context + + +def run_server(server_app: ServerApp, logger): + server_ctx = server_app.new_context(caller=server_app.name, callee=server_app.name) + logger.info("initializing server app") + server_app.initialize(server_ctx) + + if not server_app.algos: + raise RuntimeError("server app does not have any algos!") + + result = None + for name, f in server_app.algos: + if server_ctx.is_aborted(): + break + + try: + logger.info(f"Running algo {name}") + kwargs = {CollabMethodArgName.CONTEXT: server_ctx} + if not supports_context(f): + kwargs = {} + result = f(**kwargs) + server_ctx.set_prop(ContextKey.RESULT, result) + except Exception as ex: + secure_log_traceback(logger) + backend = server_app.backend + if backend: + backend.handle_exception(ex) + break + + logger.info("finalizing server app") + server_app.finalize(server_ctx) + return result diff --git a/nvflare/fox/api/utils.py b/nvflare/fox/api/utils.py new file mode 100644 index 0000000000..417ebb2550 --- /dev/null +++ b/nvflare/fox/api/utils.py @@ -0,0 +1,79 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import logging +from typing import List + +from .constants import CollabMethodArgName + + +def check_optional_args(func, kwargs, arg_names: List[str]): + signature = inspect.signature(func) + parameter_names = signature.parameters.keys() + + # make sure to expose the optional args if the collab method supports them + for n in arg_names: + if n not in parameter_names: + kwargs.pop(n, None) + + +def check_context_support(func, kwargs): + check_optional_args(func, kwargs, [CollabMethodArgName.CONTEXT]) + + +def get_collab_object_name(target_name: str): + """The target_name is either the site name or .. + This function gets the collab object name. + + Args: + target_name: + + Returns: + + """ + parts = target_name.split(".") + if len(parts) == 1: + return "_app_" + else: + return parts[1] + + +def check_call_args(func_name, func_itf, call_args, call_kwargs: dict): + """Check call args against the function's interface. + + Args: + func_name: + func_itf: + call_args: + call_kwargs: + + Returns: + + """ + num_call_args = len(call_args) + len(call_kwargs) + if num_call_args > len(func_itf): + # For security, collab funcs must only have fixed args - no flexible args are allowed. + raise RuntimeError( + f"there are {num_call_args} call args ({len(call_args)=} {len(call_kwargs)=}), " + f"but function '{func_name}' only supports {len(func_itf)} args ({func_itf})" + ) + + # make sure every arg in kwargs is valid + for arg_name in call_kwargs.keys(): + if arg_name not in func_itf: + raise RuntimeError(f"call arg {arg_name} is not supported by func '{func_name}'") + + +def simple_logging(level=logging.INFO): + logging.basicConfig(level=level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") diff --git a/nvflare/fox/api/workspace.py b/nvflare/fox/api/workspace.py new file mode 100644 index 0000000000..405bd20fc0 --- /dev/null +++ b/nvflare/fox/api/workspace.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod + + +class Workspace(ABC): + + def __init__(self): + self.resource_dirs = {} + + def add_resource_dir(self, name, resource_dir): + if not os.path.isdir(resource_dir): + raise ValueError(f"Resource dir {resource_dir} does not exist") + self.resource_dirs[name] = resource_dir + + @abstractmethod + def get_root_dir(self) -> str: + pass + + @abstractmethod + def get_work_dir(self) -> str: + pass + + def get_resource_dir(self, name: str, create: bool = True) -> str: + resource_dir = self.resource_dirs.get(name) + if resource_dir: + return resource_dir + + p = os.path.join(self.get_work_dir(), name) + if not os.path.exists(p) and create: + os.makedirs(p, exist_ok=True) + return p + + @abstractmethod + def get_experiment_dir(self) -> str: + pass diff --git a/nvflare/fox/examples/__init__.py b/nvflare/fox/examples/__init__.py new file mode 100644 index 0000000000..85481fa6cc --- /dev/null +++ b/nvflare/fox/examples/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + + +def get_job_root_dir(): + return os.getenv("FOX_JOB_ROOT", ".") + + +def get_experiment_root(): + return os.getenv("FOX_EXP_ROOT", ".") + + +def export_recipe(job_name: str, make_recipe_f): + recipe = make_recipe_f(job_name) + job_root = get_job_root_dir() + recipe.export(job_root) + print(f"job exported at {job_root}/{job_name}") diff --git a/nvflare/fox/examples/np/__init__.py b/nvflare/fox/examples/np/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/fox/examples/np/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fox/examples/np/algos/__init__.py b/nvflare/fox/examples/np/algos/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/fox/examples/np/algos/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fox/examples/np/algos/avg_stream.py b/nvflare/fox/examples/np/algos/avg_stream.py new file mode 100644 index 0000000000..ed205e461a --- /dev/null +++ b/nvflare/fox/examples/np/algos/avg_stream.py @@ -0,0 +1,156 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +import threading +import uuid + +from nvflare.fox import fox +from nvflare.fox.api.constants import BackendType +from nvflare.fox.examples.np.algos.utils import load_np_model, parse_array_def, save_np_model +from nvflare.fox.sys.downloader import Downloader, download_file +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class _AggrResult: + + def __init__(self): + self.total = 0 + self.count = 0 + self.lock = threading.Lock() + + +class NPFedAvgStream: + + def __init__(self, initial_model, num_rounds=10, timeout=2.0): + self.num_rounds = num_rounds + self.initial_model = initial_model + self.timeout = timeout + self.name = "NPFedAvgStream" + self.logger = get_obj_logger(self) + self._init_model = parse_array_def(initial_model) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + current_model = self._init_model + for i in range(self.num_rounds): + current_model = self._do_one_round(i, current_model) + if current_model is None: + self.logger.error(f"training failed at round {i}") + break + self.logger.info(f"FINAL MODEL: {current_model}") + return current_model + + def _do_one_round(self, r, current_model): + aggr_result = _AggrResult() + grp = fox.clients( + process_resp_cb=self._accept_train_result, + aggr_result=aggr_result, + ) + + # pretend the model is big + file_name = None + if fox.backend_type == BackendType.FLARE: + file_name = f"/tmp/np_{str(uuid.uuid4())}.npy" + save_np_model(current_model, file_name) + downloader = Downloader( + num_receivers=grp.size, + timeout=5.0, + ) + model = downloader.add_file(file_name=file_name, file_downloaded_cb=self._model_downloaded) + model_type = "ref" + self.logger.info(f"prepared model as ref: {model}") + else: + model = current_model + model_type = "model" + + grp.train(r, model, model_type) + + if file_name: + # train is a blocking call that does not return until train results (success or not) are received + # from all clients. + # remove the file regardless. + os.remove(file_name) + + if aggr_result.count == 0: + return None + else: + result = aggr_result.total / aggr_result.count + self.logger.info(f"[{fox.call_info}] round {r}: aggr result from {aggr_result.count} clients: {result}") + return result + + def _accept_train_result(self, gcc, result, aggr_result: _AggrResult): + self.logger.info(f"[{fox.call_info}] got train result from {fox.caller}: {result}") + + model, model_type = result + if model_type == "ref": + err, file_path = download_file(ref=model, per_request_timeout=5.0) + if err: + raise RuntimeError(f"failed to download model file {model}: {err}") + self.logger.info(f"downloaded model file to {file_path}") + model = load_np_model(file_path) + os.remove(file_path) + + with aggr_result.lock: + aggr_result.total += model + aggr_result.count += 1 + return None + + def _model_downloaded(self, to_site: str, status: str, file_name): + self.logger.info(f"model file {file_name} downloaded by {to_site}: {status=}") + + +class NPTrainer: + + def __init__(self, delta: float): + self.delta = delta + self.logger = get_obj_logger(self) + + @fox.collab + def train(self, current_round, weights, model_type: str): + if fox.is_aborted: + self.logger.debug("training aborted") + return None, "" + + self.logger.debug(f"[{fox.call_info}] training round {current_round}: {model_type=} {weights=}") + if model_type == "ref": + err, file_path = download_file(ref=weights, per_request_timeout=5.0) + if err: + raise RuntimeError(f"failed to download model file {weights}: {err}") + self.logger.info(f"downloaded model file to {file_path}") + weights = load_np_model(file_path) + self.logger.info(f"loaded model from file: {weights}") + os.remove(file_path) + + result = weights + self.delta + + if model_type == "ref": + # stream it + file_name = f"/tmp/np_{str(uuid.uuid4())}.npy" + save_np_model(result, file_name) + downloader = Downloader( + num_receivers=1, + timeout=5.0, + ) + result = downloader.add_file(file_name=file_name, file_downloaded_cb=self._result_downloaded) + self.logger.info(f"prepared result as ref: {result}") + + return result, model_type + + def _result_downloaded(self, to_site: str, status: str, file_name): + self.logger.info(f"model file {file_name} downloaded to {to_site}: {status=}") + if not to_site: + # downloaded to all sites + os.remove(file_name) + self.logger.info(f"model file {file_name} removed") diff --git a/nvflare/fox/examples/np/algos/client.py b/nvflare/fox/examples/np/algos/client.py new file mode 100644 index 0000000000..887da4207c --- /dev/null +++ b/nvflare/fox/examples/np/algos/client.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import time + +from nvflare.fox import fox +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class NPTrainer: + + def __init__(self, delta: float, delay=0): + self.delta = delta + self.delay = delay + self.logger = get_obj_logger(self) + + @fox.init + def init_trainer(self): + delta_config = fox.get_app_prop("client_delta", {}) + self.delta = delta_config.get(fox.site_name, self.delta) + self.logger.info(f"init_trainer: client {fox.site_name}: delta={self.delta}") + + @fox.init + def init_trainer2(self): + self.logger.info(f"init_trainer2: client {fox.site_name}: init again") + + @fox.collab + def train(self, current_round, weights): + if fox.is_aborted: + self.logger.debug("training aborted") + return 0 + self.logger.info(f"[{fox.call_info}] training round {current_round=} {weights=}") + # result = fox.server(expect_result=True).fire_event("metrics", {"round": current_round, "y": 10}) + # self.logger.info(f"[{fox.call_info}] got event result: {result}") + + if self.delay > 0: + time.sleep(self.delay) + return weights + self.delta + + @fox.collab + def evaluate(self, model): + self.logger.debug(f"[{fox.call_info}] evaluate") + return random.random() + + +class NPHierarchicalTrainer: + + def __init__(self, delta: float): + self.delta = delta + self.logger = get_obj_logger(self) + + @fox.collab + def train(self, current_round, weights): + if fox.is_aborted: + self.logger.debug("training aborted") + return None + + self.logger.debug(f"[{fox.call_info}] training round {current_round}") + if fox.has_children: + total = 0 + results = fox.child_clients.train(current_round, weights) + for n, v in results: + total += v + result = total / len(results) + self.logger.debug(f"[{fox.call_info}]: aggr result from children of round {current_round}: {result}") + else: + result = self._local_train(current_round, weights) + self.logger.debug(f"[{fox.call_info}]: local train result of round {current_round}: {result}") + fox.server(expect_result=False).fire_event("metrics", {"round": current_round, "y": 10}) + return result + + def _local_train(self, current_round, weights): + if fox.is_aborted: + self.logger.debug("training aborted") + return None + self.logger.info(f"[{fox.call_info}] local trained round {current_round} {weights} {type(weights)}") + return weights + self.delta + + @fox.collab + def evaluate(self, model): + self.logger.debug(f"[{fox.call_info}] evaluate") + return random.random() diff --git a/nvflare/fox/examples/np/algos/filters.py b/nvflare/fox/examples/np/algos/filters.py new file mode 100644 index 0000000000..8faffec79a --- /dev/null +++ b/nvflare/fox/examples/np/algos/filters.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +from nvflare.fox import fox +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class AddNoiseToModel: + + def __init__(self): + self.logger = get_obj_logger(self) + + @fox.call_filter + def add_noise(self, func_kwargs: dict): + direction = fox.filter_direction + qual_func_name = fox.qual_func_name + self.logger.debug(f"[{fox.call_info}] filtering call: {func_kwargs=} {direction=} {qual_func_name=}") + weights_key = "weights" + weights = func_kwargs.get(weights_key) + if weights is None: + # nothing to filter + self.logger.info(f"nothing to filter in {func_kwargs}") + return func_kwargs + + # add some noise to weights + noise = random.random() + self.logger.debug(f"[{fox.call_info}] adding noise {noise}") + weights += noise + func_kwargs[weights_key] = weights + self.logger.info(f"[{fox.call_info}] weights after adding noise {noise}: {weights}") + return func_kwargs + + +class Print: + + def __init__(self): + self.logger = get_obj_logger(self) + + @fox.call_filter + def print_call(self, func_kwargs: dict): + self.logger.info(f"[{fox.call_info}] print_call on fox ctx {id(fox.context)}") + direction = fox.filter_direction + qual_func_name = fox.qual_func_name + self.logger.info( + f"[{fox.call_info}] printing call ctx {id(fox.context)}: {func_kwargs=} {direction=} {qual_func_name=}" + ) + return func_kwargs + + @fox.result_filter + def print_result(self, result, context): + self.logger.info(f"[{fox.call_info}] print_result on {id(context)} fox ctx {id(fox.context)}") + direction = fox.filter_direction + qual_func_name = fox.qual_func_name + self.logger.info( + f"[{fox.call_info}] printing result ctx {id(fox.context)}: {result=} {direction=} {qual_func_name=}" + ) + return result diff --git a/nvflare/fox/examples/np/algos/strategies/__init__.py b/nvflare/fox/examples/np/algos/strategies/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/fox/examples/np/algos/strategies/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fox/examples/np/algos/strategies/avg_h.py b/nvflare/fox/examples/np/algos/strategies/avg_h.py new file mode 100644 index 0000000000..923270d79b --- /dev/null +++ b/nvflare/fox/examples/np/algos/strategies/avg_h.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox import fox +from nvflare.fox.examples.np.algos.utils import parse_array_def +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class NPHierarchicalFedAvg: + + def __init__(self, initial_model, num_rounds=10): + self.num_rounds = num_rounds + self.initial_model = initial_model + self._initial_model = parse_array_def(initial_model) + self.name = self.__class__.__name__ + self.logger = get_obj_logger(self) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + current_model = self._initial_model + for i in range(self.num_rounds): + current_model = self._do_one_round(i, current_model) + if current_model is None: + self.logger.error(f"training failed in round {i}") + break + score = self._do_eval(current_model) + self.logger.info(f"[{fox.call_info}]: eval score in round {i}: {score}") + self.logger.info(f"FINAL MODEL: {current_model}") + return current_model + + def _do_eval(self, model): + results = fox.leaf_clients.evaluate(model) + total = 0.0 + for n, v in results: + self.logger.info(f"[{fox.call_info}]: got eval result from client {n}: {v}") + total += v + num_results = len(results) + return total / num_results if num_results > 0 else 0.0 + + def _do_one_round(self, r, current_model): + total = 0 + results = fox.child_clients.train(r, current_model) + for n, v in results: + self.logger.info(f"[{fox.call_info}] round {r}: got group result from client {n}: {v}") + total += v + + num_results = len(results) + return total / num_results if num_results > 0 else None diff --git a/nvflare/fox/examples/np/algos/strategies/avg_intime.py b/nvflare/fox/examples/np/algos/strategies/avg_intime.py new file mode 100644 index 0000000000..b735da7afe --- /dev/null +++ b/nvflare/fox/examples/np/algos/strategies/avg_intime.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading + +from nvflare.fox import fox +from nvflare.fox.api.constants import ContextKey +from nvflare.fox.examples.np.algos.utils import parse_array_def +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class _AggrResult: + + def __init__(self): + self.total = 0 + self.count = 0 + self.lock = threading.Lock() + + +class NPFedAvgInTime: + + def __init__(self, initial_model, num_rounds=10, timeout=2.0): + self.num_rounds = num_rounds + self.initial_model = initial_model + self.timeout = timeout + self.name = "NPFedAvgInTime" + self.logger = get_obj_logger(self) + self._init_model = parse_array_def(initial_model) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + current_model = fox.get_prop(ContextKey.RESULT, self._init_model) + for i in range(self.num_rounds): + current_model = self._do_one_round(i, current_model) + if current_model is None: + self.logger.error(f"training failed at round {i}") + break + score = self._do_eval(current_model) + self.logger.info(f"[{fox.call_info}]: eval score in round {i}: {score}") + self.logger.info(f"FINAL MODEL: {current_model}") + return current_model + + def _do_eval(self, model): + results = fox.clients.evaluate(model) + total = 0.0 + for n, v in results: + self.logger.info(f"[{fox.call_info}]: got eval result from client {n}: {v}") + total += v + + num_results = len(results) + return total / num_results if num_results > 0 else 0.0 + + def _do_one_round(self, r, current_model): + aggr_result = _AggrResult() + + # try to get the configured timeout value + timeout = fox.get_app_prop("default_timeout", self.timeout) + self.logger.info(f"got timeout: {timeout}") + fox.clients( + timeout=timeout, + process_resp_cb=self._accept_train_result, + aggr_result=aggr_result, + ).train(r, current_model) + + if aggr_result.count == 0: + return None + else: + result = aggr_result.total / aggr_result.count + self.logger.info(f"[{fox.call_info}] round {r}: aggr result from {aggr_result.count} clients: {result}") + return result + + def _accept_train_result(self, gcc, result, aggr_result: _AggrResult): + self.logger.info(f"[{fox.call_info}] got train result from {fox.caller} {result}") + with aggr_result.lock: + aggr_result.total += result + aggr_result.count += 1 + return None diff --git a/nvflare/fox/examples/np/algos/strategies/avg_para.py b/nvflare/fox/examples/np/algos/strategies/avg_para.py new file mode 100644 index 0000000000..041e4f8a2f --- /dev/null +++ b/nvflare/fox/examples/np/algos/strategies/avg_para.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox import fox +from nvflare.fox.examples.np.algos.utils import parse_array_def +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class NPFedAvgParallel: + + def __init__(self, initial_model, num_rounds=10): + self.num_rounds = num_rounds + self.initial_model = initial_model + self._initial_model = parse_array_def(initial_model) + self.name = "NPFedAvgParallel" + self.logger = get_obj_logger(self) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + current_model = self._initial_model + for i in range(self.num_rounds): + current_model = self._do_one_round(i, current_model) + if current_model is None: + self.logger.error(f"training failed at round {i}") + break + score = self._do_eval(current_model) + self.logger.info(f"[{fox.call_info}]: eval score in round {i}: {score}") + return current_model + + def _do_eval(self, model): + results = fox.clients.evaluate(model) + total = 0.0 + for n, v in results: + self.logger.info(f"[{fox.call_info}]: got eval result from client {n}: {v}") + total += v + + num_results = len(results) + return total / len(results) if num_results > 0 else 0.0 + + def _do_one_round(self, r, current_model): + total = 0 + results = fox.clients(timeout=4, blocking=False, target="client").train(r, current_model) + for n, v in results: + # the value 'v' could be an exception! + if isinstance(v, Exception): + # this site encountered problem + self.logger.error(f"[{fox.call_info}] round {r}: got exception from client {n}: {v}") + raise v + + self.logger.info(f"[{fox.call_info}] round {r}: got group result from client {n}: {v}") + total += v + num_results = len(results) + return total / len(results) if num_results > 0 else None diff --git a/nvflare/fox/examples/np/algos/strategies/avg_para_tc.py b/nvflare/fox/examples/np/algos/strategies/avg_para_tc.py new file mode 100644 index 0000000000..fb9802ea3f --- /dev/null +++ b/nvflare/fox/examples/np/algos/strategies/avg_para_tc.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox import fox +from nvflare.fox.examples.np.algos.utils import parse_array_def +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class NPFedAvgParallelWithTrafficControl: + + def __init__(self, initial_model, num_rounds=10, parallel=2): + self.num_rounds = num_rounds + self.initial_model = initial_model + self._initial_model = parse_array_def(initial_model) + self.parallel = parallel + self.logger = get_obj_logger(self) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + current_model = self._initial_model + for i in range(self.num_rounds): + current_model = self._do_one_round(i, current_model) + if current_model is None: + self.logger.error(f"training failed at round {i}") + break + self.logger.info(f"FINAL MODEL: {current_model}") + return current_model + + def _do_one_round(self, r, current_model): + total = 0 + results = fox.clients(timeout=4, blocking=False, target="client", parallel=self.parallel).train( + r, current_model + ) + + for n, v in results: + # the value 'v' could be an exception! + if isinstance(v, Exception): + # this site encountered problem + self.logger.error(f"[{fox.call_info}] round {r}: got exception from client {n}: {v}") + raise v + + self.logger.info(f"[{fox.call_info}] round {r}: got group result from client {n}: {v}") + total += v + + num_results = len(results) + return total / num_results if num_results > 0 else None diff --git a/nvflare/fox/examples/np/algos/strategies/avg_seq.py b/nvflare/fox/examples/np/algos/strategies/avg_seq.py new file mode 100644 index 0000000000..f5db9a4fa2 --- /dev/null +++ b/nvflare/fox/examples/np/algos/strategies/avg_seq.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from nvflare.fox import fox +from nvflare.fox.examples.np.algos.utils import parse_array_def, save_np_model +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class NPFedAvgSequential: + + def __init__(self, initial_model, num_rounds=10): + self.name = "NPFedAvgSequential" + self.num_rounds = num_rounds + self.initial_model = initial_model # need to remember init for job API to work! + self._initial_model = parse_array_def(initial_model) + self.logger = get_obj_logger(self) + self.client_weights = None + + @fox.init + def init(self): + self.logger.info("fox init NPFedAvgSequential") + weight_config = fox.get_app_prop("client_weight_config", {}) + client_weights = {} + total = 0 + for c in fox.clients: + w = weight_config.get(c.name, 100) + client_weights[c.name] = w + total += w + + # normalize weights + for c in fox.clients: + client_weights[c.name] = client_weights[c.name] / total + + self.client_weights = client_weights + self.logger.info("client_weights: {}".format(client_weights)) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + current_model = self._initial_model + for i in range(self.num_rounds): + current_model = self._do_one_round(i, current_model) + + # save model to work dir + file_name = os.path.join(fox.workspace.get_work_dir(), "model.npy") + save_np_model(current_model, file_name) + self.logger.info(f"FINAL RESULT: {current_model}") + return current_model + + def _do_one_round(self, r, current_model): + total = 0 + for c in fox.clients: + result = c(expect_result=True, timeout=2.0, optional=True, secure=False).train(r, current_model) + self.logger.info(f"[{fox.call_info}] round {r}: got result from client {c.name}: {result}") + total += result * self.client_weights[c.name] + return total diff --git a/nvflare/fox/examples/np/algos/strategies/cyclic.py b/nvflare/fox/examples/np/algos/strategies/cyclic.py new file mode 100644 index 0000000000..440357fbdf --- /dev/null +++ b/nvflare/fox/examples/np/algos/strategies/cyclic.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random + +from nvflare.fox import fox +from nvflare.fox.examples.np.algos.utils import load_np_model, parse_array_def, save_np_model +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class NPCyclic: + + def __init__(self, initial_model, num_rounds=2): + self.num_rounds = num_rounds + self.initial_model = initial_model + self._initial_model = parse_array_def(initial_model) + self.final_model = None + self.logger = get_obj_logger(self) + + @fox.init + def check_initial_model(self): + if isinstance(self._initial_model, str): + # this is name of the file that contains model data + # load the model. + resource_dir = fox.workspace.get_resource_dir("data") + file_name = os.path.join(resource_dir, self._initial_model) + self._initial_model = load_np_model(file_name) + self.logger.info(f"loaded initial model from {file_name}: {self._initial_model}") + + @fox.algo + def execute(self): + current_model = self._initial_model + for current_round in range(self.num_rounds): + current_model = self._do_one_round(current_round, current_model) + if current_model is None: + self.logger.error(f"training failed at round {current_round}") + break + self.logger.info(f"[{fox.call_info}] final result: {current_model}") + self.final_model = current_model + return current_model + + @fox.final + def save_result(self): + final_result = fox.get_result() + file_name = os.path.join(fox.workspace.get_work_dir(), "final_model.npy") + save_np_model(final_result, file_name) + self.logger.info(f"[{fox.call_info}]: saved final model {final_result} to {file_name}") + + def _do_one_round(self, current_round, current_model): + # Note: fox.clients always returns a new copy of all clients! + clients = fox.clients + random.shuffle(clients) + for c in clients: + current_model = c.train(current_round, current_model) + if current_model is None: + self.logger.error(f"training failed on client {c.name} at round {current_round}") + return None + self.logger.info(f"[{fox.call_info}] result from {c.name}: {current_model}") + return current_model diff --git a/nvflare/fox/examples/np/algos/swarm.py b/nvflare/fox/examples/np/algos/swarm.py new file mode 100644 index 0000000000..ec4b0cbfb2 --- /dev/null +++ b/nvflare/fox/examples/np/algos/swarm.py @@ -0,0 +1,124 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import threading +import traceback + +from nvflare.fox import fox +from nvflare.fox.examples.np.algos.utils import parse_array_def, save_np_model +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class NPSwarm: + + def __init__(self, initial_model, num_rounds=10): + self.num_rounds = num_rounds + self.initial_model = initial_model + self._initial_model = parse_array_def(initial_model) + self.waiter = threading.Event() + self.logger = get_obj_logger(self) + + @fox.algo + def execute(self): + fox.register_event_handler("all_done", self._all_done) + + # randomly pick a client to start + start_client_idx = random.randint(0, len(fox.clients) - 1) + start_client = fox.clients[start_client_idx] + start_client(target="client").start(self.num_rounds, self._initial_model) + while not fox.is_aborted: + if self.waiter.wait(timeout=0.5): + break + + def _all_done(self, event_type: str, data): + self.logger.info(f"[{fox.call_info}]: received {event_type} from client: {fox.caller}: {data}") + self.all_done(data) + + @fox.collab + def all_done(self, reason: str): + self.logger.info(f"[{fox.call_info}]: all done from client: {fox.caller}: {reason}") + self.waiter.set() + + +class NPSwarmClient: + + def __init__(self, delta: float): + self.delta = delta + self.logger = get_obj_logger(self) + + @fox.init + def init(self): + # This example shows that there could be multiple listeners for the same event + fox.register_event_handler("final_model", self._accept_final_model) + fox.register_event_handler("final_model", self._save_final_model) + + @fox.collab + def train(self, weights, current_round): + self.logger.info(f"[{fox.call_info}]: train asked by {fox.caller}: {current_round=}") + return weights + self.delta + + def sag(self, model, current_round): + # results = fox.clients.train(model, current_round) + results = fox.other_clients.train(model, current_round) + total = 0 + for n, v in results: + total += v + return total / len(results) + + @fox.collab + def swarm_learn(self, num_rounds, model, current_round): + self.logger.info(f"[{fox.call_info}]: swarm learn asked: {num_rounds=} {current_round=} {model=}") + new_model = self.sag(model, current_round) + + self.logger.info(f"[{fox.call_info}]: trained model {new_model=}") + if current_round == num_rounds - 1: + # all done + result = fox.clients(expect_result=True).fire_event("final_model", new_model) + for n, v in result: + self.logger.info(f"[{fox.call_info}] final_model reply from {n}: {v}") + self.logger.info("notify server all done!") + try: + fox.server(expect_result=False).all_done("OK") + except Exception as ex: + traceback.print_exc() + self.logger.error(f"exception occurred in learning: {type(ex)}") + self.logger.info("Swarm Training is DONE!") + return + + # determine next client + next_round = current_round + 1 + next_client_idx = random.randint(0, len(fox.clients) - 1) + self.logger.debug(f"chose aggr client for round {next_round}: {next_client_idx}") + next_client = fox.clients[next_client_idx] + next_client(expect_result=False).swarm_learn(num_rounds, new_model, next_round) + + @fox.collab + def start(self, num_rounds, initial_model): + self.logger.info(f"[{fox.call_info}]: starting swarm learning") + self.swarm_learn(num_rounds, initial_model, 0) + + def _accept_final_model(self, event_type: str, model): + # accept the final model + # write model to disk + self.logger.info(f"[{fox.call_info}]: received event '{event_type}' from {fox.caller}: {model}") + return "received" + + def _save_final_model(self, event_type: str, model): + # accept the final model + # write model to disk + file_name = os.path.join(fox.workspace.get_work_dir(), "final_model.npy") + save_np_model(model, file_name) + self.logger.info(f"[{fox.call_info}]: saved model {model} to {file_name}") + return "saved" diff --git a/nvflare/fox/examples/np/algos/utils.py b/nvflare/fox/examples/np/algos/utils.py new file mode 100644 index 0000000000..e5e146c835 --- /dev/null +++ b/nvflare/fox/examples/np/algos/utils.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def parse_array_def(array_def): + if array_def is None: + return array_def + + if isinstance(array_def, np.ndarray): + return array_def + + if isinstance(array_def, str): + # this is base name of the file that contains NP array + return array_def + + if isinstance(array_def, list): + return np.array(array_def, dtype=np.float32) + else: + raise ValueError(f"unsupported array def: {array_def}") + + +def parse_state_dict(d): + result = {} + for k, v in d.items(): + result[k] = parse_array_def(v) + return result + + +def parse_model_def(model_def): + if isinstance(model_def, dict): + return parse_state_dict(model_def) + else: + return parse_array_def(model_def) + + +def save_np_model(model: np.ndarray, file_name: str): + np.save(file_name, model) + + +def load_np_model(file_name: str): + return np.load(file_name) + + +def add(model: dict, to_model: dict): + """Add specified model to another model + + Args: + model: the model to be added + to_model: the model to be added to. + + Returns: the updated model + Notes: the to_model is updated + + """ + for k, v in model.items(): + if k not in to_model: + to_model[k] = v + else: + to_model[k] += v + return to_model + + +def div(model: dict, value): + """Divide a model by a specified value + + Args: + model: the model to be divided + value: the value to divide the model with + + Returns: the updated model + + """ + for k, v in model.items(): + model[k] = v / value + return model diff --git a/nvflare/fox/examples/np/algos/widgets.py b/nvflare/fox/examples/np/algos/widgets.py new file mode 100644 index 0000000000..a6f19611f3 --- /dev/null +++ b/nvflare/fox/examples/np/algos/widgets.py @@ -0,0 +1,34 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox import fox +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class MetricReceiver: + + def __init__(self): + self.logger = get_obj_logger(self) + + @fox.collab + def accept_metric(self, metrics: dict): + self.logger.info(f"[{fox.callee}] received metric report from {fox.caller}: {metrics}") + + @fox.init + def init(self): + fox.register_event_handler("metrics", self._accept_metric) + self.logger.info("MetricReceiver initialized!") + + def _accept_metric(self, event_type: str, data): + self.logger.info(f"[{fox.callee}] received metrics event '{event_type}' from {fox.caller}: {data}") + return "OK" diff --git a/nvflare/fox/examples/np/cyclic.py b/nvflare/fox/examples/np/cyclic.py new file mode 100644 index 0000000000..70d4575076 --- /dev/null +++ b/nvflare/fox/examples/np/cyclic.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.cyclic import NPCyclic +from nvflare.fox.sim.simulator import Simulator + + +def main(): + simple_logging(logging.DEBUG) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="cyclic", + server=NPCyclic(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=2), + client=NPTrainer(delta=1.0), + num_clients=2, + ) + + final_result = simulator.run() + print(f"final model: {final_result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/cyclic_avg.py b/nvflare/fox/examples/np/cyclic_avg.py new file mode 100644 index 0000000000..65911d95a1 --- /dev/null +++ b/nvflare/fox/examples/np/cyclic_avg.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox import fox +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.avg_para import NPFedAvgParallel +from nvflare.fox.examples.np.algos.strategies.cyclic import NPCyclic +from nvflare.fox.sim.simulator import Simulator +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class Controller: + + def __init__( + self, + initial_model, + cyclic_rounds, + avg_rounds, + ): + self.initial_model = initial_model + self.cyclic_rounds = cyclic_rounds + self.avg_rounds = avg_rounds + self.logger = get_obj_logger(self) + + @fox.algo + def run(self): + self.logger.info("running cyclic ...") + ctl = NPCyclic(self.initial_model, num_rounds=self.cyclic_rounds) + result = ctl.execute() + self.logger.info(f"final cyclic model: {result}") + + self.logger.info("running fed-avg ...") + ctl = NPFedAvgParallel(initial_model=result, num_rounds=self.avg_rounds) + result = ctl.execute() + self.logger.info(f"final model: {result}") + return result + + +def main(): + simple_logging(logging.DEBUG) + exp_name = "cyclic_avg" + + server = Controller(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], cyclic_rounds=2, avg_rounds=3) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name=exp_name, + server=server, + client=NPTrainer(delta=1.0), + num_clients=3, + ) + + final_result = simulator.run() + print(f"final model: {final_result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/cyclic_file.py b/nvflare/fox/examples/np/cyclic_file.py new file mode 100644 index 0000000000..7b395db630 --- /dev/null +++ b/nvflare/fox/examples/np/cyclic_file.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.cyclic import NPCyclic +from nvflare.fox.sim.simulator import Simulator + + +def main(): + simple_logging(logging.DEBUG) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="cyclic_file", + server=NPCyclic(initial_model="initial_model.npy", num_rounds=2), + client=NPTrainer(delta=1.0), + num_clients=2, + ) + + simulator.set_server_resource_dirs({"data": "/Users/yanc/NVFlare/sandbox/data"}) + + final_result = simulator.run() + print(f"final model: {final_result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/cyclic_runner.py b/nvflare/fox/examples/np/cyclic_runner.py new file mode 100644 index 0000000000..c1c5e2e5d1 --- /dev/null +++ b/nvflare/fox/examples/np/cyclic_runner.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.app import ClientApp, ServerApp +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.cyclic import NPCyclic +from nvflare.fox.sim.simulator import AppRunner + + +def main(): + simple_logging(logging.DEBUG) + + runner = AppRunner( + root_dir=get_experiment_root(), + experiment_name="cyclic_runner", + server_app=ServerApp(NPCyclic(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=2)), + client_app=ClientApp(NPTrainer(delta=1.0)), + num_clients=2, + ) + + final_result = runner.run() + print(f"final model: {final_result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/fed_avg_h.py b/nvflare/fox/examples/np/fed_avg_h.py new file mode 100644 index 0000000000..fae8983504 --- /dev/null +++ b/nvflare/fox/examples/np/fed_avg_h.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.client import NPHierarchicalTrainer +from nvflare.fox.examples.np.algos.strategies.avg_h import NPHierarchicalFedAvg +from nvflare.fox.examples.np.algos.widgets import MetricReceiver +from nvflare.fox.sim.simulator import Simulator + + +def main(): + simple_logging(logging.DEBUG) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="fedavg_h", + server=NPHierarchicalFedAvg(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=3), + client=NPHierarchicalTrainer(delta=1.0), + server_objects={"metric_receiver": MetricReceiver()}, + num_clients=(3, 2), + ) + + result = simulator.run() + print(f"Final Result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/fed_avg_intime.py b/nvflare/fox/examples/np/fed_avg_intime.py new file mode 100644 index 0000000000..91187da81e --- /dev/null +++ b/nvflare/fox/examples/np/fed_avg_intime.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.filters import AddNoiseToModel, Print +from nvflare.fox.examples.np.algos.strategies.avg_intime import NPFedAvgInTime +from nvflare.fox.examples.np.algos.widgets import MetricReceiver +from nvflare.fox.sim.simulator import Simulator + + +def main(): + simple_logging(logging.DEBUG) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="fedavg_intime", + server=NPFedAvgInTime(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=2), + client=NPTrainer(delta=1.0), + server_objects={"metric_receiver": MetricReceiver()}, + num_clients=2, + ) + + simulator.add_server_outgoing_call_filters("*.train", [AddNoiseToModel()]) + simulator.add_server_incoming_result_filters("*.train", [Print()]) + simulator.set_server_prop("default_timeout", 8.0) + + simulator.add_client_incoming_call_filters("*.train", [Print()]) + simulator.add_client_outgoing_result_filters("*.train", [Print()]) + simulator.set_client_prop("default_timeout", 5.0) + + result = simulator.run() + print(f"final model: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/fed_avg_para.py b/nvflare/fox/examples/np/fed_avg_para.py new file mode 100644 index 0000000000..0b3a076dad --- /dev/null +++ b/nvflare/fox/examples/np/fed_avg_para.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.avg_para import NPFedAvgParallel +from nvflare.fox.examples.np.algos.widgets import MetricReceiver +from nvflare.fox.sim.simulator import Simulator + + +def main(): + simple_logging(logging.DEBUG) + + server = NPFedAvgParallel(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], num_rounds=2) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="fedavg_para", + server=server, + client=NPTrainer(delta=1.0), + server_objects={"metric_receiver": MetricReceiver()}, + num_clients=10, + ) + + result = simulator.run() + print(f"Final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/fed_avg_para_tc.py b/nvflare/fox/examples/np/fed_avg_para_tc.py new file mode 100644 index 0000000000..7c3e73aff4 --- /dev/null +++ b/nvflare/fox/examples/np/fed_avg_para_tc.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.avg_para_tc import NPFedAvgParallelWithTrafficControl +from nvflare.fox.examples.np.algos.widgets import MetricReceiver +from nvflare.fox.sim.simulator import Simulator + + +def main(): + simple_logging(logging.DEBUG) + + server = NPFedAvgParallelWithTrafficControl( + initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], + num_rounds=2, + parallel=3, + ) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="fedavg_para_tc", + server=server, + client=NPTrainer(delta=1.0, delay=1.5), + server_objects={"metric_receiver": MetricReceiver()}, + num_clients=10, + ) + + result = simulator.run() + print(f"Final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/fed_avg_seq.py b/nvflare/fox/examples/np/fed_avg_seq.py new file mode 100644 index 0000000000..fb13c72d62 --- /dev/null +++ b/nvflare/fox/examples/np/fed_avg_seq.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.avg_seq import NPFedAvgSequential +from nvflare.fox.examples.np.algos.widgets import MetricReceiver +from nvflare.fox.sim.simulator import Simulator + + +def main(): + simple_logging(logging.DEBUG) + + server = NPFedAvgSequential( + num_rounds=2, + initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + ) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="fedavg_seq", + server=server, + client=NPTrainer(delta=1.0), + server_objects={"metric_receiver": MetricReceiver()}, + num_clients=2, + ) + simulator.set_server_prop("client_weight_config", {"site-1": 70, "site-2": 100}) + simulator.set_client_prop("client_delta", {"site-1": 1.0, "site-2": 2.0}) + + result = simulator.run() + print(f"Final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/fed_avg_stream.py b/nvflare/fox/examples/np/fed_avg_stream.py new file mode 100644 index 0000000000..24c8e6daed --- /dev/null +++ b/nvflare/fox/examples/np/fed_avg_stream.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.avg_stream import NPFedAvgStream, NPTrainer +from nvflare.fox.sim.simulator import Simulator + + +def main(): + simple_logging(logging.DEBUG) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="fedavg_stream", + server=NPFedAvgStream(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=4), + client=NPTrainer(delta=1.0), + num_clients=2, + ) + + result = simulator.run() + print(f"Final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipe_cyclic.py b/nvflare/fox/examples/np/recipe_cyclic.py new file mode 100644 index 0000000000..371d529b05 --- /dev/null +++ b/nvflare/fox/examples/np/recipe_cyclic.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.cyclic import NPCyclic +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_cyclic", _make_recipe) + + +def _make_recipe(job_name): + return FoxRecipe( + job_name=job_name, + server=NPCyclic(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=2), + client=NPTrainer(delta=1.0), + ) + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipe_cyclic_avg.py b/nvflare/fox/examples/np/recipe_cyclic_avg.py new file mode 100644 index 0000000000..d111b84eca --- /dev/null +++ b/nvflare/fox/examples/np/recipe_cyclic_avg.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from nvflare.fox import fox +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.avg_para import NPFedAvgParallel +from nvflare.fox.examples.np.algos.strategies.cyclic import NPCyclic +from nvflare.fox.examples.np.algos.utils import save_np_model +from nvflare.fox.sys.recipe import FoxRecipe +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class Controller: + + def __init__( + self, + initial_model, + cyclic_rounds, + avg_rounds, + ): + self.initial_model = initial_model + self.cyclic_rounds = cyclic_rounds + self.avg_rounds = avg_rounds + self.logger = get_obj_logger(self) + + @fox.algo + def run(self): + self.logger.info("running cyclic ...") + ctl = NPCyclic(self.initial_model, num_rounds=self.cyclic_rounds) + result = ctl.execute() + + file_name = os.path.join(fox.workspace.get_work_dir(), "cyclic_model.npy") + save_np_model(result, file_name) + self.logger.info(f"[{fox.call_info}]: saved cyclic model {result} to {file_name}") + + self.logger.info("running fed-avg ...") + ctl = NPFedAvgParallel(initial_model=result, num_rounds=self.avg_rounds) + return ctl.execute() + + @fox.final + def save_result(self): + final_result = fox.get_result() + file_name = os.path.join(fox.workspace.get_work_dir(), "final_model.npy") + save_np_model(final_result, file_name) + self.logger.info(f"[{fox.call_info}]: saved final model {final_result} to {file_name}") + + +def main(): + export_recipe("fox_cyclic_avg", _make_recipe) + + +def _make_recipe(job_name): + return FoxRecipe( + job_name=job_name, + server=Controller( + initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + cyclic_rounds=2, + avg_rounds=3, + ), + client=NPTrainer(delta=1.0), + ) + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipe_cyclic_file.py b/nvflare/fox/examples/np/recipe_cyclic_file.py new file mode 100644 index 0000000000..21876fb8a5 --- /dev/null +++ b/nvflare/fox/examples/np/recipe_cyclic_file.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.cyclic import NPCyclic +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_cyclic_file", _make_recipe) + + +def _make_recipe(job_name): + recipe = FoxRecipe( + job_name=job_name, + server=NPCyclic(initial_model="initial_model.npy", num_rounds=2), + client=NPTrainer(delta=1.0), + ) + recipe.set_server_resource_dirs({"data": "/Users/yanc/NVFlare/sandbox/data"}) + return recipe + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipe_fed_avg_h.py b/nvflare/fox/examples/np/recipe_fed_avg_h.py new file mode 100644 index 0000000000..5439685adc --- /dev/null +++ b/nvflare/fox/examples/np/recipe_fed_avg_h.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.np.algos.client import NPHierarchicalTrainer +from nvflare.fox.examples.np.algos.strategies.avg_h import NPHierarchicalFedAvg +from nvflare.fox.examples.np.algos.widgets import MetricReceiver +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_fedavg_h", _make_recipe) + + +def _make_recipe(job_name): + return FoxRecipe( + job_name=job_name, + server=NPHierarchicalFedAvg(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=2), + client=NPHierarchicalTrainer(delta=1.0), + server_objects={"metric_receiver": MetricReceiver()}, + ) + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipe_fed_avg_intime.py b/nvflare/fox/examples/np/recipe_fed_avg_intime.py new file mode 100644 index 0000000000..627f9c668d --- /dev/null +++ b/nvflare/fox/examples/np/recipe_fed_avg_intime.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.filters import AddNoiseToModel, Print +from nvflare.fox.examples.np.algos.strategies.avg_intime import NPFedAvgInTime +from nvflare.fox.examples.np.algos.widgets import MetricReceiver +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_fedavg_intime", _make_recipe) + + +def _make_recipe(job_name): + recipe = FoxRecipe( + job_name=job_name, + server=NPFedAvgInTime(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=2), + server_objects={ + "metric_receiver": MetricReceiver(), + }, + client=NPTrainer(delta=1.0), + ) + + print_filter = Print() + recipe.add_server_outgoing_call_filters("*.train", [AddNoiseToModel()]) + recipe.add_server_incoming_result_filters("*.train", [print_filter]) + recipe.set_server_prop("default_timeout", 5.0) + + recipe.add_client_incoming_call_filters("*.train", [print_filter]) + recipe.add_client_outgoing_result_filters("*.train", [print_filter]) + recipe.set_client_prop("default_timeout", 8.0) + return recipe + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipe_fed_avg_seq.py b/nvflare/fox/examples/np/recipe_fed_avg_seq.py new file mode 100644 index 0000000000..79d2282411 --- /dev/null +++ b/nvflare/fox/examples/np/recipe_fed_avg_seq.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.avg_seq import NPFedAvgSequential +from nvflare.fox.examples.np.algos.widgets import MetricReceiver +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_fedavg_seq", _make_recipe) + + +def _make_recipe(job_name): + recipe = FoxRecipe( + job_name=job_name, + server=NPFedAvgSequential(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=2), + client=NPTrainer(delta=1.0), + server_objects={"metric_receiver": MetricReceiver()}, + ) + recipe.set_server_prop("client_weight_config", {"red": 70, "blue": 100, "silver": 50}) + recipe.set_client_prop("client_delta", {"red": 1.0, "blue": 2.0, "silver": 3.0}) + return recipe + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipe_fed_avg_stream.py b/nvflare/fox/examples/np/recipe_fed_avg_stream.py new file mode 100644 index 0000000000..10d899188b --- /dev/null +++ b/nvflare/fox/examples/np/recipe_fed_avg_stream.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.np.algos.avg_stream import NPFedAvgStream, NPTrainer +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_fedavg_stream", _make_recipe) + + +def _make_recipe(job_name): + return FoxRecipe( + job_name=job_name, + server=NPFedAvgStream(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=2), + client=NPTrainer(delta=1.0), + ) + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipe_fed_avg_tc.py b/nvflare/fox/examples/np/recipe_fed_avg_tc.py new file mode 100644 index 0000000000..3285622c54 --- /dev/null +++ b/nvflare/fox/examples/np/recipe_fed_avg_tc.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.np.algos.client import NPTrainer +from nvflare.fox.examples.np.algos.strategies.avg_para_tc import NPFedAvgParallelWithTrafficControl +from nvflare.fox.examples.np.algos.widgets import MetricReceiver +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_fedavg_tc", _make_recipe) + + +def _make_recipe(job_name): + return FoxRecipe( + job_name=job_name, + server=NPFedAvgParallelWithTrafficControl( + initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + num_rounds=2, + parallel=2, + ), + client=NPTrainer(delta=1.0, delay=2.0), + server_objects={"metric_receiver": MetricReceiver()}, + ) + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipe_swarm.py b/nvflare/fox/examples/np/recipe_swarm.py new file mode 100644 index 0000000000..7a6e6903f2 --- /dev/null +++ b/nvflare/fox/examples/np/recipe_swarm.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.np.algos.swarm import NPSwarm, NPSwarmClient +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_swarm", _make_recipe) + + +def _make_recipe(job_name): + return FoxRecipe( + job_name=job_name, + server=NPSwarm(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=5), + client=NPSwarmClient(delta=1.0), + ) + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/np/recipes/__init__.py b/nvflare/fox/examples/np/recipes/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/fox/examples/np/recipes/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fox/examples/np/recipes/cyclic_recipe.py b/nvflare/fox/examples/np/recipes/cyclic_recipe.py new file mode 100644 index 0000000000..3ca9cef037 --- /dev/null +++ b/nvflare/fox/examples/np/recipes/cyclic_recipe.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples.np.algos.strategies.cyclic import NPCyclic +from nvflare.fox.sys.recipe import FoxRecipe + + +class CyclicRecipe(FoxRecipe): + + def __init__( + self, + job_name, + initial_model, + num_rounds, + client, + ): + FoxRecipe.__init__( + self, + job_name, + server=NPCyclic(initial_model, num_rounds), + client=client, + ) diff --git a/nvflare/fox/examples/np/swarm.py b/nvflare/fox/examples/np/swarm.py new file mode 100644 index 0000000000..9132ead1cc --- /dev/null +++ b/nvflare/fox/examples/np/swarm.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.swarm import NPSwarm, NPSwarmClient +from nvflare.fox.sim.simulator import Simulator + + +def main(): + simple_logging(logging.DEBUG) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="swarm", + server=NPSwarm(initial_model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], num_rounds=5), + client=NPSwarmClient(delta=1.0), + num_clients=3, + ) + + result = simulator.run() + print(f"Final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/__init__.py b/nvflare/fox/examples/pt/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/fox/examples/pt/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fox/examples/pt/filters.py b/nvflare/fox/examples/pt/filters.py new file mode 100644 index 0000000000..a54fd5affe --- /dev/null +++ b/nvflare/fox/examples/pt/filters.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from nvflare.fox.api.ctx import Context +from nvflare.fox.api.filter import CallFilter, ResultFilter +from nvflare.fox.sys.downloader import Downloader, download_tensors +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class OutgoingModelCallFilter(CallFilter): + + def __init__(self, model_arg_name: str): + super().__init__() + self.model_arg_name = model_arg_name + self.logger = get_obj_logger(self) + + def filter_call(self, func_kwargs: dict, context: Context): + arg_value = func_kwargs.get(self.model_arg_name) + if not arg_value: + return func_kwargs + + num_receivers = context.target_group_size + self.logger.info(f"target group size={num_receivers}") + + downloader = Downloader( + num_receivers=num_receivers, + timeout=5.0, + ) + model = downloader.add_tensors(arg_value, 0) + func_kwargs[self.model_arg_name] = model + return func_kwargs + + +class IncomingModelCallFilter(CallFilter): + + def __init__(self, model_arg_name: str): + super().__init__() + self.model_arg_name = model_arg_name + self.logger = get_obj_logger(self) + + def filter_call(self, func_kwargs: dict, context: Context): + arg_value = func_kwargs.get(self.model_arg_name) + if not arg_value: + return func_kwargs + + err, model = download_tensors(ref=arg_value, per_request_timeout=5.0) + if err: + self.logger.error(f"error filtering call arg {arg_value}: {err}") + else: + func_kwargs[self.model_arg_name] = model + return func_kwargs + + +class OutgoingModelResultFilter(ResultFilter): + + def filter_result(self, result: Any, context: Context): + if not isinstance(result, dict): + return result + + downloader = Downloader( + num_receivers=1, + timeout=5.0, + ) + return downloader.add_tensors(result, 0) + + +class IncomingModelResultFilter(ResultFilter): + + def __init__(self): + super().__init__() + self.logger = get_obj_logger(self) + + def filter_result(self, result: Any, context: Context): + err, model = download_tensors(ref=result, per_request_timeout=5.0) + if err: + self.logger.error(f"error filtering result {result}: {err}") + return result + else: + return model diff --git a/nvflare/fox/examples/pt/filters2.py b/nvflare/fox/examples/pt/filters2.py new file mode 100644 index 0000000000..1db8eb654e --- /dev/null +++ b/nvflare/fox/examples/pt/filters2.py @@ -0,0 +1,76 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from nvflare.fox import fox +from nvflare.fox.sys.downloader import Downloader, download_tensors +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class ModelFilter: + + def __init__(self, model_arg_name: str): + super().__init__() + self.model_arg_name = model_arg_name + self.logger = get_obj_logger(self) + + @fox.out_call_filter + def prepare_weights_for_download(self, func_kwargs: dict): + arg_value = func_kwargs.get(self.model_arg_name) + if not arg_value: + return func_kwargs + + num_receivers = fox.context.target_group_size + self.logger.info(f"target group size={num_receivers}") + + downloader = Downloader( + num_receivers=num_receivers, + timeout=5.0, + ) + model = downloader.add_tensors(arg_value, 0) + func_kwargs[self.model_arg_name] = model + return func_kwargs + + @fox.in_call_filter + def download_weights(self, func_kwargs: dict): + arg_value = func_kwargs.get(self.model_arg_name) + if not arg_value: + return func_kwargs + + err, model = download_tensors(ref=arg_value, per_request_timeout=5.0) + if err: + self.logger.error(f"error filtering call arg {arg_value}: {err}") + else: + func_kwargs[self.model_arg_name] = model + return func_kwargs + + @fox.out_result_filter + def prepare_result_for_download(self, result: Any): + if not isinstance(result, dict): + return result + + downloader = Downloader( + num_receivers=1, + timeout=5.0, + ) + return downloader.add_tensors(result, 0) + + @fox.in_result_filter + def download_result(self, result: Any): + err, model = download_tensors(ref=result, per_request_timeout=5.0) + if err: + self.logger.error(f"error filtering result {result}: {err}") + return result + else: + return model diff --git a/nvflare/fox/examples/pt/pt_avg_filter.py b/nvflare/fox/examples/pt/pt_avg_filter.py new file mode 100644 index 0000000000..e631ecb14d --- /dev/null +++ b/nvflare/fox/examples/pt/pt_avg_filter.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from nvflare.fox import fox +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.pt.utils import add as add_pt +from nvflare.fox.examples.pt.utils import div as div_pt +from nvflare.fox.examples.pt.utils import parse_state_dict +from nvflare.fox.sim.simulator import Simulator +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class PTFedAvg: + + def __init__(self, initial_model, num_rounds=10, timeout=2.0): + self.num_rounds = num_rounds + self.initial_model = initial_model + self.timeout = timeout + self.name = "PTFedAvg" + self.logger = get_obj_logger(self) + self._init_model = parse_state_dict(initial_model) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + current_model = self._init_model + for i in range(self.num_rounds): + current_model = self._do_one_round(i, current_model) + if current_model is None: + self.logger.error(f"training failed at round {i}") + break + self.logger.info(f"FINAL MODEL: {current_model}") + return current_model + + def _do_one_round(self, r, current_model): + aggr_result = {} + + results = fox.clients(timeout=self.timeout).train(r, current_model) + for n, v in results: + add_pt(v, aggr_result) + + num_results = len(results) + aggr_result = div_pt(aggr_result, num_results) if num_results > 0 else None + self.logger.info(f"[{fox.call_info}] round {r}: aggr result from {num_results} clients: {aggr_result}") + return aggr_result + + +class PTTrainer: + + def __init__(self, delta: float): + self.delta = delta + self.logger = get_obj_logger(self) + + @fox.collab + def train(self, current_round, weights): + if fox.is_aborted: + self.logger.debug("training aborted") + return None + + self.logger.debug(f"[{fox.call_info}] training round {current_round}: {weights=}") + result = {} + for k, v in weights.items(): + result[k] = v + self.delta + return result + + +def main(): + simple_logging(logging.DEBUG) + + server = PTFedAvg( + initial_model={ + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + }, + num_rounds=4, + ) + + client = PTTrainer(delta=1.0) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="pt_fedavg_intime", + server=server, + client=client, + num_clients=2, + ) + + result = simulator.run() + print(f"final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/pt_avg_mixed.py b/nvflare/fox/examples/pt/pt_avg_mixed.py new file mode 100644 index 0000000000..102957b9df --- /dev/null +++ b/nvflare/fox/examples/pt/pt_avg_mixed.py @@ -0,0 +1,225 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import threading + +import numpy as np +import torch + +from nvflare.fox import fox +from nvflare.fox.api.constants import BackendType +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.utils import add as add_np +from nvflare.fox.examples.np.algos.utils import div as div_np +from nvflare.fox.examples.np.algos.utils import parse_state_dict as parse_np +from nvflare.fox.examples.pt.utils import add as add_pt +from nvflare.fox.examples.pt.utils import div as div_pt +from nvflare.fox.examples.pt.utils import parse_state_dict as parse_pt +from nvflare.fox.sim.simulator import Simulator +from nvflare.fox.sys.downloader import Downloader, download_arrays, download_tensors +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class _AggrResult: + + def __init__(self): + self.pt_total = {} + self.np_total = {} + self.count = 0 + self.lock = threading.Lock() # ensure update integrity + + +class PTFedAvgMixed: + + def __init__(self, pt_model, np_model, num_rounds=10, timeout=2.0): + self.num_rounds = num_rounds + self.pt_model = pt_model + self.np_model = np_model + self.timeout = timeout + self.name = "PTFedAvgMixed" + self.logger = get_obj_logger(self) + self._pt_model = parse_pt(pt_model) + self._np_model = parse_np(np_model) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + pt_model, np_model = self._pt_model, self._np_model + for i in range(self.num_rounds): + pt_model, np_model = self._do_one_round(i, pt_model, np_model) + if pt_model is None or np_model is None: + self.logger.error(f"training failed at round {i}") + break + self.logger.info(f"FINAL MODEL: {pt_model=} {np_model=}") + return pt_model, np_model + + def _do_one_round(self, r, pt_model, np_model): + aggr_result = _AggrResult() + + grp = fox.clients( + process_resp_cb=self._accept_train_result, + aggr_result=aggr_result, + ) + + if fox.backend_type == BackendType.FLARE: + downloader = Downloader( + num_receivers=grp.size, + timeout=5.0, + ) + model_type = "ref" + pt_model = downloader.add_tensors(pt_model, 0) + np_model = downloader.add_arrays(np_model, 0) + self.logger.info(f"prepared model as ref: {pt_model=} {np_model=}") + else: + model_type = "model" + + grp.train(r, pt_model, np_model, model_type) + + if aggr_result.count == 0: + return None, None + else: + pt_result = aggr_result.pt_total + div_pt(pt_result, aggr_result.count) + self.logger.info( + f"[{fox.call_info}] round {r}: aggr PT result from {aggr_result.count} clients: {pt_result}" + ) + + np_result = aggr_result.np_total + div_np(np_result, aggr_result.count) + self.logger.info( + f"[{fox.call_info}] round {r}: aggr NP result from {aggr_result.count} clients: {np_result}" + ) + return pt_result, np_result + + def _accept_train_result(self, gcc, result, aggr_result: _AggrResult): + self.logger.info(f"[{fox.call_info}] got train result from {fox.caller}: {result}") + + pt_result, np_result, model_type = result + if model_type == "ref": + err, pt_result = download_tensors( + ref=pt_result, + per_request_timeout=5.0, + tensors_received_cb=self._aggregate_tensors, + aggr_result=aggr_result, + ) + if err: + raise RuntimeError(f"failed to download model {pt_result}: {err}") + + err, np_result = download_arrays( + ref=np_result, + per_request_timeout=5.0, + arrays_received_cb=self._aggregate_arrays, + aggr_result=aggr_result, + ) + if err: + raise RuntimeError(f"failed to download NP model file {np_result}: {err}") + else: + with aggr_result.lock: + add_pt(pt_result, aggr_result.pt_total) + add_np(np_result, aggr_result.np_total) + + with aggr_result.lock: + aggr_result.count += 1 + return None + + def _aggregate_tensors(self, td: dict[str, torch.Tensor], aggr_result: _AggrResult): + self.logger.info(f"[{fox.call_info}] aggregating received tensor: {td}") + with aggr_result.lock: + add_pt(td, aggr_result.pt_total) + + def _aggregate_arrays(self, td: dict[str, np.ndarray], aggr_result: _AggrResult): + self.logger.info(f"[{fox.call_info}] aggregating received array: {td}") + with aggr_result.lock: + add_np(td, aggr_result.np_total) + + +class PTTrainer: + + def __init__(self, delta: float): + self.delta = delta + self.logger = get_obj_logger(self) + + @fox.collab + def train(self, current_round, pt_model, np_model, model_type: str): + if fox.is_aborted: + self.logger.debug("training aborted") + return None, None, "" + + self.logger.debug(f"[{fox.call_info}] training round {current_round}: {model_type=} {pt_model=} {np_model=}") + + if model_type == "ref": + err, pt_model = download_tensors(ref=pt_model, per_request_timeout=5.0) + if err: + raise RuntimeError(f"failed to download PT model {pt_model}: {err}") + self.logger.info(f"downloaded PT model {pt_model}") + + err, np_model = download_arrays(ref=np_model, per_request_timeout=5.0) + if err: + raise RuntimeError(f"failed to download NP model {np_model}: {err}") + self.logger.info(f"downloaded NP model {np_model}") + + pt_result = {} + for k, v in pt_model.items(): + pt_result[k] = v + self.delta + + np_result = {} + for k, v in np_model.items(): + np_result[k] = v + self.delta + + if model_type == "ref": + # stream it + downloader = Downloader( + num_receivers=1, + timeout=5.0, + ) + pt_result = downloader.add_tensors(pt_result, 0) + self.logger.info(f"prepared PT result as ref: {pt_result}") + + np_result = downloader.add_arrays(np_result, 0) + self.logger.info(f"prepared NP result as ref: {np_result}") + return pt_result, np_result, model_type + + +def main(): + simple_logging(logging.DEBUG) + + init_model = { + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + } + + server = PTFedAvgMixed( + pt_model=init_model, + np_model=init_model, + num_rounds=4, + ) + + client = PTTrainer(delta=1.0) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="fedavg_mixed", + server=server, + client=client, + num_clients=2, + ) + + result = simulator.run() + print(f"Final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/pt_avg_stream.py b/nvflare/fox/examples/pt/pt_avg_stream.py new file mode 100644 index 0000000000..fbb4f02e1a --- /dev/null +++ b/nvflare/fox/examples/pt/pt_avg_stream.py @@ -0,0 +1,203 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import threading + +import torch + +from nvflare.fox import fox +from nvflare.fox.api.constants import BackendType +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.pt.utils import parse_state_dict +from nvflare.fox.sim.simulator import Simulator +from nvflare.fox.sys.downloader import Downloader, download_tensors +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class _AggrResult: + + def __init__(self): + self.total = {} + self.count = 0 + self.lock = threading.Lock() # ensure update integrity + + +class PTFedAvgStream: + + def __init__(self, initial_model, num_rounds=10, timeout=2.0): + self.num_rounds = num_rounds + self.initial_model = initial_model + self.timeout = timeout + self.name = "PTFedAvgStream" + self.logger = get_obj_logger(self) + self._init_model = parse_state_dict(initial_model) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + current_model = self._init_model + for i in range(self.num_rounds): + current_model = self._do_one_round(i, current_model) + if current_model is None: + self.logger.error(f"training failed at round {i}") + break + self.logger.info(f"FINAL MODEL: {current_model}") + return current_model + + def _do_one_round(self, r, current_model): + aggr_result = _AggrResult() + model2 = {} + for k, v in current_model.items(): + model2[k] = v + 2.0 + + grp = fox.clients( + process_resp_cb=self._accept_train_result, + aggr_result=aggr_result, + ) + + if fox.backend_type == BackendType.FLARE: + downloader = Downloader( + num_receivers=grp.size, + timeout=5.0, + ) + model_type = "ref" + model = downloader.add_tensors(current_model, 0) + model2 = downloader.add_tensors(model2, 0) + self.logger.info(f"prepared model as ref: {model}") + else: + model = current_model + model_type = "model" + + grp.train(r, model, model2, model_type) + + if aggr_result.count == 0: + return None + else: + result = {} + for k, v in aggr_result.total.items(): + result[k] = torch.div(v, aggr_result.count) + self.logger.info(f"[{fox.call_info}] round {r}: aggr result from {aggr_result.count} clients: {result}") + return result + + def _accept_train_result(self, gcc, result, aggr_result: _AggrResult): + self.logger.info(f"[{fox.call_info}] got train result from {fox.caller}: {result}") + + model, model_type = result + if model_type == "ref": + err, model = download_tensors( + ref=model, + per_request_timeout=5.0, + tensors_received_cb=self._aggregate_tensors, + aggr_result=aggr_result, + ) + if err: + raise RuntimeError(f"failed to download model {model}: {err}") + else: + with aggr_result.lock: + for k, v in model.items(): + if k not in aggr_result.total: + aggr_result.total[k] = v + else: + aggr_result.total[k] += v + + aggr_result.count += 1 + return None + + def _aggregate_tensors(self, td: dict[str, torch.Tensor], aggr_result: _AggrResult): + self.logger.info(f"[{fox.call_info}] aggregating received tensor: {td}") + with aggr_result.lock: + for k, v in td.items(): + if k not in aggr_result.total: + aggr_result.total[k] = v + else: + aggr_result.total[k] += v + aggr_result.count += 1 + + +class PTTrainer: + + def __init__(self, delta: float): + self.delta = delta + self.logger = get_obj_logger(self) + + @fox.collab + def train(self, current_round, model1, model2, model_type: str): + if fox.is_aborted: + self.logger.debug("training aborted") + return None, "model" + + self.logger.debug(f"[{fox.call_info}] training round {current_round}: {model_type=} {model1=} {model2=}") + if model_type == "ref": + err, model1 = download_tensors(ref=model1, per_request_timeout=5.0) + if err: + raise RuntimeError(f"failed to download model1 {model1}: {err}") + self.logger.info(f"downloaded model1 {model1}") + + err, model2 = download_tensors(ref=model2, per_request_timeout=5.0) + if err: + raise RuntimeError(f"failed to download model2 {model2}: {err}") + self.logger.info(f"downloaded model2 {model2}") + + weights = {} + for k, v in model1.items(): + weights[k] = v + model2[k] + + result = {} + for k, v in weights.items(): + result[k] = v + self.delta + + if model_type == "ref": + # stream it + downloader = Downloader( + num_receivers=1, + timeout=5.0, + ) + model_type = "ref" + model = downloader.add_tensors(result, 0) + self.logger.info(f"prepared result as ref: {model}") + else: + model = result + model_type = "model" + return model, model_type + + +def main(): + simple_logging(logging.DEBUG) + + server = PTFedAvgStream( + initial_model={ + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + }, + num_rounds=2, + ) + + client = PTTrainer(delta=1.0) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="pt_fedavg_stream", + server=server, + client=client, + num_clients=2, + ) + + result = simulator.run() + print(f"final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/pt_avg_stream2.py b/nvflare/fox/examples/pt/pt_avg_stream2.py new file mode 100644 index 0000000000..4ab4dd0a6f --- /dev/null +++ b/nvflare/fox/examples/pt/pt_avg_stream2.py @@ -0,0 +1,162 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import torch + +from nvflare.fox import fox +from nvflare.fox.api.constants import BackendType +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.pt.utils import parse_state_dict +from nvflare.fox.sim.simulator import Simulator +from nvflare.fox.sys.downloader import Downloader, download_tensors +from nvflare.fox.utils.tensor_receiver import TensorReceiver +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class PTFedAvgStream: + + def __init__(self, initial_model, num_rounds=10, timeout=2.0): + self.num_rounds = num_rounds + self.initial_model = initial_model + self.timeout = timeout + self.name = "PTFedAvgStream" + self.logger = get_obj_logger(self) + self._init_model = parse_state_dict(initial_model) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + current_model = self._init_model + for i in range(self.num_rounds): + current_model = self._do_one_round(i, current_model) + if current_model is None: + self.logger.error(f"training failed at round {i}") + break + self.logger.info(f"FINAL MODEL: {current_model}") + return current_model + + def _do_one_round(self, r, current_model): + grp = fox.clients( + blocking=False, + process_resp_cb=TensorReceiver(), + ) + + if fox.backend_type == BackendType.FLARE: + downloader = Downloader( + num_receivers=grp.size, + timeout=5.0, + ) + model_type = "ref" + model = downloader.add_tensors(current_model, 0) + self.logger.info(f"prepared model as ref: {model}") + else: + model = current_model + model_type = "model" + + aggr_result = {} + aggr_count = {} + results = grp.train(r, model, model_type) + + # results is a queue that contains chunks of tensors that are downloaded from clients. + # we aggregate them while they are being downloaded in parallel. + # Note that the chunks from different sites may arrive in any order. + for n, tensors in results: + self.logger.info(f"got tensors from {n}: {tensors}") + if not tensors: + # we use None to indicate the end of all chunks from a site. + continue + + for k, v in tensors.items(): + if k not in aggr_result: + aggr_result[k] = v + aggr_count[k] = 1 + else: + aggr_result[k] += v + aggr_count[k] += 1 + + final_result = {} + for k, v in aggr_result.items(): + final_result[k] = torch.div(v, aggr_count[k]) + self.logger.info(f"[{fox.call_info}] round {r}: aggr result: {final_result}") + return final_result + + +class PTTrainer: + + def __init__(self, delta: float): + self.delta = delta + self.logger = get_obj_logger(self) + + @fox.collab + def train(self, current_round, model, model_type: str): + if fox.is_aborted: + self.logger.debug("training aborted") + return None, "model" + + self.logger.debug(f"[{fox.call_info}] training round {current_round}: {model_type=} {model=}") + if model_type == "ref": + err, model = download_tensors(ref=model, per_request_timeout=5.0) + if err: + raise RuntimeError(f"failed to download model {model}: {err}") + self.logger.info(f"downloaded model1 {model}") + + result = {} + for k, v in model.items(): + result[k] = v + self.delta + + if model_type == "ref": + # stream it + downloader = Downloader( + num_receivers=1, + timeout=5.0, + ) + model_type = "ref" + model = downloader.add_tensors(result, 0) + self.logger.info(f"prepared result as ref: {model}") + else: + model = result + model_type = "model" + return model, model_type + + +def main(): + simple_logging(logging.DEBUG) + + server = PTFedAvgStream( + initial_model={ + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + }, + num_rounds=2, + ) + + client = PTTrainer(delta=1.0) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="pt_fedavg_stream2", + server=server, + client=client, + num_clients=2, + ) + + result = simulator.run() + print(f"final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/pt_np.py b/nvflare/fox/examples/pt/pt_np.py new file mode 100644 index 0000000000..8c2f2d82de --- /dev/null +++ b/nvflare/fox/examples/pt/pt_np.py @@ -0,0 +1,151 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import threading + +from nvflare.fox import fox +from nvflare.fox.api.utils import simple_logging +from nvflare.fox.examples import get_experiment_root +from nvflare.fox.examples.np.algos.utils import add as add_np +from nvflare.fox.examples.np.algos.utils import div as div_np +from nvflare.fox.examples.np.algos.utils import parse_state_dict as parse_np +from nvflare.fox.examples.pt.utils import add as add_pt +from nvflare.fox.examples.pt.utils import div as div_pt +from nvflare.fox.examples.pt.utils import parse_state_dict as parse_pt +from nvflare.fox.sim.simulator import Simulator +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class _AggrResult: + + def __init__(self): + self.pt_total = {} + self.np_total = {} + self.count = 0 + self.lock = threading.Lock() # ensure update integrity + + +class PTFedAvgMixed: + + def __init__(self, pt_model, np_model, num_rounds=10, timeout=2.0): + self.num_rounds = num_rounds + self.pt_model = pt_model + self.np_model = np_model + self.timeout = timeout + self.name = "PTFedAvg" + self.logger = get_obj_logger(self) + self._pt_model = parse_pt(pt_model) + self._np_model = parse_np(np_model) + + @fox.algo + def execute(self): + self.logger.info(f"[{fox.call_info}] Start training for {self.num_rounds} rounds") + pt_model, np_model = self._pt_model, self._np_model + for i in range(self.num_rounds): + pt_model, np_model = self._do_one_round(i, pt_model, np_model) + if pt_model is None or np_model is None: + self.logger.error(f"training failed at round {i}") + break + self.logger.info(f"FINAL MODEL: {pt_model=} {np_model=}") + return pt_model, np_model + + def _do_one_round(self, r, pt_model, np_model): + aggr_result = _AggrResult() + + fox.clients( + process_resp_cb=self._accept_train_result, + aggr_result=aggr_result, + ).train(r, pt_model, np_model) + + if aggr_result.count == 0: + return None, None + else: + pt_result = aggr_result.pt_total + div_pt(pt_result, aggr_result.count) + self.logger.info( + f"[{fox.call_info}] round {r}: aggr PT result from {aggr_result.count} clients: {pt_result}" + ) + + np_result = aggr_result.np_total + div_np(np_result, aggr_result.count) + self.logger.info( + f"[{fox.call_info}] round {r}: aggr NP result from {aggr_result.count} clients: {np_result}" + ) + return pt_result, np_result + + def _accept_train_result(self, gcc, result, aggr_result: _AggrResult): + self.logger.info(f"[{fox.call_info}] got train result from {fox.caller}: {result}") + + pt_result, np_result = result + with aggr_result.lock: + add_pt(pt_result, aggr_result.pt_total) + add_np(np_result, aggr_result.np_total) + aggr_result.count += 1 + return None + + +class PTTrainer: + + def __init__(self, delta: float): + self.delta = delta + self.logger = get_obj_logger(self) + + @fox.collab + def train(self, current_round, pt_model, np_model): + if fox.is_aborted: + self.logger.debug("training aborted") + return None, None + + self.logger.debug(f"[{fox.call_info}] training round {current_round}: {pt_model=} {np_model=}") + + pt_result = {} + for k, v in pt_model.items(): + pt_result[k] = v + self.delta + + np_result = {} + for k, v in np_model.items(): + np_result[k] = v + self.delta + + return pt_result, np_result + + +def main(): + simple_logging(logging.DEBUG) + + init_model = { + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + } + + server = PTFedAvgMixed( + pt_model=init_model, + np_model=init_model, + num_rounds=4, + ) + + simulator = Simulator( + root_dir=get_experiment_root(), + experiment_name="pt_np", + server=server, + client=PTTrainer(delta=1.0), + num_clients=2, + ) + + result = simulator.run() + print(f"Final result: {result}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/recipe_pt_avg.py b/nvflare/fox/examples/pt/recipe_pt_avg.py new file mode 100644 index 0000000000..e22978d5aa --- /dev/null +++ b/nvflare/fox/examples/pt/recipe_pt_avg.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.app_opt.pt.decomposers import TensorDecomposer +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.pt.pt_avg_filter import PTFedAvg, PTTrainer +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_pt_fedavg", _make_recipe) + + +def _make_recipe(job_name): + recipe = FoxRecipe( + job_name=job_name, + server=PTFedAvg( + initial_model={ + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + }, + num_rounds=2, + ), + client=PTTrainer(delta=1.0), + ) + recipe.add_decomposers([TensorDecomposer()]) + return recipe + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/recipe_pt_avg_filter.py b/nvflare/fox/examples/pt/recipe_pt_avg_filter.py new file mode 100644 index 0000000000..5f9b2ef74e --- /dev/null +++ b/nvflare/fox/examples/pt/recipe_pt_avg_filter.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.pt.filters import ( + IncomingModelCallFilter, + IncomingModelResultFilter, + OutgoingModelCallFilter, + OutgoingModelResultFilter, +) +from nvflare.fox.examples.pt.pt_avg_filter import PTFedAvg, PTTrainer +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_pt_fedavg_filter", _make_recipe) + + +def _make_recipe(job_name): + recipe = FoxRecipe( + job_name=job_name, + server=PTFedAvg( + initial_model={ + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + }, + num_rounds=2, + ), + client=PTTrainer(delta=1.0), + ) + recipe.add_server_outgoing_call_filters( + pattern="*.train", + filters=[OutgoingModelCallFilter("weights")], + ) + recipe.add_server_incoming_result_filters(pattern="*.train", filters=[IncomingModelResultFilter()]) + + recipe.add_client_incoming_call_filters( + pattern="*.train", + filters=[IncomingModelCallFilter("weights")], + ) + recipe.add_client_outgoing_result_filters(pattern="*.train", filters=[OutgoingModelResultFilter()]) + return recipe + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/recipe_pt_avg_filter2.py b/nvflare/fox/examples/pt/recipe_pt_avg_filter2.py new file mode 100644 index 0000000000..db744bf08c --- /dev/null +++ b/nvflare/fox/examples/pt/recipe_pt_avg_filter2.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.pt.filters2 import ModelFilter +from nvflare.fox.examples.pt.pt_avg_filter import PTFedAvg, PTTrainer +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_pt_fedavg_filter2", _make_recipe) + + +def _make_recipe(job_name): + recipe = FoxRecipe( + job_name=job_name, + server=PTFedAvg( + initial_model={ + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + }, + num_rounds=2, + ), + client=PTTrainer(delta=1.0), + ) + model_filter = ModelFilter("weights") + recipe.add_server_outgoing_call_filters( + pattern="*.train", + filters=[model_filter], + ) + recipe.add_server_incoming_result_filters(pattern="*.train", filters=[model_filter]) + + recipe.add_client_incoming_call_filters( + pattern="*.train", + filters=[model_filter], + ) + recipe.add_client_outgoing_result_filters(pattern="*.train", filters=[model_filter]) + return recipe + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/recipe_pt_avg_mixed.py b/nvflare/fox/examples/pt/recipe_pt_avg_mixed.py new file mode 100644 index 0000000000..634637cc19 --- /dev/null +++ b/nvflare/fox/examples/pt/recipe_pt_avg_mixed.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.pt.pt_avg_mixed import PTFedAvgMixed, PTTrainer +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_pt_fedavg_mixed", _make_recipe) + + +def _make_recipe(job_name): + init_model = { + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + } + + return FoxRecipe( + job_name=job_name, + server=PTFedAvgMixed( + pt_model=init_model, + np_model=init_model, + num_rounds=2, + ), + client=PTTrainer(delta=1.0), + ) + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/recipe_pt_avg_stream.py b/nvflare/fox/examples/pt/recipe_pt_avg_stream.py new file mode 100644 index 0000000000..fce248b77e --- /dev/null +++ b/nvflare/fox/examples/pt/recipe_pt_avg_stream.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.pt.pt_avg_stream import PTFedAvgStream, PTTrainer +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_pt_fedavg_stream", _make_recipe) + + +def _make_recipe(job_name): + return FoxRecipe( + job_name=job_name, + server=PTFedAvgStream( + initial_model={ + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + }, + num_rounds=2, + ), + client=PTTrainer(delta=1.0), + ) + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/recipe_pt_avg_stream2.py b/nvflare/fox/examples/pt/recipe_pt_avg_stream2.py new file mode 100644 index 0000000000..b17d98f6f9 --- /dev/null +++ b/nvflare/fox/examples/pt/recipe_pt_avg_stream2.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.pt.pt_avg_stream2 import PTFedAvgStream, PTTrainer +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_pt_fedavg_stream2", _make_recipe) + + +def _make_recipe(job_name): + return FoxRecipe( + job_name=job_name, + server=PTFedAvgStream( + initial_model={ + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + }, + num_rounds=2, + ), + client=PTTrainer(delta=1.0), + ) + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/recipe_pt_np.py b/nvflare/fox/examples/pt/recipe_pt_np.py new file mode 100644 index 0000000000..65a15ba75f --- /dev/null +++ b/nvflare/fox/examples/pt/recipe_pt_np.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.app_common.decomposers.numpy_decomposers import NumpyArrayDecomposer +from nvflare.app_opt.pt.decomposers import TensorDecomposer +from nvflare.fox.examples import export_recipe +from nvflare.fox.examples.pt.pt_np import PTFedAvgMixed, PTTrainer +from nvflare.fox.sys.recipe import FoxRecipe + + +def main(): + export_recipe("fox_pt_np", _make_recipe) + + +def _make_recipe(job_name): + init_model = { + "x": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "y": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "z": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + } + + recipe = FoxRecipe( + job_name=job_name, + server=PTFedAvgMixed( + pt_model=init_model, + np_model=init_model, + num_rounds=2, + ), + client=PTTrainer(delta=1.0), + ) + recipe.add_decomposers([TensorDecomposer(), NumpyArrayDecomposer()]) + return recipe + + +if __name__ == "__main__": + main() diff --git a/nvflare/fox/examples/pt/utils.py b/nvflare/fox/examples/pt/utils.py new file mode 100644 index 0000000000..a41e5cb130 --- /dev/null +++ b/nvflare/fox/examples/pt/utils.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def parse_array_def(array_def): + if array_def is None: + return array_def + + if isinstance(array_def, torch.Tensor): + return array_def + + if isinstance(array_def, list): + return torch.tensor(array_def) + else: + raise ValueError(f"unsupported array def: {array_def}") + + +def parse_state_dict(d): + result = {} + for k, v in d.items(): + result[k] = parse_array_def(v) + return result + + +def parse_model_def(model_def): + if isinstance(model_def, dict): + return parse_state_dict(model_def) + else: + return parse_array_def(model_def) + + +def add(value: dict, to_model: dict): + """Add value to a specified model in-place. + + Args: + value: + to_model: + + Returns: + + """ + for k, v in value.items(): + if k not in to_model: + to_model[k] = v + else: + to_model[k] += v + return to_model + + +def div(model: dict, value): + """Divide the model in-place by a specified value. + + Args: + model: + value: + + Returns: + + """ + for k, v in model.items(): + model[k] = torch.div(v, value) + return model diff --git a/nvflare/fox/sim/__init__.py b/nvflare/fox/sim/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/fox/sim/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fox/sim/backend.py b/nvflare/fox/sim/backend.py new file mode 100644 index 0000000000..6f128b4a5a --- /dev/null +++ b/nvflare/fox/sim/backend.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time + +from nvflare.apis.fl_exception import RunAborted +from nvflare.fox.api.app import App +from nvflare.fox.api.backend import Backend +from nvflare.fox.api.call_opt import CallOption +from nvflare.fox.api.constants import CollabMethodArgName +from nvflare.fox.api.dec import adjust_kwargs +from nvflare.fox.api.gcc import GroupCallContext +from nvflare.fox.api.utils import check_call_args + + +class _Waiter(threading.Event): + + def __init__(self): + super().__init__() + self.result = None + + +class SimBackend(Backend): + + def __init__(self, target_obj_name: str, target_app: App, target_obj, abort_signal, thread_executor): + Backend.__init__(self, abort_signal) + self.target_obj_name = target_obj_name + self.target_app = target_app + self.target_obj = target_obj + self.executor = thread_executor + + def _get_func(self, func_name): + return self.target_app.find_collab_method(self.target_obj, func_name) + + def call_target(self, context, target_name: str, call_opt: CallOption, func_name: str, *args, **kwargs): + func = self._get_func(func_name) + if not func: + raise AttributeError(f"{target_name} does not have method '{func_name}' or it is not collab") + + if not callable(func): + raise AttributeError(f"the method '{func_name}' of {target_name} is not callable") + + expect_result = call_opt.expect_result + timeout = call_opt.timeout + + waiter = None + if expect_result: + waiter = _Waiter() + + self.executor.submit(self._run_func, waiter, context, target_name, func_name, func, args, kwargs) + if waiter: + start_time = time.time() + while True: + if self.abort_signal.triggered: + waiter.result = RunAborted("job is aborted") + break + + ok = waiter.wait(0.1) + if ok: + break + + waited = time.time() - start_time + if waited > timeout: + # timed out + waiter.result = TimeoutError(f"function {func_name} timed out after {waited} seconds") + break + + return waiter.result + else: + return None + + def _preprocess(self, context, target_name, func_name, func, kwargs): + caller_ctx = context + my_ctx = self.target_app.new_context(caller_ctx.caller, caller_ctx.callee) + kwargs = self.target_app.apply_incoming_call_filters(target_name, func_name, kwargs, my_ctx) + + # make sure the final kwargs conforms to func interface + obj_itf = self.target_app.get_target_object_collab_interface(self.target_obj_name) + if not obj_itf: + raise RuntimeError(f"cannot find collab interface for object {self.target_obj_name}") + + func_itf = obj_itf.get(func_name) + if not func_itf: + raise RuntimeError(f"cannot find interface for func '{func_name}' of object {self.target_obj_name}") + + check_call_args(func_name, func_itf, [], kwargs) + kwargs[CollabMethodArgName.CONTEXT] = my_ctx + adjust_kwargs(func, kwargs) + return my_ctx, kwargs + + def _run_func(self, waiter: _Waiter, context, target_name, func_name, func, args, kwargs): + try: + ctx, kwargs = self._preprocess(context, target_name, func_name, func, kwargs) + result = func(*args, **kwargs) + + # apply result filter + result = self.target_app.apply_outgoing_result_filters(target_name, func_name, result, ctx) + if waiter: + waiter.result = result + except Exception as ex: + if waiter: + waiter.result = ex + finally: + if waiter: + waiter.set() + + def call_target_in_group(self, gcc: GroupCallContext, func_name: str, *args, **kwargs): + target_name = gcc.target_name + func = self._get_func(func_name) + if not func: + raise AttributeError(f"{target_name} does not have method '{func_name}' or it is not collab") + + if not callable(func): + raise AttributeError(f"the method '{func_name}' of {target_name} is not callable") + + self.executor.submit(self._run_func_in_group, gcc, func_name, args, kwargs) + + def _run_func_in_group(self, gcc: GroupCallContext, func_name, args, kwargs): + try: + target_name = gcc.target_name + result = self.call_target(gcc.context, target_name, gcc.call_opt, func_name, *args, **kwargs) + gcc.send_completed() + gcc.set_result(result) + except Exception as ex: + gcc.set_exception(ex) diff --git a/nvflare/fox/sim/simulator.py b/nvflare/fox/sim/simulator.py new file mode 100644 index 0000000000..250805b714 --- /dev/null +++ b/nvflare/fox/sim/simulator.py @@ -0,0 +1,307 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import List, Tuple, Union + +from nvflare.apis.signal import Signal +from nvflare.fox.api.app import App, ClientApp, ServerApp +from nvflare.fox.api.constants import MAKE_CLIENT_APP_METHOD, BackendType +from nvflare.fox.api.dec import get_object_collab_interface +from nvflare.fox.api.proxy import Proxy +from nvflare.fox.api.run_server import run_server +from nvflare.fox.sim.backend import SimBackend +from nvflare.fox.sim.ws import SimWorkspace +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class AppRunner: + + def _prepare_app_backends(self, app: App): + bes = {"": SimBackend("", app, app, self.abort_signal, self.thread_executor)} + targets = app.get_collab_objects() + for name, obj in targets.items(): + bes[name] = SimBackend(name, app, obj, self.abort_signal, self.thread_executor) + return bes + + @staticmethod + def _prepare_proxy(for_app: App, target_app: App, backends: dict): + app_proxy = Proxy( + app=for_app, + target_name=target_app.name, + target_fqn=target_app.fqn, + backend=backends[""], + target_interface=get_object_collab_interface(target_app), + ) + collab_objs = target_app.get_collab_objects() + for name, obj in collab_objs.items(): + p = Proxy( + app=for_app, + target_name=f"{target_app.name}.{name}", + target_fqn="", + backend=backends[name], + target_interface=get_object_collab_interface(obj), + ) + app_proxy.add_child(name, p) + return app_proxy + + def _make_app(self, site_name, fqn): + """Make a new client app instance for the specified site + + Args: + site_name: name of the site + fqn: fully qualified name of the site + + Returns: a new instance of the app + + """ + # If the app contains "make_client_app" method, call it to make the app instance! + # Otherwise, make the instance by deep copying. + # If the client_app object cannot be deep-copied, then it must provide the make_client_app method. + make_client_app_f = getattr(self.client_app, MAKE_CLIENT_APP_METHOD, None) + if make_client_app_f and callable(make_client_app_f): + app = make_client_app_f(site_name, BackendType.SIMULATION) + if not isinstance(app, ClientApp): + raise RuntimeError(f"result returned by {MAKE_CLIENT_APP_METHOD} must be ClientApp but got {type(app)}") + else: + try: + app = copy.deepcopy(self.client_app) + except Exception as ex: + self.logger.error( + f"exception occurred {type(ex)} creating client app with deepcopy. " + f"Please implement the {MAKE_CLIENT_APP_METHOD} method in the client app class" + ) + raise ex + + app.name = site_name + app.set_fqn(fqn) + app.set_backend_type(BackendType.SIMULATION) + return app + + def _prepare_proxies(self, for_app: App, server_app: App, client_apps: dict, backends: dict): + server_proxy = self._prepare_proxy(for_app, server_app, backends[server_app.name]) + client_proxies = [] + for name, app in client_apps.items(): + p = self._prepare_proxy(for_app, app, backends[name]) + client_proxies.append(p) + + return server_proxy, client_proxies + + def __init__( + self, + root_dir: str, + experiment_name: str, + server_app: ServerApp, + client_app: ClientApp, + max_workers: int = 100, + num_clients: Union[int, Tuple[int, int]] = 2, + ): + if not isinstance(server_app, ServerApp): + raise ValueError(f"server_app must be ServerApp but got {type(server_app)}") + + if not isinstance(client_app, ClientApp): + raise ValueError(f"client_app must be ClientApp but got {type(client_app)}") + + self.logger = get_obj_logger(self) + self.abort_signal = Signal() + server_app.name = "server" + server_app.set_fqn(server_app.name) + server_app.set_backend_type(BackendType.SIMULATION) + self.server_app = server_app + self.client_app = client_app + self.thread_executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="fox_call") + + if isinstance(num_clients, int): + if num_clients <= 0: + raise ValueError(f"num_clients must > 0 but got {num_clients}") + client_apps = {} + for i in range(num_clients): + name = f"site-{i + 1}" + client_apps[name] = self._make_app(name, name) + elif isinstance(num_clients, tuple): + if len(num_clients) != 2: + raise ValueError(f"num_clients must be an int or tuple(int, int) but got {num_clients}") + + # tuple of (height x width) + height, num_children_per_parent = num_clients + if not isinstance(height, int) or not isinstance(num_children_per_parent, int): + raise ValueError(f"num_clients must be an int or tuple(int, int) but got {num_clients}") + + if height <= 0 or num_children_per_parent <= 0: + raise ValueError(f"num_clients must contain positive ints but got {num_clients}") + + self.logger.info(f"creating clients {height} x {num_children_per_parent}") + client_apps = self._build_hierarchical_clients(height, num_children_per_parent) + else: + raise ValueError(f"num_clients must be an int or tuple(int, int) but got {type(num_clients)}") + + self.logger.info(f"created client apps: {client_apps.keys()}") + + backends = {server_app.name: self._prepare_app_backends(server_app)} + + for name, app in client_apps.items(): + backends[name] = self._prepare_app_backends(app) + + exp_id = str(uuid.uuid4()) + + for name, app in client_apps.items(): + server_proxy, client_proxies = self._prepare_proxies(app, server_app, client_apps, backends) + ws = SimWorkspace(root_dir=root_dir, experiment_name=experiment_name, site_name=name, exp_id=exp_id) + app.setup(ws, server_proxy, client_proxies, self.abort_signal) + + # prepare server + server_proxy, client_proxies = self._prepare_proxies(server_app, server_app, client_apps, backends) + ws = SimWorkspace(root_dir=root_dir, experiment_name=experiment_name, site_name=server_app.name, exp_id=exp_id) + server_app.setup(ws, server_proxy, client_proxies, self.abort_signal) + self.client_apps = client_apps + self.exp_dir = ws.get_experiment_dir() + + def _build_hierarchical_clients(self, height: int, num_children_per_parent: int): + client_apps = {} + last_client_fqns = {} + current_client_fqns = {} + for i in range(height): + if not last_client_fqns: + for j in range(num_children_per_parent): + name = f"site-{j + 1}" + fqn = name + app = self._make_app(name, fqn) + client_apps[name] = app + current_client_fqns[fqn] = app + else: + for fqn, parent_app in last_client_fqns.items(): + # create w clients for each parent + for k in range(num_children_per_parent): + child_name = f"{parent_app.name}-{k + 1}" + child_fqn = f"{fqn}.{child_name}" + app = self._make_app(child_name, child_fqn) + client_apps[child_name] = app + current_client_fqns[child_fqn] = app + last_client_fqns = current_client_fqns + current_client_fqns = {} + return client_apps + + def run(self): + self.logger.debug(f"Server Collab Interface: {self.server_app.get_collab_interface()}") + self.logger.debug(f"Client Collab Interface: {self.client_app.get_collab_interface()}") + + try: + result = self._try_run() + except KeyboardInterrupt: + self.logger.info("execution is aborted by user") + self.abort_signal.trigger(True) + result = None + finally: + self.thread_executor.shutdown(wait=True, cancel_futures=True) + self.logger.info(f"Experiment results are in {self.exp_dir}") + return result + + def _try_run(self): + # initialize all apps + client_ctx = {} + for n, app in self.client_apps.items(): + ctx = app.new_context(n, n) + client_ctx[n] = ctx + self.logger.info(f"initializing client app {n}") + app.initialize(ctx) + + # run the server + result = run_server(self.server_app, self.logger) + for n, app in self.client_apps.items(): + ctx = client_ctx[n] + self.logger.info(f"finalizing client app {n}") + app.finalize(ctx) + + return result + + +class Simulator: + + def __init__( + self, + root_dir: str, + experiment_name: str, + server, + client, + server_objects: dict[str, object] = None, + client_objects: dict[str, object] = None, + max_workers: int = 100, + num_clients: Union[int, Tuple[int, int]] = 2, + ): + server_app: ServerApp = ServerApp(server) + client_app: ClientApp = ClientApp(client) + + self.root_dir = root_dir + self.experiment_name = experiment_name + self.max_workers = max_workers + self.num_clients = num_clients + + if server_objects: + for name, obj in server_objects.items(): + server_app.add_collab_object(name, obj) + + if client_objects: + for name, obj in client_objects.items(): + client_app.add_collab_object(name, obj) + + self.server_app = server_app + self.client_app = client_app + + def add_server_outgoing_call_filters(self, pattern: str, filters: List[object]): + self.server_app.add_outgoing_call_filters(pattern, filters) + + def add_server_incoming_call_filters(self, pattern: str, filters: List[object]): + self.server_app.add_incoming_call_filters(pattern, filters) + + def add_server_outgoing_result_filters(self, pattern: str, filters: List[object]): + self.server_app.add_outgoing_result_filters(pattern, filters) + + def add_server_incoming_result_filters(self, pattern: str, filters: List[object]): + self.server_app.add_incoming_result_filters(pattern, filters) + + def add_client_outgoing_call_filters(self, pattern: str, filters: List[object]): + self.client_app.add_outgoing_call_filters(pattern, filters) + + def add_client_incoming_call_filters(self, pattern: str, filters: List[object]): + self.client_app.add_incoming_call_filters(pattern, filters) + + def add_client_outgoing_result_filters(self, pattern: str, filters: List[object]): + self.client_app.add_outgoing_result_filters(pattern, filters) + + def add_client_incoming_result_filters(self, pattern: str, filters: List[object]): + self.client_app.add_incoming_result_filters(pattern, filters) + + def set_server_prop(self, name: str, value): + self.server_app.set_prop(name, value) + + def set_client_prop(self, name: str, value): + self.client_app.set_prop(name, value) + + def set_server_resource_dirs(self, resource_dirs): + self.server_app.set_resource_dirs(resource_dirs) + + def set_client_resource_dirs(self, resource_dirs): + self.client_app.set_resource_dirs(resource_dirs) + + def run(self): + runner = AppRunner( + root_dir=self.root_dir, + experiment_name=self.experiment_name, + server_app=self.server_app, + client_app=self.client_app, + max_workers=self.max_workers, + num_clients=self.num_clients, + ) + return runner.run() diff --git a/nvflare/fox/sim/ws.py b/nvflare/fox/sim/ws.py new file mode 100644 index 0000000000..8b5b6e8638 --- /dev/null +++ b/nvflare/fox/sim/ws.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path + +from nvflare.fox.api.workspace import Workspace + + +class SimWorkspace(Workspace): + + def __init__(self, root_dir: str, experiment_name: str, exp_id: str, site_name: str): + super().__init__() + if not isinstance(root_dir, str): + raise ValueError(f"root_dir must be str but got {type(root_dir)}") + + if not isinstance(exp_id, str): + raise ValueError(f"exp_id must be str but got {type(exp_id)}") + + if not isinstance(experiment_name, str): + raise ValueError(f"experiment_name must be str but got {type(experiment_name)}") + + if not isinstance(site_name, str): + raise ValueError(f"site_name must be str but got {type(site_name)}") + + self.root_dir = root_dir + self.site_name = site_name + self.exp_name = experiment_name + self.exp_dir = os.path.join(root_dir, experiment_name, exp_id) + self.work_dir = os.path.join(self.exp_dir, site_name) + os.makedirs(self.work_dir, exist_ok=True) + + def get_root_dir(self) -> str: + return self.root_dir + + def get_work_dir(self) -> str: + return self.work_dir + + def get_experiment_dir(self) -> str: + return self.exp_dir diff --git a/nvflare/fox/sys/__init__.py b/nvflare/fox/sys/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/fox/sys/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fox/sys/adaptor.py b/nvflare/fox/sys/adaptor.py new file mode 100644 index 0000000000..7e6cd9655c --- /dev/null +++ b/nvflare/fox/sys/adaptor.py @@ -0,0 +1,111 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List + +from nvflare.apis.fl_context import FLContext +from nvflare.fox.api.app import App + + +class FoxAdaptor: + + def __init__( + self, + collab_obj_ids: List[str] = None, + props: Dict[str, Any] = None, + resource_dirs: Dict[str, str] = None, + incoming_call_filters=None, + outgoing_call_filters=None, + incoming_result_filters=None, + outgoing_result_filters=None, + ): + if not collab_obj_ids: + collab_obj_ids = [] + self.props = props + self.resource_dirs = resource_dirs + self.collab_obj_ids = collab_obj_ids + self.incoming_call_filters = incoming_call_filters + self.outgoing_call_filters = outgoing_call_filters + self.incoming_result_filters = incoming_result_filters + self.outgoing_result_filters = outgoing_result_filters + + def process_config(self, app: App, fl_ctx: FLContext): + app.update_props(self.props) + app.set_resource_dirs(self.resource_dirs) + + engine = fl_ctx.get_engine() + if self.collab_obj_ids: + for cid in self.collab_obj_ids: + obj = engine.get_component(cid) + if not obj: + return f"component {cid} does not exist" + + app.add_collab_object(cid, obj) + + err = self._parse_filters("incoming_call_filters", app.add_incoming_call_filters, fl_ctx) + if err: + return err + + err = self._parse_filters("outgoing_call_filters", app.add_outgoing_call_filters, fl_ctx) + if err: + return err + + err = self._parse_filters("incoming_result_filters", app.add_incoming_result_filters, fl_ctx) + if err: + return err + + err = self._parse_filters("outgoing_result_filters", app.add_outgoing_result_filters, fl_ctx) + if err: + return err + + return None + + def _parse_filters(self, name, add_f, fl_ctx): + filters = getattr(self, name) + if not filters: + return None + + if not isinstance(filters, list): + return f"{name} must be a list but got {type(filters)}" + + for chain_dict in filters: + pattern, filter_components, err = self._parse_filter_chain(name, chain_dict, fl_ctx) + if err: + return err + add_f(pattern, filter_components) + return None + + @staticmethod + def _parse_filter_chain(chain_name, chain_dict: dict, fl_ctx): + if not isinstance(chain_dict, dict): + return None, None, f"element in {chain_name} must be dict but got {type(chain_dict)}" + + pattern = chain_dict.get("pattern") + if not pattern: + return None, None, f"missing 'pattern' in {chain_name}" + + filter_ids = chain_dict.get("filters") + if not filter_ids: + return None, None, f"missing 'filters' in {chain_name}" + + if not isinstance(filter_ids, list): + return None, None, f"invalid 'filters' in {chain_name}: expect list got {type(filter_ids)}" + + engine = fl_ctx.get_engine() + filters = [] + for fid in filter_ids: + f = engine.get_component(fid) + if not f: + return None, None, f"component {fid} does not exist" + filters.append(f) + return pattern, filters, None diff --git a/nvflare/fox/sys/backend.py b/nvflare/fox/sys/backend.py new file mode 100644 index 0000000000..30b338264c --- /dev/null +++ b/nvflare/fox/sys/backend.py @@ -0,0 +1,148 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.api.backend import Backend +from nvflare.fox.api.call_opt import CallOption +from nvflare.fox.api.ctx import set_call_context +from nvflare.fox.api.gcc import GroupCallContext +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.utils import new_cell_message +from nvflare.fuel.f3.message import Message +from nvflare.security.logging import secure_log_traceback + +from .constants import MSG_CHANNEL, MSG_TOPIC, CallReplyKey, ObjectCallKey + + +class FlareBackend(Backend): + + def __init__(self, manager, engine, caller, cell, target_fqcn, abort_signal, thread_executor): + Backend.__init__(self, abort_signal) + self.manager = manager + self.engine = engine + self.caller = caller + self.cell = cell + self.target_fqcn = target_fqcn + self.thread_executor = thread_executor + + def call_target(self, context, target_name: str, call_opt: CallOption, func_name: str, *args, **kwargs): + return self._call_target( + context=context, + target_name=target_name, + call_opt=call_opt, + send_complete_cb=None, + cb_kwargs={}, + func_name=func_name, + *args, + **kwargs, + ) + + def _call_target( + self, + context, + target_name: str, + call_opt: CallOption, + send_complete_cb, + cb_kwargs, + func_name: str, + *args, + **kwargs, + ): + set_call_context(context) + + payload = { + ObjectCallKey.CALLER: self.caller, + ObjectCallKey.TARGET_NAME: target_name, + ObjectCallKey.METHOD_NAME: func_name, + ObjectCallKey.ARGS: args, + ObjectCallKey.KWARGS: kwargs, + } + request = new_cell_message({}, payload) + + timeout = call_opt.timeout + if call_opt.expect_result: + self.logger.info(f"send_request from {self.cell.get_fqcn()} to {self.target_fqcn}: {func_name=} {call_opt}") + + reply = self.cell.send_request( + channel=MSG_CHANNEL, + target=self.target_fqcn, + topic=MSG_TOPIC, + request=request, + timeout=timeout, + secure=call_opt.secure, + optional=call_opt.optional, + abort_signal=self.abort_signal, + send_complete_cb=send_complete_cb, + **cb_kwargs, + ) + if not isinstance(reply, Message): + self.logger.error(f"cell message reply must be Message but got {type(reply)}") + raise RuntimeError(f"function {func_name} failed with internal error") + + rc = reply.get_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK) + if rc == ReturnCode.TIMEOUT: + raise TimeoutError(f"function {func_name} timed out after {timeout} seconds") + elif rc != ReturnCode.OK: + error = None + if isinstance(reply.payload, dict): + error = reply.payload.get(CallReplyKey.ERROR) + raise RuntimeError(f"function {func_name} failed: {rc=} {error=}") + + if not isinstance(reply.payload, dict): + raise RuntimeError(f"function {func_name} failed: reply must be dict but got {type(reply.payload)}") + + error = reply.payload.get(CallReplyKey.ERROR) + if error: + raise RuntimeError(f"function {func_name} failed: {error}") + + result = reply.payload.get(CallReplyKey.RESULT) + self.logger.info(f"got result from {self.target_fqcn} {func_name=}") + return result + else: + # fire and forget + self.logger.info(f"fire_and_forget from {self.cell.get_fqcn()} to {self.target_fqcn}") + self.cell.fire_and_forget( + channel=MSG_CHANNEL, + topic=MSG_TOPIC, + targets=self.target_fqcn, + message=request, + secure=call_opt.secure, + optional=call_opt.optional, + ) + return None + + def call_target_in_group(self, gcc: GroupCallContext, func_name: str, *args, **kwargs): + self.thread_executor.submit(self._run_func, gcc, func_name, args, kwargs) + + def _run_func(self, gcc: GroupCallContext, func_name: str, args, kwargs): + try: + result = self._call_target( + context=gcc.context, + target_name=gcc.target_name, + call_opt=gcc.call_opt, + func_name=func_name, + send_complete_cb=self._msg_sent, + cb_kwargs={"gcc": gcc}, + *args, + **kwargs, + ) + gcc.set_result(result) + except Exception as ex: + gcc.set_exception(ex) + + def _msg_sent(self, gcc: GroupCallContext): + gcc.send_completed() + + def handle_exception(self, exception: Exception): + fl_ctx = self.engine.new_context() + secure_log_traceback(self.logger) + self.manager.system_panic(f"exception occurred: {exception}", fl_ctx) diff --git a/nvflare/fox/sys/constants.py b/nvflare/fox/sys/constants.py new file mode 100644 index 0000000000..afbb3cef79 --- /dev/null +++ b/nvflare/fox/sys/constants.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +SYNC_TASK_NAME = "sync" + +MSG_CHANNEL = "fox" +MSG_TOPIC = "call" + + +class SyncKey: + COLLAB_INTERFACE = "collab_interface" + + +class ObjectCallKey: + CALLER = "caller" + TARGET_NAME = "target_name" + METHOD_NAME = "method_name" + ARGS = "args" + KWARGS = "kwargs" + TIMEOUT = "timeout" + BLOCKING = "blocking" + + +class CallReplyKey: + ERROR = "error" + RESULT = "result" diff --git a/nvflare/fox/sys/controller.py b/nvflare/fox/sys/controller.py new file mode 100644 index 0000000000..160eca8bd5 --- /dev/null +++ b/nvflare/fox/sys/controller.py @@ -0,0 +1,261 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from concurrent.futures import ThreadPoolExecutor +from typing import List + +from nvflare.apis.client import Client as ClientSite +from nvflare.apis.controller_spec import Client, ClientTask, Task +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import ReturnCode, Shareable +from nvflare.apis.signal import Signal +from nvflare.fox.api.app import ServerApp +from nvflare.fox.api.constants import BackendType +from nvflare.fox.api.proxy import Proxy +from nvflare.fox.api.run_server import run_server +from nvflare.fuel.f3.cellnet.fqcn import FQCN + +from .adaptor import FoxAdaptor +from .backend import FlareBackend +from .constants import SYNC_TASK_NAME, SyncKey +from .utils import prepare_for_remote_call +from .ws import FlareWorkspace + + +class _ClientInfo: + + def __init__(self, collab_interface: dict): + """Information about a client. Reported by the client in the sync response. + + Args: + collab_interface: collab method interface of the client. + """ + self.collab_interface = collab_interface + + +class FoxController(Controller, FoxAdaptor): + + def __init__( + self, + server_obj_id: str = None, + collab_obj_ids: List[str] = None, + incoming_call_filters=None, + outgoing_call_filters=None, + incoming_result_filters=None, + outgoing_result_filters=None, + props=None, + resource_dirs=None, + sync_task_timeout=5, + max_call_threads=100, + ): + Controller.__init__(self) + FoxAdaptor.__init__( + self, + props=props, + resource_dirs=resource_dirs, + collab_obj_ids=collab_obj_ids, + incoming_call_filters=incoming_call_filters, + outgoing_call_filters=outgoing_call_filters, + incoming_result_filters=incoming_result_filters, + outgoing_result_filters=outgoing_result_filters, + ) + self.server_obj_id = server_obj_id # component name + self.sync_task_timeout = sync_task_timeout + self.server_app = None + self.client_info = {} # client name => _ClientInfo + self.cell = None + self.thread_executor = ThreadPoolExecutor(max_workers=max_call_threads, thread_name_prefix="fox_call") + + def start_controller(self, fl_ctx: FLContext): + engine = fl_ctx.get_engine() + + server_obj = engine.get_component(self.server_obj_id) + if not server_obj: + self.system_panic(f"no component defined for {self.server_obj_id}", fl_ctx) + return + + app = ServerApp(server_obj) + + app.name = "server" + app.set_backend_type(BackendType.FLARE) + + err = self.process_config(app, fl_ctx) + if err: + self.system_panic(err, fl_ctx) + return + self.server_app = app + + def _prepare_client_backend(self, job_id, client: ClientSite, abort_signal: Signal, fl_ctx: FLContext): + return FlareBackend( + manager=self, + engine=fl_ctx.get_engine(), + caller=self.server_app.name, + cell=self.cell, + target_fqcn=FQCN.join([client.get_fqcn(), job_id]), + abort_signal=abort_signal, + thread_executor=self.thread_executor, + ) + + def _prepare_server_backend(self, job_id: str, abort_signal: Signal, fl_ctx: FLContext): + return FlareBackend( + manager=self, + engine=fl_ctx.get_engine(), + caller=self.server_app.name, + cell=self.cell, + target_fqcn=FQCN.join([FQCN.ROOT_SERVER, job_id]), + abort_signal=abort_signal, + thread_executor=self.thread_executor, + ) + + def _prepare_client_proxy( + self, + job_id: str, + client: ClientSite, + collab_interface: dict, + abort_signal, + fl_ctx: FLContext, + ): + backend = self._prepare_client_backend(job_id, client, abort_signal, fl_ctx) + proxy = Proxy( + app=self.server_app, + target_name=client.name, + target_fqn=client.get_fqsn(), + backend=backend, + target_interface=collab_interface.get(""), + ) + + for name, itf in collab_interface.items(): + if name == "": + continue + + p = Proxy( + app=self.server_app, + target_name=f"{client.name}.{name}", + target_fqn="", + backend=backend, + target_interface=itf, + ) + proxy.add_child(name, p) + return proxy + + def _prepare_server_proxy( + self, + job_id, + abort_signal, + collab_interface: dict, + fl_ctx: FLContext, + ): + server_name = self.server_app.name + backend = self._prepare_server_backend(job_id, abort_signal, fl_ctx) + proxy = Proxy( + app=self.server_app, + target_name=server_name, + target_fqn=server_name, + backend=backend, + target_interface=collab_interface.get(""), + ) + + collab_objs = self.server_app.get_collab_objects() + if collab_objs: + for name in collab_objs.keys(): + p = Proxy( + app=self.server_app, + target_name=f"{server_name}.{name}", + target_fqn="", + backend=backend, + target_interface=collab_interface.get(name), + ) + proxy.add_child(name, p) + return proxy + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + # configure all sites + server_collab_interface = self.server_app.get_collab_interface() + task_data = Shareable({SyncKey.COLLAB_INTERFACE: server_collab_interface}) + task = Task( + name=SYNC_TASK_NAME, + data=task_data, + timeout=int(self.sync_task_timeout), + result_received_cb=self._process_sync_reply, + ) + + engine = fl_ctx.get_engine() + self.logger.info(f"server engine {type(engine)}") + all_clients = engine.get_clients() + num_clients = len(all_clients) + for c in all_clients: + self.client_info[c.name] = None + + start_time = time.time() + self.broadcast_and_wait( + task=task, + min_responses=num_clients, + abort_signal=abort_signal, + fl_ctx=fl_ctx, + ) + time_taken = time.time() - start_time + self.log_info(fl_ctx, f"client sync took {time_taken} seconds") + + failed_clients = [] + for c, info in self.client_info.items(): + if not info: + failed_clients.append(c) + + if failed_clients: + self.system_panic( + f"failed to sync clients {failed_clients}", + fl_ctx, + ) + return + + self.log_info(fl_ctx, f"successfully synced clients {self.client_info.keys()}") + + # register msg CB for processing object calls + self.cell = engine.get_cell() + prepare_for_remote_call(self.cell, self.server_app, self.logger) + + # prepare proxies and backends + job_id = fl_ctx.get_job_id() + server_proxy = self._prepare_server_proxy(job_id, abort_signal, server_collab_interface, fl_ctx) + client_proxies = [] + for c in all_clients: + info = self.client_info[c.name] + # assert isinstance(info, _ClientInfo) + client_proxies.append(self._prepare_client_proxy(job_id, c, info.collab_interface, abort_signal, fl_ctx)) + + ws = FlareWorkspace(fl_ctx) + self.server_app.setup(ws, server_proxy, client_proxies, abort_signal) + run_server(self.server_app, self.logger) + + def _process_sync_reply(self, client_task: ClientTask, fl_ctx: FLContext): + result = client_task.result + client_name = client_task.client.name + + rc = result.get_return_code() + if rc == ReturnCode.OK: + self.log_info(fl_ctx, f"successfully synced client {client_name}") + collab_itf = result.get(SyncKey.COLLAB_INTERFACE) + self.client_info[client_name] = _ClientInfo(collab_itf) + else: + self.log_error(fl_ctx, f"client {client_task.client.name} failed to sync: {rc}") + self.client_info[client_name] = None + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + pass + + def stop_controller(self, fl_ctx: FLContext): + self.thread_executor.shutdown(wait=True, cancel_futures=True) diff --git a/nvflare/fox/sys/downloader.py b/nvflare/fox/sys/downloader.py new file mode 100644 index 0000000000..09b9fe3501 --- /dev/null +++ b/nvflare/fox/sys/downloader.py @@ -0,0 +1,142 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch + +from nvflare.app_common.np.np_downloader import add_arrays +from nvflare.app_common.np.np_downloader import download_arrays as pull_arrays +from nvflare.app_opt.pt.tensor_downloader import add_tensors +from nvflare.app_opt.pt.tensor_downloader import download_tensors as pull_tensors +from nvflare.fox import fox +from nvflare.fox.sys.backend import FlareBackend +from nvflare.fuel.f3.streaming.file_downloader import add_file +from nvflare.fuel.f3.streaming.file_downloader import download_file as pull_file +from nvflare.fuel.f3.streaming.obj_downloader import ObjectDownloader + + +class DownloadRefKey: + SOURCE = "source" + REF_ID = "ref_id" + OBJECT_TYPE = "object_type" + + +class ObjectType: + FILE = "file" + TENSORS = "tensors" + ARRAYS = "arrays" + + +class Downloader(ObjectDownloader): + + def __init__( + self, + num_receivers: int, + timeout: float, + ): + ctx = fox.context + backend = ctx.backend + if not isinstance(backend, FlareBackend): + raise ValueError(f"backend must be FlareBackend but got {type(backend)}") + + super().__init__( + cell=backend.cell, + timeout=timeout, + num_receivers=num_receivers, + ) + + def _to_ref(self, obj_type, ref_id): + return { + DownloadRefKey.OBJECT_TYPE: obj_type, + DownloadRefKey.REF_ID: ref_id, + DownloadRefKey.SOURCE: self.cell.get_fqcn(), + } + + def add_file( + self, + file_name: str, + chunk_size=None, + file_downloaded_cb=None, + **cb_kwargs, + ): + rid = add_file(self, file_name, chunk_size=chunk_size, file_downloaded_cb=file_downloaded_cb, **cb_kwargs) + return self._to_ref(ObjectType.FILE, rid) + + def add_tensors(self, tensors: dict[str, torch.Tensor], max_chunk_size: int = 0): + rid = add_tensors(self, tensors, max_chunk_size=max_chunk_size) + return self._to_ref(ObjectType.TENSORS, rid) + + def add_arrays(self, arrays: dict[str, np.ndarray], max_chunk_size: int = 0): + rid = add_arrays(self, arrays, max_chunk_size=max_chunk_size) + return self._to_ref(ObjectType.ARRAYS, rid) + + +def download_file(ref: dict, per_request_timeout: float): + ctx = fox.context + backend = ctx.backend + if not isinstance(backend, FlareBackend): + raise ValueError(f"backend must be FlareBackend but got {type(backend)}") + + obj_type = ref.get(DownloadRefKey.OBJECT_TYPE) + if obj_type != ObjectType.FILE: + raise ValueError(f"obj_type must be {ObjectType.FILE} but got {obj_type}") + + return pull_file( + from_fqcn=ref.get(DownloadRefKey.SOURCE), + ref_id=ref.get(DownloadRefKey.REF_ID), + per_request_timeout=per_request_timeout, + cell=backend.cell, + abort_signal=ctx.abort_signal, + ) + + +def download_tensors(ref: dict, per_request_timeout: float, tensors_received_cb=None, **cb_kwargs): + ctx = fox.context + backend = ctx.backend + if not isinstance(backend, FlareBackend): + raise ValueError(f"backend must be FlareBackend but got {type(backend)}") + + obj_type = ref.get(DownloadRefKey.OBJECT_TYPE) + if obj_type != ObjectType.TENSORS: + raise ValueError(f"obj_type must be {ObjectType.TENSORS} but got {obj_type}") + + return pull_tensors( + from_fqcn=ref.get(DownloadRefKey.SOURCE), + ref_id=ref.get(DownloadRefKey.REF_ID), + per_request_timeout=per_request_timeout, + cell=backend.cell, + abort_signal=ctx.abort_signal, + tensors_received_cb=tensors_received_cb, + **cb_kwargs, + ) + + +def download_arrays(ref: dict, per_request_timeout: float, arrays_received_cb=None, **cb_kwargs): + ctx = fox.context + backend = ctx.backend + if not isinstance(backend, FlareBackend): + raise ValueError(f"backend must be FlareBackend but got {type(backend)}") + + obj_type = ref.get(DownloadRefKey.OBJECT_TYPE) + if obj_type != ObjectType.ARRAYS: + raise ValueError(f"obj_type must be {ObjectType.ARRAYS} but got {obj_type}") + + return pull_arrays( + from_fqcn=ref.get(DownloadRefKey.SOURCE), + ref_id=ref.get(DownloadRefKey.REF_ID), + per_request_timeout=per_request_timeout, + cell=backend.cell, + abort_signal=ctx.abort_signal, + arrays_received_cb=arrays_received_cb, + **cb_kwargs, + ) diff --git a/nvflare/fox/sys/executor.py b/nvflare/fox/sys/executor.py new file mode 100644 index 0000000000..7ade83f405 --- /dev/null +++ b/nvflare/fox/sys/executor.py @@ -0,0 +1,205 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List + +from nvflare.apis.client import Client, from_dict +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import FLContextKey, ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import JobMetaKey +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.fox.api.app import ClientApp +from nvflare.fox.api.constants import MAKE_CLIENT_APP_METHOD, BackendType +from nvflare.fox.api.proxy import Proxy +from nvflare.fuel.f3.cellnet.fqcn import FQCN + +from .adaptor import FoxAdaptor +from .backend import FlareBackend +from .constants import SYNC_TASK_NAME, SyncKey +from .utils import prepare_for_remote_call +from .ws import FlareWorkspace + + +class FoxExecutor(Executor, FoxAdaptor): + + def __init__( + self, + client_obj_id: str, + collab_obj_ids: List[str] = None, + incoming_call_filters=None, + outgoing_call_filters=None, + incoming_result_filters=None, + outgoing_result_filters=None, + props: Dict[str, Any] = None, + resource_dirs: Dict[str, str] = None, + max_call_threads=100, + ): + Executor.__init__(self) + FoxAdaptor.__init__( + self, + collab_obj_ids=collab_obj_ids, + props=props, + resource_dirs=resource_dirs, + incoming_call_filters=incoming_call_filters, + outgoing_call_filters=outgoing_call_filters, + incoming_result_filters=incoming_result_filters, + outgoing_result_filters=outgoing_result_filters, + ) + self.client_obj_id = client_obj_id + self.register_event_handler(EventType.START_RUN, self._handle_start_run) + self.register_event_handler(EventType.END_RUN, self._handle_end_run) + self.client_app = None + self.client_ctx = None + self.thread_executor = ThreadPoolExecutor(max_workers=max_call_threads, thread_name_prefix="fox_call") + + def _handle_start_run(self, event_type: str, fl_ctx: FLContext): + fl_ctx.set_prop(FLContextKey.FOX_MODE, True, private=True, sticky=True) + engine = fl_ctx.get_engine() + client_obj = engine.get_component(self.client_obj_id) + if not client_obj: + self.system_panic(f"cannot get client component {self.client_obj_id}", fl_ctx) + return + + client_name = fl_ctx.get_identity_name() + + app = ClientApp(client_obj) + + # If the app contains "make_client_app" method, call it to make the app instance! + make_client_app_f = getattr(app, MAKE_CLIENT_APP_METHOD, None) + if make_client_app_f and callable(make_client_app_f): + app = make_client_app_f(client_name, BackendType.FLARE) + if not isinstance(app, ClientApp): + raise RuntimeError(f"result returned by {MAKE_CLIENT_APP_METHOD} must be ClientApp but got {type(app)}") + + app.name = client_name + app.set_backend_type(BackendType.FLARE) + self.client_app = app + + err = self.process_config(self.client_app, fl_ctx) + if err: + self.system_panic(err, fl_ctx) + + def _handle_end_run(self, event_type: str, fl_ctx: FLContext): + if self.client_ctx: + self.logger.info(f"finalizing client app {self.client_app.name}") + self.client_app.finalize(self.client_ctx) + self.thread_executor.shutdown(wait=True, cancel_futures=True) + + def _prepare_server_proxy(self, job_id, cell, collab_interface: dict, abort_signal, fl_ctx: FLContext): + server_name = "server" + backend = FlareBackend( + manager=self, + engine=fl_ctx.get_engine(), + caller=self.client_app.name, + cell=cell, + target_fqcn=FQCN.join([FQCN.ROOT_SERVER, job_id]), + abort_signal=abort_signal, + thread_executor=self.thread_executor, + ) + proxy = Proxy( + app=self.client_app, + target_name=server_name, + target_fqn=server_name, + backend=backend, + target_interface=collab_interface.get(""), + ) + + for name, itf in collab_interface.items(): + if name == "": + # this is the server app itself + continue + p = Proxy( + app=self.client_app, + target_name=f"{server_name}.{name}", + target_fqn="", + backend=backend, + target_interface=itf, + ) + proxy.add_child(name, p) + return proxy + + def _prepare_client_proxy(self, job_id, cell, client: Client, abort_signal, collab_interface, fl_ctx: FLContext): + backend = FlareBackend( + manager=self, + engine=fl_ctx.get_engine(), + caller=self.client_app.name, + cell=cell, + target_fqcn=FQCN.join([client.get_fqcn(), job_id]), + abort_signal=abort_signal, + thread_executor=self.thread_executor, + ) + proxy = Proxy( + app=self.client_app, + target_name=client.name, + target_fqn=client.get_fqsn(), + backend=backend, + target_interface=collab_interface.get(""), + ) + + collab_objs = self.client_app.get_collab_objects() + if collab_objs: + for name in collab_objs.keys(): + p = Proxy( + app=self.client_app, + target_name=f"{client.name}.{name}", + target_fqn="", + backend=backend, + target_interface=collab_interface.get(name), + ) + proxy.add_child(name, p) + return proxy + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + if task_name != SYNC_TASK_NAME: + self.log_error(fl_ctx, f"received unsupported task {task_name}") + return make_reply(ReturnCode.TASK_UNKNOWN) + + server_collab_interface = shareable.get(SyncKey.COLLAB_INTERFACE) + client_collab_interface = self.client_app.get_collab_interface() + self.log_info(fl_ctx, f"{client_collab_interface=} {server_collab_interface=}") + + engine = fl_ctx.get_engine() + cell = engine.get_cell() + + prepare_for_remote_call( + cell=cell, + app=self.client_app, + logger=self.logger, + ) + + # build proxies + job_id = fl_ctx.get_job_id() + server_proxy = self._prepare_server_proxy(job_id, cell, server_collab_interface, abort_signal, fl_ctx) + + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + job_clients = job_meta.get(JobMetaKey.JOB_CLIENTS) + all_clients = [from_dict(d) for d in job_clients] + client_proxies = [] + for c in all_clients: + p = self._prepare_client_proxy(job_id, cell, c, abort_signal, client_collab_interface, fl_ctx) + client_proxies.append(p) + + ws = FlareWorkspace(fl_ctx) + self.client_app.setup(ws, server_proxy, client_proxies, abort_signal) + + self.client_ctx = self.client_app.new_context(self.client_app.name, self.client_app.name) + self.logger.info(f"initializing client app {self.client_app.name}") + self.client_app.initialize(self.client_ctx) + + reply = make_reply(ReturnCode.OK) + reply[SyncKey.COLLAB_INTERFACE] = client_collab_interface + return reply diff --git a/nvflare/fox/sys/recipe.py b/nvflare/fox/sys/recipe.py new file mode 100644 index 0000000000..1d7b43c9fd --- /dev/null +++ b/nvflare/fox/sys/recipe.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from nvflare.fox.api.app import App, ClientApp, ServerApp +from nvflare.fox.api.filter import FilterChain +from nvflare.fuel.utils.validation_utils import check_positive_int, check_positive_number, check_str +from nvflare.job_config.api import FedJob +from nvflare.recipe.spec import Recipe + +from .controller import FoxController +from .executor import FoxExecutor + + +class FoxRecipe(Recipe): + + def __init__( + self, + job_name: str, + server: object, + client: object, + server_objects: dict[str, object] = None, + client_objects: dict[str, object] = None, + sync_task_timeout=5, + max_call_threads_for_server=100, + max_call_threads_for_client=100, + min_clients: int = 1, + ): + check_str("job_name", job_name) + check_positive_number("sync_task_timeout", sync_task_timeout) + check_positive_int("max_call_threads_for_server", max_call_threads_for_server) + check_positive_int("max_call_threads_for_client", max_call_threads_for_client) + check_positive_int("min_clients", min_clients) + + self.job_name = job_name + self.server_app = ServerApp(server) + self.client_app = ClientApp(client) + + if server_objects: + for name, obj in server_objects.items(): + self.server_app.add_collab_object(name, obj) + + if client_objects: + for name, obj in client_objects.items(): + self.client_app.add_collab_object(name, obj) + + self.sync_task_timeout = sync_task_timeout + self.max_call_threads_for_server = max_call_threads_for_server + self.max_call_threads_for_client = max_call_threads_for_client + self.min_clients = min_clients + + job = FedJob(name=self.job_name, min_clients=self.min_clients) + Recipe.__init__(self, job) + + def set_server_prop(self, name: str, value): + self.server_app.set_prop(name, value) + + def set_server_resource_dirs(self, resource_dirs): + self.server_app.set_resource_dirs(resource_dirs) + + def add_server_outgoing_call_filters(self, pattern: str, filters: List[object]): + self.server_app.add_outgoing_call_filters(pattern, filters) + + def add_server_incoming_call_filters(self, pattern: str, filters: List[object]): + self.server_app.add_incoming_call_filters(pattern, filters) + + def add_server_outgoing_result_filters(self, pattern: str, filters: List[object]): + self.server_app.add_outgoing_result_filters(pattern, filters) + + def add_server_incoming_result_filters(self, pattern: str, filters: List[object]): + self.server_app.add_incoming_result_filters(pattern, filters) + + def add_client_outgoing_call_filters(self, pattern: str, filters: List[object]): + self.client_app.add_outgoing_call_filters(pattern, filters) + + def add_client_incoming_call_filters(self, pattern: str, filters: List[object]): + self.client_app.add_incoming_call_filters(pattern, filters) + + def add_client_outgoing_result_filters(self, pattern: str, filters: List[object]): + self.client_app.add_outgoing_result_filters(pattern, filters) + + def add_client_incoming_result_filters(self, pattern: str, filters: List[object]): + self.client_app.add_incoming_result_filters(pattern, filters) + + def set_client_prop(self, name: str, value): + self.client_app.set_prop(name, value) + + def set_client_resource_dirs(self, resource_dirs): + self.client_app.set_resource_dirs(resource_dirs) + + def finalize(self) -> FedJob: + server_obj_id = self.job.to_server(self.server_app.obj, "_server") + job = self.job + + collab_obj_ids, in_cf_arg, out_cf_arg, in_rf_arg, out_rf_arg = self._create_app_args( + self.server_app, job.to_server + ) + + controller = FoxController( + server_obj_id=server_obj_id, + collab_obj_ids=collab_obj_ids, + incoming_call_filters=in_cf_arg, + outgoing_call_filters=out_cf_arg, + incoming_result_filters=in_rf_arg, + outgoing_result_filters=out_rf_arg, + sync_task_timeout=self.sync_task_timeout, + max_call_threads=self.max_call_threads_for_server, + props=self.server_app.get_props(), + resource_dirs=self.server_app.get_resource_dirs(), + ) + + job.to_server(controller, id="controller") + + # add client config + client_obj_id = job.to_clients(self.client_app.obj, "_client") + c_collab_obj_ids, c_in_cf_arg, c_out_cf_arg, c_in_rf_arg, c_out_rf_arg = self._create_app_args( + self.client_app, job.to_clients + ) + executor = FoxExecutor( + client_obj_id=client_obj_id, + collab_obj_ids=c_collab_obj_ids, + incoming_call_filters=c_in_cf_arg, + outgoing_call_filters=c_out_cf_arg, + incoming_result_filters=c_in_rf_arg, + outgoing_result_filters=c_out_rf_arg, + max_call_threads=self.max_call_threads_for_client, + props=self.client_app.get_props(), + resource_dirs=self.client_app.get_resource_dirs(), + ) + job.to_clients(executor, id="executor", tasks=["*"]) + return job + + def _create_app_args(self, app: App, to_f): + # collab objs + collab_obj_ids = [] + collab_objs = app.get_collab_objects() + for name, obj in collab_objs.items(): + if obj == app.obj: + # do not include in collab objs since it's done separately. + continue + comp_id = to_f(obj, id=name) + collab_obj_ids.append(comp_id) + + # build filter components + # since a filter object could be used multiple times, we must make sure that only one component is created + # for the same object! + filter_comp_table = {} + incoming_call_filters = app.get_incoming_call_filters() + outgoing_call_filters = app.get_outgoing_call_filters() + incoming_result_filters = app.get_incoming_result_filters() + outgoing_result_filters = app.get_outgoing_result_filters() + + self._create_filter_components(to_f, incoming_call_filters, filter_comp_table) + self._create_filter_components(to_f, outgoing_call_filters, filter_comp_table) + self._create_filter_components(to_f, incoming_result_filters, filter_comp_table) + self._create_filter_components(to_f, outgoing_result_filters, filter_comp_table) + + # filters + in_cf_arg = self._create_filter_chain_arg(incoming_call_filters, filter_comp_table) + out_cf_arg = self._create_filter_chain_arg(outgoing_call_filters, filter_comp_table) + in_rf_arg = self._create_filter_chain_arg(incoming_result_filters, filter_comp_table) + out_rf_arg = self._create_filter_chain_arg(outgoing_result_filters, filter_comp_table) + return collab_obj_ids, in_cf_arg, out_cf_arg, in_rf_arg, out_rf_arg + + @staticmethod + def _create_filter_chain_arg(filter_chains: list, comp_table: dict): + result = [] + for chain in filter_chains: + assert isinstance(chain, FilterChain) + filter_ids = [] + for f in chain.filters: + f = f.get_impl_object() + comp_id = comp_table[id(f)] + filter_ids.append(comp_id) + d = {"pattern": chain.pattern, "filters": filter_ids} + result.append(d) + return result + + @staticmethod + def _create_filter_components(to_f, filter_chains: list, comp_table: dict): + for chain in filter_chains: + assert isinstance(chain, FilterChain) + for f in chain.filters: + f = f.get_impl_object() + fid = id(f) + comp_id = comp_table.get(fid) + if not comp_id: + comp_id = to_f(f, id="_filter") + comp_table[fid] = comp_id diff --git a/nvflare/fox/sys/utils.py b/nvflare/fox/sys/utils.py new file mode 100644 index 0000000000..6342c553e5 --- /dev/null +++ b/nvflare/fox/sys/utils.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox.api.app import App +from nvflare.fox.api.constants import CollabMethodArgName +from nvflare.fox.api.dec import adjust_kwargs +from nvflare.fox.api.utils import check_call_args +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.utils import new_cell_message +from nvflare.fuel.f3.message import Message + +from ...security.logging import secure_log_traceback +from .constants import MSG_CHANNEL, MSG_TOPIC, CallReplyKey, ObjectCallKey + + +def prepare_for_remote_call(cell, app, logger): + logger.info(f"register cb for cell {cell.get_fqcn()}: {type(cell)}") + cell.register_request_cb(channel=MSG_CHANNEL, topic=MSG_TOPIC, cb=_call_app_method, app=app, logger=logger) + logger.info(f"registered request CB for {MSG_CHANNEL}/{MSG_TOPIC}") + + +def _error_reply(error: str, logger) -> Message: + logger.error(error) + return new_cell_message( + headers={MessageHeaderKey.RETURN_CODE: ReturnCode.PROCESS_EXCEPTION}, payload={CallReplyKey.ERROR: error} + ) + + +def _preprocess(app: App, caller, target_obj_name, target_name, func_name, func, args, kwargs): + ctx = app.new_context(caller=caller, callee=app.name) + kwargs = app.apply_incoming_call_filters(target_name, func_name, kwargs, ctx) + + # make sure the final kwargs conforms to func interface + obj_itf = app.get_target_object_collab_interface(target_obj_name) + if not obj_itf: + raise RuntimeError(f"cannot find collab interface for object {target_obj_name}") + + func_itf = obj_itf.get(func_name) + if not func_itf: + raise RuntimeError(f"cannot find interface for func '{func_name}' of object {target_obj_name}") + + check_call_args(func_name, func_itf, args, kwargs) + + kwargs[CollabMethodArgName.CONTEXT] = ctx + adjust_kwargs(func, kwargs) + return ctx, kwargs + + +def _call_app_method(request: Message, app: App, logger) -> Message: + logger.debug("got a remote call") + payload = request.payload + if not isinstance(payload, dict): + raise RuntimeError(f"request payload must be dict but got {type(payload)}") + + caller = payload.get(ObjectCallKey.CALLER) + if not caller: + return _error_reply(f"missing '{ObjectCallKey.CALLER}' from call", logger) + + method_name = payload.get(ObjectCallKey.METHOD_NAME) + if not method_name: + return _error_reply(f"missing '{ObjectCallKey.METHOD_NAME}' from call", logger) + + target_name = payload.get(ObjectCallKey.TARGET_NAME) + if not isinstance(target_name, str): + return _error_reply( + f"bad '{ObjectCallKey.TARGET_NAME}' from call: expect str but got {type(target_name)}", + logger, + ) + + method_args = payload.get(ObjectCallKey.ARGS) + if not method_args: + method_args = [] + elif not isinstance(method_args, (list, tuple)): + return _error_reply(f"bad method args: should be list/tuple but got {type(method_args)}", logger) + + method_kwargs = payload.get(ObjectCallKey.KWARGS) + if not method_kwargs: + method_kwargs = {} + elif not isinstance(method_kwargs, dict): + return _error_reply(f"bad method kwargs: should be dict but got {type(method_kwargs)}", logger) + + parts = target_name.split(".") + obj_name = "" + if len(parts) >= 2: + obj_name = parts[1] + if obj_name: + target_objs = app.get_collab_objects() + target_obj = target_objs.get(obj_name) + logger.debug(f"calling target obj: {app.name}.{obj_name}") + else: + target_obj = app + logger.debug(f"calling target app: {app.name}") + + if not target_obj: + return _error_reply(f"no object named '{target_name}'", logger) + + m = app.find_collab_method(target_obj, method_name) + if not m: + return _error_reply(f"no method named '{method_name}' or it is not collab", logger) + else: + logger.debug(f"found method for {method_name}") + + # invoke this method + try: + ctx, method_kwargs = _preprocess(app, caller, obj_name, target_name, method_name, m, method_args, method_kwargs) + result = m(*method_args, **method_kwargs) + + # apply result filters + result = app.apply_outgoing_result_filters(target_name, method_name, result, ctx) + + return new_cell_message( + headers={MessageHeaderKey.RETURN_CODE: ReturnCode.OK}, payload={CallReplyKey.RESULT: result} + ) + except Exception as ex: + secure_log_traceback(logger) + return _error_reply(f"exception {type(ex)}", logger) diff --git a/nvflare/fox/sys/ws.py b/nvflare/fox/sys/ws.py new file mode 100644 index 0000000000..99d09d2e9f --- /dev/null +++ b/nvflare/fox/sys/ws.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_context import FLContext +from nvflare.apis.workspace import Workspace as NVFWorkspace +from nvflare.fox.api.workspace import Workspace + + +class FlareWorkspace(Workspace): + + def __init__(self, fl_ctx: FLContext): + super().__init__() + ws_obj = fl_ctx.get_workspace() + if not isinstance(ws_obj, NVFWorkspace): + raise RuntimeError(f"the ws_obj must be NVFWorkspace but got {type(ws_obj)}") + self.flare_ws = ws_obj + self.job_id = fl_ctx.get_job_id() + + def get_root_dir(self) -> str: + return self.flare_ws.get_root_dir() + + def get_work_dir(self) -> str: + return self.flare_ws.get_run_dir(self.job_id) + + def get_experiment_dir(self) -> str: + return self.get_work_dir() diff --git a/nvflare/fox/utils/__init__.py b/nvflare/fox/utils/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/fox/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fox/utils/tensor_receiver.py b/nvflare/fox/utils/tensor_receiver.py new file mode 100644 index 0000000000..d6b3224f3f --- /dev/null +++ b/nvflare/fox/utils/tensor_receiver.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fox import fox +from nvflare.fox.api.gcc import GroupCallContext +from nvflare.fox.sys.downloader import download_tensors +from nvflare.fuel.utils.log_utils import get_obj_logger + + +class TensorReceiver: + """This class implements the callback function to add partially received tensors to the result queue. + This will enable the semi-in-time tensor aggregation by iterating through result queue. + To use, the application simply sets the process_resp_cb to an instance of this class when making group call. + + Example: + fox.clients( + blocking=False, + process_resp_cb=TensorReceiver(), + ).train(...) + """ + + def __init__(self): + self.logger = get_obj_logger(self) + + def __call__(self, gcc: GroupCallContext, result): + self.logger.info(f"[{fox.call_info}] got train result from {fox.caller}: {result}") + model, model_type = result + if model_type == "ref": + err, _ = download_tensors( + ref=model, + per_request_timeout=5.0, + tensors_received_cb=self._receive_tensors, + gcc=gcc, + ) + if err: + raise RuntimeError(f"failed to download model {model}: {err}") + else: + return None + else: + return model + + def _receive_tensors(self, tensors, gcc: GroupCallContext): + self.logger.info(f"adding partial result: {tensors}") + gcc.add_partial_result(tensors) + return None diff --git a/nvflare/fuel/f3/cellnet/cell.py b/nvflare/fuel/f3/cellnet/cell.py index 837548bf89..d6f0c52e1e 100644 --- a/nvflare/fuel/f3/cellnet/cell.py +++ b/nvflare/fuel/f3/cellnet/cell.py @@ -318,6 +318,8 @@ def _send_request( secure=False, optional=False, abort_signal: Signal = None, + send_complete_cb=None, + **cb_kwargs, ): """Stream one request to the target @@ -335,7 +337,9 @@ def _send_request( """ self._encode_message(request, abort_signal) - return self._send_one_request(channel, target, topic, request, timeout, secure, optional, abort_signal) + return self._send_one_request( + channel, target, topic, request, timeout, secure, optional, abort_signal, send_complete_cb, **cb_kwargs + ) def _send_one_request( self, @@ -347,6 +351,8 @@ def _send_one_request( secure=False, optional=False, abort_signal=None, + send_complete_cb=None, + **cb_kwargs, ): req_id = str(uuid.uuid4()) request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id}) @@ -368,6 +374,10 @@ def _send_one_request( # sending with progress timeout self.logger.debug(f"{req_id=}: entering sending wait {timeout=}") sending_complete = self._future_wait(future, timeout, abort_signal) + + if send_complete_cb: + send_complete_cb(**cb_kwargs) + if not sending_complete: self.logger.debug(f"{req_id=}: sending timeout {timeout=}") return self._get_result(req_id) diff --git a/nvflare/fuel/utils/tree_utils.py b/nvflare/fuel/utils/tree_utils.py index ebd92b8850..bda0c84a06 100644 --- a/nvflare/fuel/utils/tree_utils.py +++ b/nvflare/fuel/utils/tree_utils.py @@ -85,6 +85,7 @@ def __init__(self): """Constructor of Forest""" self.roots = [] # one or more names of the root nodes self.nodes = {} # name => Node + self.leaves = [] # names of leaf nodes def build_forest(objs: List[Any], get_name_f, get_fqn_f, **kwargs) -> Forest: @@ -140,6 +141,10 @@ def build_forest(objs: List[Any], get_name_f, get_fqn_f, **kwargs) -> Forest: # this node has no parent - it's a root forest.roots.append(name) + for name, node in forest.nodes.items(): + if not node.children: + forest.leaves.append(name) + return forest diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index a4226e4f27..acc0527830 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -734,6 +734,11 @@ def init_run(self, app_root, args): self.log_debug(fl_ctx, "firing event EventType.START_RUN") self.fire_event(EventType.START_RUN, fl_ctx) self.log_info(fl_ctx, "client runner started") + fox_mode = fl_ctx.get_prop(FLContextKey.FOX_MODE, False) + if fox_mode: + # in fox mode, all tasks go to the server + self.logger.debug(f"changed parent target from {self.parent_target} to {FQCN.ROOT_SERVER}") + self.parent_target = FQCN.ROOT_SERVER def _handle_sync_runner(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: # simply ack diff --git a/nvflare/private/fed/client/utils.py b/nvflare/private/fed/client/utils.py index f18690d8cc..f5f2d8fc02 100644 --- a/nvflare/private/fed/client/utils.py +++ b/nvflare/private/fed/client/utils.py @@ -14,6 +14,7 @@ from typing import Optional from nvflare.apis.client import Client +from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.private.fed.utils.identity_utils import get_parent_site_name @@ -49,6 +50,10 @@ def determine_parent_fqcn(client_config: dict, fl_ctx: FLContext) -> str: Returns: the FQCN of the parent cell """ + fox_mode = fl_ctx.get_prop(FLContextKey.FOX_MODE, False) + if fox_mode: + return FQCN.ROOT_SERVER + parent_client_name = determine_parent_name(client_config) if parent_client_name: engine = fl_ctx.get_engine() diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index b2187595c9..455a80b001 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -448,7 +448,10 @@ def set_run_manager(self, run_manager: RunManager): self.run_manager.add_handler(widget) def get_cell(self): - return self.cell + if self.cell: + return self.cell + elif self.run_manager and self.run_manager.cell: + return self.run_manager.cell def initialize_comm(self, cell: Cell): """This is called when the communication cell has been created. diff --git a/nvflare/recipe/spec.py b/nvflare/recipe/spec.py index 53307ba8a3..cf8435e666 100644 --- a/nvflare/recipe/spec.py +++ b/nvflare/recipe/spec.py @@ -109,6 +109,14 @@ def __init__(self, job: FedJob): def process_env(self, env: ExecEnv): pass + def finalize(self): + """Called to finalize the setup of the recipe. + + Returns: + + """ + pass + def add_client_input_filter( self, filter: Filter, tasks: Optional[List[str]] = None, clients: Optional[List[str]] = None ): @@ -225,6 +233,8 @@ def export( Returns: None """ + self.finalize() + if server_exec_params: self.job.to_server(server_exec_params) @@ -247,6 +257,8 @@ def execute(self, env: ExecEnv, server_exec_params: dict = None, client_exec_par Returns: Run to get job ID and execution results """ + self.finalize() + if server_exec_params: self.job.to_server(server_exec_params)