Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions tests/core/pubsub/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Shared fixtures and helpers for pubsub tests."""

from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import dataclasses
from typing import Any

import pytest
import trio

from libp2p.abc import IHost
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
from tests.utils.factories import PubsubFactory
from tests.utils.pubsub.utils import dense_connect


@dataclasses.dataclass(frozen=True, slots=True)
class GossipSubHarness:
"""Typed wrapper around a batch of GossipSub-backed pubsub instances."""

pubsubs: tuple[Pubsub, ...]

@property
def hosts(self) -> tuple[IHost, ...]:
return tuple(ps.host for ps in self.pubsubs)

@property
def routers(self) -> tuple[GossipSub, ...]:
result: list[GossipSub] = []
for ps in self.pubsubs:
r = ps.router
assert isinstance(r, GossipSub), f"Expected GossipSub, got {type(r)}"
result.append(r)
return tuple(result)

def __len__(self) -> int:
return len(self.pubsubs)


@asynccontextmanager
async def gossipsub_nodes(n: int, **kwargs: Any) -> AsyncIterator[GossipSubHarness]:
"""
Create *n* GossipSub-backed pubsub nodes wrapped in a harness.

Usage::

async with gossipsub_nodes(3, heartbeat_interval=0.5) as h:
h.pubsubs # tuple[Pubsub, ...]
h.hosts # tuple[IHost, ...]
h.routers # tuple[GossipSub, ...]
"""
async with PubsubFactory.create_batch_with_gossipsub(n, **kwargs) as pubsubs:
yield GossipSubHarness(pubsubs=pubsubs)


@asynccontextmanager
async def connected_gossipsub_nodes(
n: int, **kwargs: Any
) -> AsyncIterator[GossipSubHarness]:
"""Create *n* GossipSub nodes with dense connectivity."""
async with gossipsub_nodes(n, **kwargs) as harness:
await dense_connect(harness.hosts)
await trio.sleep(0.1)
Comment on lines +64 to +66
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connected_gossipsub_nodes() uses a fixed trio.sleep(0.1) after dense_connect(). This makes the fixture timing-dependent and can still be flaky (or unnecessarily slow) since pubsub peer registration is asynchronous; the codebase already has event-based helpers like Pubsub.wait_for_peer() intended to replace arbitrary sleeps. Suggest replacing the fixed sleep with a deterministic wait (e.g., wait until each pubsub observes at least one peer when n > 1, or otherwise expose a configurable settle/wait strategy) so the fixture actually guarantees “connected” semantics.

Suggested change
async with gossipsub_nodes(n, **kwargs) as harness:
await dense_connect(harness.hosts)
await trio.sleep(0.1)
peer_wait_timeout = kwargs.pop("peer_wait_timeout", 5.0)
async with gossipsub_nodes(n, **kwargs) as harness:
await dense_connect(harness.hosts)
if n > 1:
with trio.fail_after(peer_wait_timeout):
for index, pubsub in enumerate(harness.pubsubs):
target_host = harness.hosts[(index + 1) % n]
await pubsub.wait_for_peer(target_host.get_id())

Copilot uses AI. Check for mistakes.
yield harness


@asynccontextmanager
async def subscribed_mesh(
topic: str, n: int, *, settle_time: float = 1.0, **kwargs: Any
) -> AsyncIterator[GossipSubHarness]:
"""
Create *n* connected GossipSub nodes all subscribed to *topic*.

Waits *settle_time* seconds for mesh formation before yielding.
"""
async with connected_gossipsub_nodes(n, **kwargs) as harness:
for ps in harness.pubsubs:
await ps.subscribe(topic)
await trio.sleep(settle_time)
yield harness


@pytest.fixture
async def connected_gossipsub_pair() -> AsyncIterator[GossipSubHarness]:
"""Fixture: two connected GossipSub nodes with default config."""
async with connected_gossipsub_nodes(2) as harness:
yield harness
80 changes: 3 additions & 77 deletions tests/core/pubsub/test_dummyaccount_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
from collections.abc import (
Callable,
)
import logging

import pytest
import trio

Expand All @@ -12,69 +7,9 @@
from tests.utils.pubsub.dummy_account_node import (
DummyAccountNode,
)

logger = logging.getLogger(__name__)


async def wait_for_convergence(
nodes: tuple[DummyAccountNode, ...],
check: Callable[[DummyAccountNode], bool],
timeout: float = 10.0,
poll_interval: float = 0.02,
log_success: bool = False,
raise_last_exception_on_timeout: bool = True,
) -> None:
"""
Wait until all nodes satisfy the check condition.

