Skip to content

Commit cfaebd4

Browse files
committed
impl recursive aggregation
1 parent 096d667 commit cfaebd4

File tree

3 files changed

+124
-114
lines changed

3 files changed

+124
-114
lines changed

src/lean_spec/subspecs/containers/state/state.py

Lines changed: 70 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Collection, Iterable
5+
from collections.abc import Iterable
66
from collections.abc import Set as AbstractSet
77
from typing import TYPE_CHECKING
88

@@ -759,79 +759,89 @@ def build_block(
759759

760760
return final_block, post_state, aggregated_attestations, aggregated_signatures
761761

762+
def _extend_proofs_with_unique_participants(
763+
proofs: set[AggregatedSignatureProof] | None,
764+
selected: list[AggregatedSignatureProof],
765+
covered: set[ValidatorIndex],
766+
) -> None:
767+
if not proofs:
768+
return
769+
sorted_proofs = sorted(
770+
proofs,
771+
key=lambda proof: len(proof.participants.to_validator_indices()),
772+
reverse=True,
773+
)
774+
for proof in sorted_proofs:
775+
participants = set(proof.participants.to_validator_indices())
776+
if participants - covered:
777+
selected.append(proof)
778+
covered.update(participants)
779+
762780
def aggregate_gossip_signatures(
763781
self,
764-
attestations: Collection[Attestation],
765782
gossip_signatures: dict[AttestationData, set[GossipSignatureEntry]] | None = None,
783+
new_payloads: dict[AttestationData, set[AggregatedSignatureProof]] | None = None,
784+
known_payloads: dict[AttestationData, set[AggregatedSignatureProof]] | None = None,
766785
) -> list[tuple[AggregatedAttestation, AggregatedSignatureProof]]:
767786
"""
768-
Collect aggregated signatures from gossip network and aggregate them.
769-
770-
For each attestation group, attempt to collect individual XMSS signatures
771-
from the gossip network. These are fresh signatures that validators
772-
broadcast when they attest.
787+
Aggregate gossip signatures using new payloads, with known payloads as helpers.
773788
774789
Args:
775-
attestations: Individual attestations to aggregate and sign.
776-
gossip_signatures: Per-validator XMSS signatures learned from
777-
the gossip network, keyed by the attestation data they signed.
790+
gossip_signatures: Raw XMSS signatures learned from gossip keyed by attestation data.
791+
new_payloads: Aggregated proofs pending processing (child proofs).
792+
known_payloads: Known aggregated proofs already accepted.
778793
779794
Returns:
780-
List of (attestation, proof) pairs from gossip collection.
795+
List of (aggregated attestation, proof) pairs to broadcast.
781796
"""
782797
results: list[tuple[AggregatedAttestation, AggregatedSignatureProof]] = []
783798

784-
# Group individual attestations by data
785-
#
786-
# Multiple validators may attest to the same data (slot, head, target, source).
787-
# We aggregate them into groups so each group can share a single proof.
788-
for aggregated in AggregatedAttestation.aggregate_by_data(list(attestations)):
789-
# Extract the common attestation data and its hash.
790-
#
791-
# All validators in this group signed the same message (the data root).
792-
data = aggregated.data
793-
data_root = data.data_root_bytes()
799+
gossip_signatures = gossip_signatures or {}
800+
new_payloads = new_payloads or {}
801+
known_payloads = known_payloads or {}
794802

795-
# Get the list of validators who attested to this data.
796-
validator_ids = aggregated.aggregation_bits.to_validator_indices()
803+
# Use only keys from new_payloads and gossip_signatures
804+
# know_payloads can be used to extend the proof with new_payloads and gossip_signatures
805+
# but known_payloads are not recursively aggregated into their own proofs
806+
attestation_keys = set(new_payloads.keys()) | set(gossip_signatures.keys())
807+
if not attestation_keys:
808+
return results
797809

798-
# When a validator creates an attestation, it broadcasts the
799-
# individual XMSS signature over the gossip network. If we have
800-
# received these signatures, we can aggregate them ourselves.
801-
#
802-
# This is the preferred path: fresh signatures from the network.
803-
804-
# Parallel lists for signatures, public keys, and validator IDs.
805-
gossip_sigs: list[Signature] = []
806-
gossip_keys: list[PublicKey] = []
807-
gossip_ids: list[ValidatorIndex] = []
808-
809-
# Look up signatures by attestation data directly.
810-
# Sort by validator ID for deterministic aggregation order.
811-
if gossip_signatures and (entries := gossip_signatures.get(data)):
812-
for entry in sorted(entries, key=lambda e: e.validator_id):
813-
if entry.validator_id in validator_ids:
814-
gossip_sigs.append(entry.signature)
815-
gossip_keys.append(self.validators[entry.validator_id].get_pubkey())
816-
gossip_ids.append(entry.validator_id)
817-
818-
# If we collected any gossip signatures, aggregate them into a proof.
819-
#
820-
# The aggregation combines multiple XMSS signatures into a single
821-
# compact proof that can verify all participants signed the message.
822-
if gossip_ids:
823-
participants = AggregationBits.from_validator_indices(
824-
ValidatorIndices(data=gossip_ids)
825-
)
826-
proof = AggregatedSignatureProof.aggregate(
827-
participants=participants,
828-
public_keys=gossip_keys,
829-
signatures=gossip_sigs,
830-
message=data_root,
831-
slot=data.slot,
832-
)
833-
attestation = AggregatedAttestation(aggregation_bits=participants, data=data)
834-
results.append((attestation, proof))
810+
# Aggregate the proofs for each attestation data
811+
for data in attestation_keys:
812+
child_proofs: list[AggregatedSignatureProof] = []
813+
covered_validators: set[ValidatorIndex] = set()
814+
815+
self._extend_proofs_with_unique_participants(new_payloads.get(data), child_proofs, covered_validators)
816+
self._extend_proofs_with_unique_participants(known_payloads.get(data), child_proofs, covered_validators)
817+
818+
raw_entries: list[tuple[ValidatorIndex, PublicKey, Signature]] = []
819+
for entry in sorted(gossip_signatures.get(data, set()), key=lambda e: e.validator_id):
820+
if entry.validator_id in covered_validators:
821+
continue
822+
if int(entry.validator_id) >= len(self.validators):
823+
continue
824+
public_key = self.validators[entry.validator_id].get_pubkey()
825+
raw_entries.append((entry.validator_id, public_key, entry.signature))
826+
covered_validators.add(entry.validator_id)
827+
828+
if not raw_entries and len(child_proofs) < 2:
829+
results.append((data, child_proofs))
830+
continue
831+
832+
raw_entries = sorted(raw_entries, key=lambda e: e.validator_id)
833+
raw_xmss = [(pubkey, signature) for _, pubkey, signature in raw_entries]
834+
xmss_participants = AggregationBits.from_validator_indices(ValidatorIndices(data=[e.validator_id for e in raw_entries]))
835+
836+
proof = AggregatedSignatureProof.aggregate(
837+
xmss_participants=xmss_participants,
838+
children=child_proofs,
839+
raw_xmss=raw_xmss,
840+
message=data.data_root_bytes(),
841+
slot=data.slot,
842+
)
843+
attestation = AggregatedAttestation(aggregation_bits=proof.participants, data=data)
844+
results.append((attestation, proof))
835845

836846
return results
837847

src/lean_spec/subspecs/forkchoice/store.py

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -950,61 +950,28 @@ def aggregate_committee_signatures(self) -> tuple["Store", list[SignedAggregated
950950
"""
951951
Aggregate committee signatures for attestations in committee_signatures.
952952
953-
This method aggregates signatures from the gossip_signatures map.
953+
This method aggregates gossip signatures together with new aggregated payloads.
954954
955955
Returns:
956956
Tuple of (new Store with updated payloads, list of new SignedAggregatedAttestation).
957957
"""
958-
new_aggregated_payloads = {
959-
attestation_data: set(proofs)
960-
for attestation_data, proofs in self.latest_new_aggregated_payloads.items()
961-
}
962-
963-
committee_signatures = self.gossip_signatures
964-
965-
# Extract attestations from gossip_signatures
966-
attestation_list: list[Attestation] = [
967-
Attestation(validator_id=entry.validator_id, data=attestation_data)
968-
for attestation_data, signatures in self.gossip_signatures.items()
969-
for entry in signatures
970-
]
971-
972958
head_state = self.states[self.head]
973-
# Perform aggregation
974959
aggregated_results = head_state.aggregate_gossip_signatures(
975-
attestation_list,
976-
committee_signatures,
960+
gossip_signatures=self.gossip_signatures,
961+
new_payloads=self.latest_new_aggregated_payloads,
962+
known_payloads=self.latest_known_aggregated_payloads,
977963
)
978964

979965
# Create list of aggregated attestations for broadcasting
980-
new_aggregates = [
981-
SignedAggregatedAttestation(data=att.data, proof=sig) for att, sig in aggregated_results
982-
]
966+
# and update the store with the new aggregated payloads
967+
new_aggregates: list[SignedAggregatedAttestation] = []
968+
new_aggregated_payloads: dict[AttestationData, set[AggregatedSignatureProof]] = {}
983969

984-
# Compute new aggregated payloads
985-
new_gossip_sigs = {
986-
attestation_data: set(signatures)
987-
for attestation_data, signatures in self.gossip_signatures.items()
988-
}
989-
for aggregated_attestation, aggregated_signature in aggregated_results:
990-
attestation_data = aggregated_attestation.data
991-
new_aggregated_payloads.setdefault(attestation_data, set()).add(aggregated_signature)
992-
993-
validator_ids = set(aggregated_signature.participants.to_validator_indices())
994-
existing_entries = new_gossip_sigs.get(attestation_data)
995-
if existing_entries:
996-
remaining = {e for e in existing_entries if e.validator_id not in validator_ids}
997-
if remaining:
998-
new_gossip_sigs[attestation_data] = remaining
999-
else:
1000-
del new_gossip_sigs[attestation_data]
970+
for att, proof in aggregated_results:
971+
new_aggregates.append(SignedAggregatedAttestation(data=att.data, proof=proof))
972+
new_aggregated_payloads.setdefault(att.data, set()).add(proof)
1001973

1002-
return self.model_copy(
1003-
update={
1004-
"latest_new_aggregated_payloads": new_aggregated_payloads,
1005-
"gossip_signatures": new_gossip_sigs,
1006-
}
1007-
), new_aggregates
974+
return self.model_copy(update={"latest_new_aggregated_payloads": new_aggregated_payloads, "gossip_signatures": {}}), new_aggregates
1008975

1009976
def tick_interval(
1010977
self, has_proposal: bool, is_aggregator: bool = False

src/lean_spec/subspecs/xmss/aggregation.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
from lean_spec.config import LEAN_ENV, LeanEnvMode
1616
from lean_spec.subspecs.containers.attestation import AggregationBits
1717
from lean_spec.subspecs.containers.slot import Slot
18+
from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices
1819
from lean_spec.types import ByteListMiB, Bytes32, Container
1920

2021
from .containers import PublicKey, Signature
2122

23+
INV_PROOF_SIZE: int = 2
24+
"""Protocol-level inverse proof size parameter for aggregation (range 1-4)."""
25+
2226

2327
class AggregationError(Exception):
2428
"""Raised when signature aggregation or verification fails."""
@@ -44,39 +48,68 @@ class AggregatedSignatureProof(Container):
4448
proof_data: ByteListMiB
4549
"""The raw aggregated proof bytes from leanVM."""
4650

51+
bytecode_point: ByteListMiB | None = None
52+
"""
53+
Serialized bytecode-point claim data from recursive aggregation.
54+
55+
If the bytecode point is not provided, the proof is not recursive.
56+
"""
57+
4758
@classmethod
4859
def aggregate(
4960
cls,
50-
participants: AggregationBits,
51-
public_keys: Sequence[PublicKey],
52-
signatures: Sequence[Signature],
61+
xmss_participants: AggregationBits | None,
62+
children: Sequence[Self],
63+
raw_xmss: Sequence[tuple[PublicKey, Signature]],
5364
message: Bytes32,
5465
slot: Slot,
5566
mode: LeanEnvMode | None = None,
5667
) -> Self:
5768
"""
58-
Aggregate individual XMSS signatures into a single proof.
69+
Aggregate raw_xmss signatures and children proofs into a single proof.
5970
6071
Args:
61-
participants: Bitfield of validators whose signatures are included.
62-
public_keys: Public keys of the signers (must match signatures order).
63-
signatures: Individual XMSS signatures to aggregate.
72+
xmss_participants: Bitfield of validators whose raw_signatures are provided.
73+
children: Sequence of child proofs to aggregate.
74+
raw_xmss: Sequence of (public key, signature) tuples to aggregate.
6475
message: The 32-byte message that was signed.
6576
slot: The slot in which the signatures were created.
6677
mode: The mode to use for the aggregation (test or prod).
6778
6879
Returns:
69-
An aggregated signature proof covering all participants.
80+
An aggregated signature proof covering raw signers and all child participants.
7081
7182
Raises:
7283
AggregationError: If aggregation fails.
7384
"""
85+
if not raw_xmss and not children:
86+
raise AggregationError("At least one raw signature or child proof is required")
87+
88+
if raw_xmss and xmss_participants is None:
89+
raise AggregationError("xmss_participants is required when raw_xmss is provided")
90+
91+
if not raw_xmss and len(children) < 2:
92+
raise AggregationError(
93+
"At least two child proofs are required when no raw signatures are provided"
94+
)
95+
96+
aggregated_validator_ids: set[ValidatorIndex] = set()
97+
aggregated_validator_ids.update(xmss_participants.to_validator_indices())
98+
99+
if len(aggregated_validator_ids) != len(raw_xmss):
100+
raise AggregationError("The number of raw signatures does not match the number of XMSS participants")
101+
102+
# Include child participants in the aggregated participants
103+
for child in children:
104+
aggregated_validator_ids.update(child.participants.to_validator_indices())
105+
participants = AggregationBits.from_validator_indices(ValidatorIndices(data=aggregated_validator_ids))
106+
74107
mode = mode or LEAN_ENV
75108
setup_prover(mode=mode)
76109
try:
77110
proof_bytes = aggregate_signatures(
78-
[pk.encode_bytes() for pk in public_keys],
79-
[sig.encode_bytes() for sig in signatures],
111+
[pk.encode_bytes() for pk, _ in raw_xmss],
112+
[sig.encode_bytes() for _, sig in raw_xmss],
80113
message,
81114
slot,
82115
mode=mode,

0 commit comments

Comments
 (0)