Skip to content

Commit d39d101

Browse files
authored
typing: strengthen types in networking and discovery transport (#452)
* typing: strengthen types in networking and discovery transport * fix(discovery): use RequestId types in transport internal state and tests * fix(tests): update RequestId equality assertions in test_transport * style: clean up stray inline comments in transport.py
1 parent 65938cf commit d39d101

File tree

4 files changed

+41
-40
lines changed

4 files changed

+41
-40
lines changed

src/lean_spec/subspecs/networking/client/event_source.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,8 @@ async def _dial_quic(self, multiaddr: str) -> QuicConnection:
874874
QuicTransportError: If connection fails.
875875
"""
876876
await self._ensure_quic_manager()
877-
return await self.quic_manager.connect(multiaddr) # type: ignore[union-attr]
877+
assert self.quic_manager is not None
878+
return await self.quic_manager.connect(multiaddr)
878879

879880
async def listen(self, multiaddr: str) -> None:
880881
"""Start listening for incoming connections.
@@ -905,7 +906,8 @@ async def _listen_quic(self, multiaddr: str) -> None:
905906
multiaddr: QUIC address to listen on.
906907
"""
907908
await self._ensure_quic_manager()
908-
await self.quic_manager.listen( # type: ignore[union-attr]
909+
assert self.quic_manager is not None
910+
await self.quic_manager.listen(
909911
multiaddr,
910912
on_connection=self._handle_inbound_quic_connection,
911913
)

src/lean_spec/subspecs/networking/discovery/transport.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
class PendingRequest:
6969
"""Tracks a pending request awaiting response."""
7070

71-
request_id: bytes
71+
request_id: RequestId
7272
"""Request ID for matching responses."""
7373

7474
dest_node_id: NodeId
@@ -95,7 +95,7 @@ class PendingMultiRequest:
9595
across UDP packets when results exceed MTU.
9696
"""
9797

98-
request_id: bytes
98+
request_id: RequestId
9999
"""Request ID for matching responses."""
100100

101101
dest_node_id: NodeId
@@ -193,8 +193,8 @@ def __init__(
193193

194194
self._protocol: DiscoveryProtocol | None = None
195195
self._transport: asyncio.DatagramTransport | None = None
196-
self._pending_requests: dict[bytes, PendingRequest] = {}
197-
self._pending_multi_requests: dict[bytes, PendingMultiRequest] = {}
196+
self._pending_requests: dict[RequestId, PendingRequest] = {}
197+
self._pending_multi_requests: dict[RequestId, PendingMultiRequest] = {}
198198
self._node_addresses: dict[NodeId, tuple[str, int]] = {}
199199

200200
self._message_handler: (
@@ -384,12 +384,11 @@ async def _send_multi_response_request(
384384

385385
# Create collector for multiple responses.
386386
loop = asyncio.get_running_loop()
387-
request_id_bytes = bytes(message.request_id)
388387

389388
# Use a queue to collect multiple responses.
390389
response_queue: asyncio.Queue[DiscoveryMessage] = asyncio.Queue()
391390
pending = PendingMultiRequest(
392-
request_id=request_id_bytes,
391+
request_id=message.request_id,
393392
dest_node_id=dest_node_id,
394393
sent_at=loop.time(),
395394
nonce=nonce,
@@ -398,7 +397,7 @@ async def _send_multi_response_request(
398397
expected_total=None,
399398
received_count=0,
400399
)
401-
self._pending_multi_requests[request_id_bytes] = pending
400+
self._pending_multi_requests[message.request_id] = pending
402401

403402
# Send packet.
404403
self._transport.sendto(packet, dest_addr)
@@ -434,7 +433,7 @@ async def _send_multi_response_request(
434433
break
435434

436435
finally:
437-
self._pending_multi_requests.pop(request_id_bytes, None)
436+
self._pending_multi_requests.pop(message.request_id, None)
438437

439438
return responses
440439

@@ -494,17 +493,16 @@ async def _send_request(
494493
# Create pending request.
495494
loop = asyncio.get_running_loop()
496495
future: asyncio.Future[DiscoveryMessage | None] = loop.create_future()
497-
498-
request_id_bytes = bytes(message.request_id)
496+
request_id = message.request_id
499497
pending = PendingRequest(
500-
request_id=request_id_bytes,
498+
request_id=request_id,
501499
dest_node_id=dest_node_id,
502500
sent_at=loop.time(),
503501
nonce=nonce,
504502
message=message,
505503
future=future,
506504
)
507-
self._pending_requests[request_id_bytes] = pending
505+
self._pending_requests[request_id] = pending
508506

509507
# Send packet.
510508
self._transport.sendto(packet, dest_addr)
@@ -518,7 +516,7 @@ async def _send_request(
518516
except asyncio.TimeoutError:
519517
return None
520518
finally:
521-
self._pending_requests.pop(request_id_bytes, None)
519+
self._pending_requests.pop(request_id, None)
522520

523521
def _build_message_packet(
524522
self,
@@ -790,7 +788,7 @@ async def _handle_decoded_message(
790788
self._session_cache.touch(remote_node_id, ip, Port(port))
791789

792790
# Check if this is a response to a pending request.
793-
request_id = bytes(message.request_id)
791+
request_id = message.request_id
794792

795793
# Check for multi-response requests first (e.g., FINDNODE -> NODES).
796794
multi_pending = self._pending_multi_requests.get(request_id)

src/lean_spec/subspecs/networking/transport/quic/connection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def __init__(self, *args, **kwargs) -> None:
324324
"""Initialize the libp2p QUIC protocol handler."""
325325
super().__init__(*args, **kwargs)
326326
self.connection: QuicConnection | None = None
327-
self.peer_identity: bytes | None = None
327+
self.peer_identity: PeerId | None = None
328328
self.handshake_complete = asyncio.Event()
329329
self._buffered_events: list[QuicEvent] = []
330330

@@ -532,7 +532,8 @@ async def connect(self, multiaddr: str) -> QuicConnection:
532532
self._context_managers.append(cm)
533533

534534
# Cast to our protocol type to access custom attributes.
535-
protocol: LibP2PQuicProtocol = base_protocol # type: ignore[assignment]
535+
assert isinstance(base_protocol, LibP2PQuicProtocol)
536+
protocol: LibP2PQuicProtocol = base_protocol
536537

537538
# Wait for handshake to complete.
538539
await protocol.handshake_complete.wait()

tests/lean_spec/subspecs/networking/discovery/test_transport.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ async def test_stop_cancels_pending_requests(self, local_node_id, local_private_
276276
loop = asyncio.get_running_loop()
277277
future: asyncio.Future = loop.create_future()
278278
pending = PendingRequest(
279-
request_id=b"\x01\x02\x03\x04",
279+
request_id=RequestId(data=b"\x01\x02\x03\x04"),
280280
dest_node_id=NodeId(bytes(32)),
281281
sent_at=loop.time(),
282282
nonce=Nonce(bytes(12)),
@@ -315,15 +315,15 @@ def test_create_pending_request(self):
315315
message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1))
316316

317317
pending = PendingRequest(
318-
request_id=b"\x01\x02\x03\x04",
318+
request_id=RequestId(data=b"\x01\x02\x03\x04"),
319319
dest_node_id=NodeId(bytes(32)),
320320
sent_at=123.456,
321321
nonce=Nonce(bytes(12)),
322322
message=message,
323323
future=future,
324324
)
325325

326-
assert pending.request_id == b"\x01\x02\x03\x04"
326+
assert pending.request_id == RequestId(data=b"\x01\x02\x03\x04")
327327
assert pending.dest_node_id == bytes(32)
328328
assert pending.sent_at == 123.456
329329
assert pending.nonce == bytes(12)
@@ -407,7 +407,7 @@ def test_pending_multi_request_creation(self, local_node_id, local_private_key,
407407
queue: asyncio.Queue = asyncio.Queue()
408408

409409
pending = PendingMultiRequest(
410-
request_id=b"\x01\x02\x03\x04",
410+
request_id=RequestId(data=b"\x01\x02\x03\x04"),
411411
dest_node_id=NodeId(bytes(32)),
412412
sent_at=123.456,
413413
nonce=Nonce(bytes(12)),
@@ -417,7 +417,7 @@ def test_pending_multi_request_creation(self, local_node_id, local_private_key,
417417
received_count=0,
418418
)
419419

420-
assert pending.request_id == b"\x01\x02\x03\x04"
420+
assert pending.request_id == RequestId(data=b"\x01\x02\x03\x04")
421421
assert pending.expected_total is None
422422
assert pending.received_count == 0
423423

@@ -430,7 +430,7 @@ def test_pending_multi_request_expected_total_tracking(self):
430430
queue: asyncio.Queue = asyncio.Queue()
431431

432432
pending = PendingMultiRequest(
433-
request_id=b"\x01\x02\x03\x04",
433+
request_id=RequestId(data=b"\x01\x02\x03\x04"),
434434
dest_node_id=NodeId(bytes(32)),
435435
sent_at=123.456,
436436
nonce=Nonce(bytes(12)),
@@ -466,7 +466,7 @@ async def test_queue():
466466
queue: asyncio.Queue = asyncio.Queue()
467467

468468
pending = PendingMultiRequest(
469-
request_id=b"\x01\x02\x03\x04",
469+
request_id=RequestId(data=b"\x01\x02\x03\x04"),
470470
dest_node_id=NodeId(bytes(32)),
471471
sent_at=123.456,
472472
nonce=Nonce(bytes(12)),
@@ -572,7 +572,7 @@ def test_pending_request_stores_request_id(self):
572572
message = Ping(request_id=RequestId(data=b"\x01\x02\x03\x04"), enr_seq=SeqNumber(1))
573573

574574
pending = PendingRequest(
575-
request_id=b"\x01\x02\x03\x04",
575+
request_id=RequestId(data=b"\x01\x02\x03\x04"),
576576
dest_node_id=NodeId(bytes(32)),
577577
sent_at=123.456,
578578
nonce=Nonce(bytes(12)),
@@ -581,8 +581,8 @@ def test_pending_request_stores_request_id(self):
581581
)
582582

583583
# Request ID should be stored for matching.
584-
assert pending.request_id == b"\x01\x02\x03\x04"
585-
assert bytes(pending.message.request_id) == b"\x01\x02\x03\x04"
584+
assert pending.request_id == RequestId(data=b"\x01\x02\x03\x04")
585+
assert pending.message.request_id == RequestId(data=b"\x01\x02\x03\x04")
586586

587587
loop.close()
588588

@@ -596,7 +596,7 @@ async def test_future():
596596

597597
message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1))
598598
pending = PendingRequest(
599-
request_id=b"\x01",
599+
request_id=RequestId(data=b"\x01"),
600600
dest_node_id=NodeId(bytes(32)),
601601
sent_at=loop.time(),
602602
nonce=Nonce(bytes(12)),
@@ -632,7 +632,7 @@ def test_pending_request_future_cancellation(self):
632632

633633
message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1))
634634
pending = PendingRequest(
635-
request_id=b"\x01",
635+
request_id=RequestId(data=b"\x01"),
636636
dest_node_id=NodeId(bytes(32)),
637637
sent_at=loop.time(),
638638
nonce=Nonce(bytes(12)),
@@ -663,7 +663,7 @@ def test_request_id_bytes_for_dict_lookup(self):
663663
message2 = Ping(request_id=RequestId(data=request_id_2), enr_seq=SeqNumber(2))
664664

665665
pending1 = PendingRequest(
666-
request_id=request_id_1,
666+
request_id=RequestId(data=request_id_1),
667667
dest_node_id=NodeId(bytes(32)),
668668
sent_at=loop.time(),
669669
nonce=Nonce(bytes(12)),
@@ -672,7 +672,7 @@ def test_request_id_bytes_for_dict_lookup(self):
672672
)
673673

674674
pending2 = PendingRequest(
675-
request_id=request_id_2,
675+
request_id=RequestId(data=request_id_2),
676676
dest_node_id=NodeId(bytes(32)),
677677
sent_at=loop.time(),
678678
nonce=Nonce(bytes(12)),
@@ -703,7 +703,7 @@ def test_pending_request_stores_nonce_for_whoareyou_matching(self):
703703

704704
message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1))
705705
pending = PendingRequest(
706-
request_id=b"\x01",
706+
request_id=RequestId(data=b"\x01"),
707707
dest_node_id=NodeId(bytes(32)),
708708
sent_at=loop.time(),
709709
nonce=nonce,
@@ -724,7 +724,7 @@ def test_pending_request_stores_message_for_retransmission(self):
724724

725725
message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(42))
726726
pending = PendingRequest(
727-
request_id=b"\x01",
727+
request_id=RequestId(data=b"\x01"),
728728
dest_node_id=NodeId(bytes(32)),
729729
sent_at=loop.time(),
730730
nonce=Nonce(bytes(12)),
@@ -782,7 +782,7 @@ async def test_pending_requests_cleared_on_stop(
782782
for i in range(3):
783783
future: asyncio.Future = loop.create_future()
784784
pending = PendingRequest(
785-
request_id=bytes([i]),
785+
request_id=RequestId(data=bytes([i])),
786786
dest_node_id=NodeId(bytes(32)),
787787
sent_at=loop.time(),
788788
nonce=Nonce(bytes(12)),
@@ -825,7 +825,7 @@ async def test_pending_request_futures_cancelled_on_stop(
825825
future: asyncio.Future = loop.create_future()
826826
futures.append(future)
827827
pending = PendingRequest(
828-
request_id=bytes([i]),
828+
request_id=RequestId(data=bytes([i])),
829829
dest_node_id=NodeId(bytes(32)),
830830
sent_at=loop.time(),
831831
nonce=Nonce(bytes(12)),
@@ -1086,7 +1086,7 @@ async def test_response_completes_pending_request_future(
10861086

10871087
loop = asyncio.get_running_loop()
10881088
future: asyncio.Future[Pong | None] = loop.create_future()
1089-
request_id = b"\x01\x02\x03\x04"
1089+
request_id = RequestId(data=b"\x01\x02\x03\x04")
10901090

10911091
pending = PendingRequest(
10921092
request_id=request_id,
@@ -1099,7 +1099,7 @@ async def test_response_completes_pending_request_future(
10991099
transport._pending_requests[request_id] = pending
11001100

11011101
pong = Pong(
1102-
request_id=RequestId(data=request_id),
1102+
request_id=request_id,
11031103
enr_seq=SeqNumber(1),
11041104
recipient_ip=IPv4(b"\x7f\x00\x00\x01"),
11051105
recipient_port=Port(9000),
@@ -1121,7 +1121,7 @@ async def test_response_enqueued_for_multi_request(
11211121
local_enr=local_enr,
11221122
)
11231123

1124-
request_id = b"\x01\x02\x03\x04"
1124+
request_id = RequestId(data=b"\x01\x02\x03\x04")
11251125
queue: asyncio.Queue = asyncio.Queue()
11261126

11271127
multi_pending = PendingMultiRequest(
@@ -1137,7 +1137,7 @@ async def test_response_enqueued_for_multi_request(
11371137
transport._pending_multi_requests[request_id] = multi_pending
11381138

11391139
nodes = Nodes(
1140-
request_id=RequestId(data=request_id),
1140+
request_id=request_id,
11411141
total=Uint8(1),
11421142
enrs=[b"enr1"],
11431143
)

0 commit comments

Comments
 (0)