diff --git a/tests/core/pubsub/conftest.py b/tests/core/pubsub/conftest.py new file mode 100644 index 000000000..b9e5631a5 --- /dev/null +++ b/tests/core/pubsub/conftest.py @@ -0,0 +1,90 @@ +"""Shared fixtures and helpers for pubsub tests.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +import dataclasses +from typing import Any + +import pytest +import trio + +from libp2p.abc import IHost +from libp2p.pubsub.gossipsub import GossipSub +from libp2p.pubsub.pubsub import Pubsub +from tests.utils.factories import PubsubFactory +from tests.utils.pubsub.utils import dense_connect + + +@dataclasses.dataclass(frozen=True, slots=True) +class GossipSubHarness: + """Typed wrapper around a batch of GossipSub-backed pubsub instances.""" + + pubsubs: tuple[Pubsub, ...] + + @property + def hosts(self) -> tuple[IHost, ...]: + return tuple(ps.host for ps in self.pubsubs) + + @property + def routers(self) -> tuple[GossipSub, ...]: + result: list[GossipSub] = [] + for ps in self.pubsubs: + r = ps.router + assert isinstance(r, GossipSub), f"Expected GossipSub, got {type(r)}" + result.append(r) + return tuple(result) + + def __len__(self) -> int: + return len(self.pubsubs) + + +@asynccontextmanager +async def gossipsub_nodes(n: int, **kwargs: Any) -> AsyncIterator[GossipSubHarness]: + """ + Create *n* GossipSub-backed pubsub nodes wrapped in a harness. + + Usage:: + + async with gossipsub_nodes(3, heartbeat_interval=0.5) as h: + h.pubsubs # tuple[Pubsub, ...] + h.hosts # tuple[IHost, ...] + h.routers # tuple[GossipSub, ...] + """ + async with PubsubFactory.create_batch_with_gossipsub(n, **kwargs) as pubsubs: + yield GossipSubHarness(pubsubs=pubsubs) + + +@asynccontextmanager +async def connected_gossipsub_nodes( + n: int, **kwargs: Any +) -> AsyncIterator[GossipSubHarness]: + """Create *n* GossipSub nodes with dense connectivity.""" + async with gossipsub_nodes(n, **kwargs) as harness: + await dense_connect(harness.hosts) + await trio.sleep(0.1) + yield harness + + +@asynccontextmanager +async def subscribed_mesh( + topic: str, n: int, *, settle_time: float = 1.0, **kwargs: Any +) -> AsyncIterator[GossipSubHarness]: + """ + Create *n* connected GossipSub nodes all subscribed to *topic*. + + Waits *settle_time* seconds for mesh formation before yielding. + """ + async with connected_gossipsub_nodes(n, **kwargs) as harness: + for ps in harness.pubsubs: + await ps.subscribe(topic) + await trio.sleep(settle_time) + yield harness + + +@pytest.fixture +async def connected_gossipsub_pair() -> AsyncIterator[GossipSubHarness]: + """Fixture: two connected GossipSub nodes with default config.""" + async with connected_gossipsub_nodes(2) as harness: + yield harness diff --git a/tests/core/pubsub/test_dummyaccount_demo.py b/tests/core/pubsub/test_dummyaccount_demo.py index 0018ba80f..2136f0ab3 100644 --- a/tests/core/pubsub/test_dummyaccount_demo.py +++ b/tests/core/pubsub/test_dummyaccount_demo.py @@ -1,8 +1,3 @@ -from collections.abc import ( - Callable, -) -import logging - import pytest import trio @@ -12,69 +7,9 @@ from tests.utils.pubsub.dummy_account_node import ( DummyAccountNode, ) - -logger = logging.getLogger(__name__) - - -async def wait_for_convergence( - nodes: tuple[DummyAccountNode, ...], - check: Callable[[DummyAccountNode], bool], - timeout: float = 10.0, - poll_interval: float = 0.02, - log_success: bool = False, - raise_last_exception_on_timeout: bool = True, -) -> None: - """ - Wait until all nodes satisfy the check condition. - - Returns as soon as convergence is reached, otherwise raises TimeoutError. - Convergence already guarantees all nodes satisfy the check, so callers need - not run a second assertion pass after this returns. - """ - start_time = trio.current_time() - - last_exception: Exception | None = None - last_exception_node: int | None = None - - while True: - failed_indices: list[int] = [] - for i, node in enumerate(nodes): - try: - ok = check(node) - except Exception as exc: - ok = False - last_exception = exc - last_exception_node = i - if not ok: - failed_indices.append(i) - - if not failed_indices: - elapsed = trio.current_time() - start_time - if log_success: - logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes)) - return - - elapsed = trio.current_time() - start_time - if elapsed > timeout: - if raise_last_exception_on_timeout and last_exception is not None: - # Preserve the underlying assertion/exception signal (and its message) - # instead of hiding it behind a generic timeout. - node_hint = ( - f" (node index {last_exception_node})" - if last_exception_node is not None - else "" - ) - raise AssertionError( - f"Convergence failed{node_hint}: {last_exception}" - ) from last_exception - - raise TimeoutError( - f"Convergence timeout after {elapsed:.2f}s. " - f"Failed nodes: {failed_indices}. " - f"(Hint: run with -s and pass log_success=True for timing logs)" - ) - - await trio.sleep(poll_interval) +from tests.utils.pubsub.wait import ( + wait_for_convergence, +) async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): @@ -116,7 +51,6 @@ def _check_final(node: DummyAccountNode) -> bool: # Success, terminate pending tasks. -@pytest.mark.trio async def test_simple_two_nodes(): num_nodes = 2 adj_map = {0: [1]} @@ -130,7 +64,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_simple_three_nodes_line_topography(): num_nodes = 3 adj_map = {0: [1], 1: [2]} @@ -144,7 +77,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_simple_three_nodes_triangle_topography(): num_nodes = 3 adj_map = {0: [1, 2], 1: [2]} @@ -158,7 +90,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_simple_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} @@ -172,7 +103,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_set_then_send_from_root_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} @@ -197,7 +127,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} @@ -216,7 +145,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_simple_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} @@ -230,7 +158,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} @@ -252,7 +179,6 @@ def assertion_func(dummy_node): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.trio @pytest.mark.slow async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 diff --git a/tests/utils/pubsub/wait.py b/tests/utils/pubsub/wait.py new file mode 100644 index 000000000..a9290e20a --- /dev/null +++ b/tests/utils/pubsub/wait.py @@ -0,0 +1,114 @@ +"""Polling helpers for pubsub test synchronization.""" + +from __future__ import annotations + +from collections.abc import Callable +import inspect +import logging +from typing import TYPE_CHECKING + +import trio + +if TYPE_CHECKING: + from tests.utils.pubsub.dummy_account_node import DummyAccountNode + +logger = logging.getLogger(__name__) + + +async def wait_for( + predicate: Callable[[], object], + *, + timeout: float = 10.0, + poll_interval: float = 0.02, + fail_msg: str = "", +) -> None: + """ + Poll until *predicate()* returns a truthy value, or raise ``TimeoutError``. + + Supports both sync and async predicates. If the predicate raises an + exception it is treated as falsy; on timeout the last such exception is + chained to the ``TimeoutError``. + """ + _is_async = inspect.iscoroutinefunction(predicate) + start = trio.current_time() + last_exc: Exception | None = None + + while True: + try: + result = (await predicate()) if _is_async else predicate() # type: ignore[misc] + if result: + return + except Exception as exc: + last_exc = exc + + elapsed = trio.current_time() - start + if elapsed > timeout: + msg = fail_msg or f"wait_for timed out after {elapsed:.2f}s" + err = TimeoutError(msg) + if last_exc is not None: + raise err from last_exc + raise err + + await trio.sleep(poll_interval) + + +async def wait_for_convergence( + nodes: tuple[DummyAccountNode, ...], + check: Callable[[DummyAccountNode], bool], + timeout: float = 10.0, + poll_interval: float = 0.02, + log_success: bool = False, + raise_last_exception_on_timeout: bool = True, +) -> None: + """ + Wait until all *nodes* satisfy *check*. + + Returns as soon as convergence is reached, otherwise raises + ``TimeoutError`` (or ``AssertionError`` when + *raise_last_exception_on_timeout* is ``True`` and a node raised). + + Preserves the API of the original inline helper from + ``test_dummyaccount_demo.py``. + """ + start_time = trio.current_time() + + last_exception: Exception | None = None + last_exception_node: int | None = None + + while True: + failed_indices: list[int] = [] + for i, node in enumerate(nodes): + try: + ok = check(node) + except Exception as exc: + ok = False + last_exception = exc + last_exception_node = i + if not ok: + failed_indices.append(i) + + if not failed_indices: + elapsed = trio.current_time() - start_time + if log_success: + logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes)) + return + + elapsed = trio.current_time() - start_time + if elapsed > timeout: + if raise_last_exception_on_timeout and last_exception is not None: + node_hint = ( + f" (node index {last_exception_node})" + if last_exception_node is not None + else "" + ) + raise AssertionError( + f"Convergence failed{node_hint}: {last_exception}" + ) from last_exception + + raise TimeoutError( + f"Convergence timeout after {elapsed:.2f}s. " + f"Failed nodes: {failed_indices}. " + f"(Hint: run with -s and pass log_success=True for timing logs)" + ) + + await trio.sleep(poll_interval)