Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nvflare/private/fed/client/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,13 @@ def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None):
)
job_id = fl_ctx.get_job_id()

# Use at least the server-required minimum (e.g. for tensor streaming). When the server
# sends MIN_GET_TASK_TIMEOUT we update self.timeout; the caller may still pass a smaller
# config value, so ensure we never use less than the required minimum.
if not timeout:
timeout = self.timeout
else:
timeout = max(timeout, self.timeout)

parent_fqcn = determine_parent_fqcn(self.client_config, fl_ctx)
self.logger.debug(f"pulling task from parent FQCN: {parent_fqcn}")
Expand Down
75 changes: 75 additions & 0 deletions tests/unit_test/app_opt/tensor_stream/timeout_management_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,81 @@ def get_header_side_effect(key, default=None):
for log_message in all_log_messages:
assert "Automatically adjusting" not in log_message

@pytest.mark.parametrize(
"initial_timeout,server_min_timeout,caller_timeout,expected_request_timeout",
[
(5.0, 360.0, 5.0, 360.0), # Caller passes small value, server min is large → use server min
(5.0, 360.0, 30.0, 360.0), # Caller passes moderate value, still below server min → use server min
(5.0, 360.0, 400.0, 400.0), # Caller passes value above server min → use caller value
(5.0, 360.0, 360.0, 360.0), # Caller passes exactly server min → use that
(5.0, 660.0, 600.0, 660.0), # Large server min, caller still below → use server min
],
)
@patch("nvflare.private.fed.client.communicator.new_cell_message")
@patch("nvflare.private.fed.client.communicator.determine_parent_fqcn")
@patch("nvflare.private.fed.client.communicator.gen_new_peer_ctx")
def test_explicit_timeout_enforces_server_minimum(
self,
mock_gen_ctx,
mock_determine_parent,
mock_new_cell_message,
initial_timeout,
server_min_timeout,
caller_timeout,
expected_request_timeout,
):
"""Test that an explicit caller-provided timeout is raised to the server minimum when needed.

When self.timeout has been bumped by a prior MIN_GET_TASK_TIMEOUT from the server,
and the caller passes an explicit (smaller) timeout, pull_task should use
max(caller_timeout, self.timeout) so the server-required minimum is respected.
"""
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode

communicator = Communicator(client_config={"client_name": "test_client"}, timeout=initial_timeout)
communicator.engine = Mock()
communicator.cell = Mock()

# Simulate a prior MIN_GET_TASK_TIMEOUT bump
communicator.timeout = server_min_timeout

mock_determine_parent.return_value = "parent_fqcn"
mock_new_cell_message.return_value = Mock()

# Create a successful response (content doesn't matter for this test)
response_shareable = Shareable()
response_shareable.set_header(ServerCommandKey.TASK_NAME, "train")
response_shareable.set_header(FLContextKey.TASK_ID, "task_123")

mock_task = Mock()

def get_header_side_effect(key, default=None):
return {
MessageHeaderKey.RETURN_CODE: ReturnCode.OK,
MessageHeaderKey.PAYLOAD_LEN: 1024,
}.get(key, default)

mock_task.get_header = Mock(side_effect=get_header_side_effect)
mock_task.payload = response_shareable
communicator.cell.send_request.return_value = mock_task

mock_fl_context = Mock()
mock_fl_context.get_job_id.return_value = "job_123"
mock_fl_context.get_run_abort_signal.return_value = None
mock_fl_context.set_prop = Mock()

communicator.logger = Mock()

# Call pull_task WITH an explicit timeout
communicator.pull_task("project", "token", "ssid", mock_fl_context, timeout=caller_timeout)

# Verify the actual timeout passed to cell.send_request
actual_timeout = communicator.cell.send_request.call_args[1]["timeout"]
assert actual_timeout == expected_request_timeout, (
f"Expected send_request timeout={expected_request_timeout}, got {actual_timeout}. "
f"caller_timeout={caller_timeout}, self.timeout={server_min_timeout}"
)

@patch("nvflare.private.fed.client.communicator.new_cell_message")
@patch("nvflare.private.fed.client.communicator.determine_parent_fqcn")
@patch("nvflare.private.fed.client.communicator.gen_new_peer_ctx")
Expand Down