Returns as soon as convergence is reached, otherwise raises TimeoutError.
Convergence already guarantees all nodes satisfy the check, so callers need
not run a second assertion pass after this returns.
"""
start_time = trio.current_time()

last_exception: Exception | None = None
last_exception_node: int | None = None

while True:
failed_indices: list[int] = []
for i, node in enumerate(nodes):
try:
ok = check(node)
except Exception as exc:
ok = False
last_exception = exc
last_exception_node = i
if not ok:
failed_indices.append(i)

if not failed_indices:
elapsed = trio.current_time() - start_time
if log_success:
logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes))
return

elapsed = trio.current_time() - start_time
if elapsed > timeout:
if raise_last_exception_on_timeout and last_exception is not None:
# Preserve the underlying assertion/exception signal (and its message)
# instead of hiding it behind a generic timeout.
node_hint = (
f" (node index {last_exception_node})"
if last_exception_node is not None
else ""
)
raise AssertionError(
f"Convergence failed{node_hint}: {last_exception}"
) from last_exception

raise TimeoutError(
f"Convergence timeout after {elapsed:.2f}s. "
f"Failed nodes: {failed_indices}. "
f"(Hint: run with -s and pass log_success=True for timing logs)"
)

await trio.sleep(poll_interval)
from tests.utils.pubsub.wait import (
wait_for_convergence,
)


async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
Expand Down Expand Up @@ -116,7 +51,6 @@ def _check_final(node: DummyAccountNode) -> bool:
# Success, terminate pending tasks.


@pytest.mark.trio
async def test_simple_two_nodes():
num_nodes = 2
adj_map = {0: [1]}
Expand All @@ -130,7 +64,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_simple_three_nodes_line_topography():
num_nodes = 3
adj_map = {0: [1], 1: [2]}
Expand All @@ -144,7 +77,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_simple_three_nodes_triangle_topography():
num_nodes = 3
adj_map = {0: [1, 2], 1: [2]}
Expand All @@ -158,7 +90,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_simple_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
Expand All @@ -172,7 +103,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_set_then_send_from_root_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
Expand All @@ -197,7 +127,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
Expand All @@ -216,7 +145,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_simple_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
Expand All @@ -230,7 +158,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
Expand All @@ -252,7 +179,6 @@ def assertion_func(dummy_node):
await perform_test(num_nodes, adj_map, action_func, assertion_func)


@pytest.mark.trio
@pytest.mark.slow
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
Expand Down
114 changes: 114 additions & 0 deletions tests/utils/pubsub/wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Polling helpers for pubsub test synchronization."""

from __future__ import annotations

from collections.abc import Callable
import inspect
import logging
from typing import TYPE_CHECKING

import trio

if TYPE_CHECKING:
from tests.utils.pubsub.dummy_account_node import DummyAccountNode

logger = logging.getLogger(__name__)


async def wait_for(
predicate: Callable[[], object],
*,
timeout: float = 10.0,
poll_interval: float = 0.02,
fail_msg: str = "",
) -> None:
"""
Poll until *predicate()* returns a truthy value, or raise ``TimeoutError``.

Supports both sync and async predicates. If the predicate raises an
exception it is treated as falsy; on timeout the last such exception is
chained to the ``TimeoutError``.
"""
_is_async = inspect.iscoroutinefunction(predicate)
start = trio.current_time()
last_exc: Exception | None = None

while True:
try:
result = (await predicate()) if _is_async else predicate() # type: ignore[misc]
if result:
Comment on lines +32 to +39
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait_for() determines async-ness via inspect.iscoroutinefunction(predicate), which misses common patterns like lambda: async_fn() or functools.partial(async_fn, ...). In those cases predicate() returns a coroutine object (truthy) and the helper will return immediately without awaiting it, potentially causing false positives and "coroutine was never awaited" warnings. Consider calling predicate() each loop and then await if the returned value is awaitable (e.g., inspect.isawaitable(result)), and update the type hint accordingly to avoid the type: ignore.

Copilot uses AI. Check for mistakes.
return
except Exception as exc:
last_exc = exc

elapsed = trio.current_time() - start
if elapsed > timeout:
msg = fail_msg or f"wait_for timed out after {elapsed:.2f}s"
err = TimeoutError(msg)
if last_exc is not None:
raise err from last_exc
raise err

await trio.sleep(poll_interval)


async def wait_for_convergence(
nodes: tuple[DummyAccountNode, ...],
check: Callable[[DummyAccountNode], bool],
timeout: float = 10.0,
poll_interval: float = 0.02,
log_success: bool = False,
raise_last_exception_on_timeout: bool = True,
) -> None:
"""
Wait until all *nodes* satisfy *check*.

Returns as soon as convergence is reached, otherwise raises
``TimeoutError`` (or ``AssertionError`` when
*raise_last_exception_on_timeout* is ``True`` and a node raised).

Preserves the API of the original inline helper from
``test_dummyaccount_demo.py``.
"""
start_time = trio.current_time()

last_exception: Exception | None = None
last_exception_node: int | None = None

while True:
failed_indices: list[int] = []
for i, node in enumerate(nodes):
try:
ok = check(node)
except Exception as exc:
ok = False
last_exception = exc
last_exception_node = i
if not ok:
failed_indices.append(i)

if not failed_indices:
elapsed = trio.current_time() - start_time
if log_success:
logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes))
return

elapsed = trio.current_time() - start_time
if elapsed > timeout:
if raise_last_exception_on_timeout and last_exception is not None:
node_hint = (
f" (node index {last_exception_node})"
if last_exception_node is not None
else ""
)
raise AssertionError(
f"Convergence failed{node_hint}: {last_exception}"
) from last_exception

raise TimeoutError(
f"Convergence timeout after {elapsed:.2f}s. "
f"Failed nodes: {failed_indices}. "
f"(Hint: run with -s and pass log_success=True for timing logs)"
)

await trio.sleep(poll_interval)
Loading