Skip to content

Commit e73c707

Browse files
authored
test framework: optimize key manager's memory usage (#460)
1 parent 4e17e77 commit e73c707

File tree

4 files changed

+105
-33
lines changed

4 files changed

+105
-33
lines changed

packages/testing/src/consensus_testing/keys.py

Lines changed: 91 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from lean_spec.subspecs.koalabear import Fp
5252
from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof
5353
from lean_spec.subspecs.xmss.constants import TARGET_CONFIG
54-
from lean_spec.subspecs.xmss.containers import PublicKey, Signature, ValidatorKeyPair
54+
from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature, ValidatorKeyPair
5555
from lean_spec.subspecs.xmss.interface import (
5656
PROD_SIGNATURE_SCHEME,
5757
TEST_SIGNATURE_SCHEME,
@@ -65,6 +65,9 @@
6565
)
6666
from lean_spec.types import Bytes32, Uint64
6767

68+
SecretField = Literal["attestation_secret", "proposal_secret"]
69+
"""The two secret key field names on ValidatorKeyPair."""
70+
6871
__all__ = [
6972
"CLI_DEFAULT_MAX_SLOT",
7073
"KEY_DOWNLOAD_URLS",
@@ -173,6 +176,8 @@ def __init__(self, scheme_name: str) -> None:
173176
self._scheme_name = scheme_name
174177
self._keys_dir = get_keys_dir(scheme_name)
175178
self._cache: dict[ValidatorIndex, ValidatorKeyPair] = {}
179+
self._public_cache: dict[ValidatorIndex, tuple[PublicKey, PublicKey]] = {}
180+
self._raw_cache: dict[ValidatorIndex, dict[str, str]] = {}
176181
self._available_indices: set[ValidatorIndex] | None = None
177182

178183
def _ensure_dir_exists(self) -> None:
@@ -197,18 +202,62 @@ def _get_available_indices(self) -> set[ValidatorIndex]:
197202
)
198203
return self._available_indices
199204

205+
def _load_raw(self, idx: ValidatorIndex) -> dict[str, str]:
206+
"""Load raw JSON data from disk (cached)."""
207+
if idx not in self._raw_cache:
208+
key_file = self._keys_dir / f"{idx}.json"
209+
try:
210+
self._raw_cache[idx] = json.loads(key_file.read_text())
211+
except FileNotFoundError:
212+
raise KeyError(f"Key file not found: {key_file}") from None
213+
return self._raw_cache[idx]
214+
200215
def _load_key(self, idx: ValidatorIndex) -> ValidatorKeyPair:
201216
"""Load a single key from disk."""
202-
key_file = self._keys_dir / f"{idx}.json"
203-
if not key_file.exists():
204-
raise KeyError(f"Key file not found: {key_file}")
205-
data = json.loads(key_file.read_text())
206-
return ValidatorKeyPair.from_dict(data)
217+
return ValidatorKeyPair.from_dict(self._load_raw(idx))
218+
219+
def _load_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]:
220+
"""Load only public keys from disk, skipping expensive SecretKey deserialization."""
221+
data = self._load_raw(idx)
222+
return (
223+
PublicKey.decode_bytes(bytes.fromhex(data["attestation_public"])),
224+
PublicKey.decode_bytes(bytes.fromhex(data["proposal_public"])),
225+
)
226+
227+
def get_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]:
228+
"""
229+
Get (attestation_public, proposal_public) without loading secret keys.
230+
231+
Returns cached public keys if available, otherwise loads only the public
232+
key portions from disk. Avoids deserializing the heavy SecretKey objects
233+
(each ~2.7KB raw with 3 HashSubTree structures) until signing is needed.
234+
"""
235+
if idx in self._cache:
236+
kp = self._cache[idx]
237+
return (kp.attestation_public, kp.proposal_public)
238+
if idx not in self._public_cache:
239+
self._public_cache[idx] = self._load_public_keys(idx)
240+
return self._public_cache[idx]
241+
242+
def get_secret_key(self, idx: ValidatorIndex, field: SecretField) -> SecretKey:
243+
"""
244+
Load a specific secret key from disk without deserializing the other keys.
245+
246+
Only the requested SecretKey is deserialized (~370 MB in Python objects).
247+
The other three fields remain as lightweight hex strings (~2.7 KB each).
248+
"""
249+
if idx in self._cache:
250+
return getattr(self._cache[idx], field)
251+
data = self._load_raw(idx)
252+
return SecretKey.decode_bytes(bytes.fromhex(data[field]))
207253

208254
def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair:
209255
"""Get key pair by validator index, loading from disk if needed."""
210256
if idx not in self._cache:
211257
self._cache[idx] = self._load_key(idx)
258+
# Full pair supersedes raw/public caches for this index.
259+
self._raw_cache.pop(idx, None)
260+
self._public_cache.pop(idx, None)
212261
return self._cache[idx]
213262

214263
def __contains__(self, idx: object) -> bool:
@@ -247,7 +296,13 @@ def __init__(
247296
"""Initialize the manager with optional custom configuration."""
248297
self.max_slot = max_slot
249298
self.scheme = scheme
250-
self._state: dict[ValidatorIndex, ValidatorKeyPair] = {}
299+
self._secret_state: dict[tuple[ValidatorIndex, SecretField], bytes] = {}
300+
"""
301+
Advanced secret key state cached as raw SSZ bytes.
302+
303+
Raw bytes (~2.7 KB each) instead of deserialized SecretKey objects
304+
(~370 MB each) to avoid holding massive Pydantic model trees in memory.
305+
"""
251306

252307
try:
253308
self.scheme_name = next(
@@ -264,9 +319,7 @@ def keys(self) -> LazyKeyDict:
264319
return _LAZY_KEY_CACHE[self.scheme_name]
265320

266321
def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair:
267-
"""Get key pair, returning advanced state if available."""
268-
if idx in self._state:
269-
return self._state[idx]
322+
"""Get key pair. Prefer get_public_keys() or signing methods to avoid loading all keys."""
270323
if idx not in self.keys:
271324
raise KeyError(f"Validator {idx} not found (max: {len(self.keys) - 1})")
272325
return self.keys[idx]
@@ -285,27 +338,45 @@ def __iter__(self) -> Iterator[ValidatorIndex]:
285338
"""Iterate over validator indices."""
286339
return iter(self.keys)
287340

341+
def get_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]:
342+
"""
343+
Get (attestation_public, proposal_public) without loading secret keys.
344+
345+
Delegates to lazy disk loading that skips SecretKey deserialization.
346+
"""
347+
return self.keys.get_public_keys(idx)
348+
288349
def _sign_with_secret(
289350
self,
290351
validator_id: ValidatorIndex,
291352
slot: Slot,
292353
message: Bytes32,
293-
secret_field: Literal["attestation_secret", "proposal_secret"],
354+
secret_field: SecretField,
294355
) -> Signature:
295356
"""
296357
Shared signing logic for attestation/proposal paths.
297358
298359
Handles XMSS state advancement until the requested slot is within the
299-
prepared interval, caches the updated secret, and produces the signature.
360+
prepared interval, caches the updated secret as raw bytes, and produces
361+
the signature.
362+
363+
Only the needed SecretKey is deserialized (~370 MB in Python objects).
364+
After signing, the advanced state is re-serialized to compact bytes
365+
(~2.7 KB) so only one SecretKey is in memory at a time.
300366
301367
Args:
302368
validator_id: Validator index whose key should be used.
303369
slot: The slot to sign for.
304370
message: The message bytes to sign.
305371
secret_field: Which secret on the key pair should advance.
306372
"""
307-
kp = self[validator_id]
308-
sk = getattr(kp, secret_field)
373+
cache_key = (validator_id, secret_field)
374+
375+
# Deserialize the secret key: from cached bytes or from disk.
376+
if cache_key in self._secret_state:
377+
sk = SecretKey.decode_bytes(self._secret_state[cache_key])
378+
else:
379+
sk = self.keys.get_secret_key(validator_id, secret_field)
309380

310381
# Advance key state until the slot is ready for signing.
311382
prepared = self.scheme.get_prepared_interval(sk)
@@ -316,10 +387,12 @@ def _sign_with_secret(
316387
sk = self.scheme.advance_preparation(sk)
317388
prepared = self.scheme.get_prepared_interval(sk)
318389

319-
# Cache advanced state (only the selected secret changes).
320-
self._state[validator_id] = kp._replace(**{secret_field: sk})
390+
signature = self.scheme.sign(sk, slot, message)
391+
392+
# Cache advanced state as raw bytes to keep memory compact.
393+
self._secret_state[cache_key] = sk.encode_bytes()
321394

322-
return self.scheme.sign(sk, slot, message)
395+
return signature
323396

324397
def sign_attestation_data(
325398
self,
@@ -397,7 +470,7 @@ def build_attestation_signatures(
397470
# Look up pre-computed signatures by attestation data and validator ID.
398471
sigs_for_data = lookup.get(agg.data, {})
399472

400-
public_keys: list[PublicKey] = [self[vid].attestation_public for vid in validator_ids]
473+
public_keys: list[PublicKey] = [self.get_public_keys(vid)[0] for vid in validator_ids]
401474
signatures: list[Signature] = [
402475
sigs_for_data.get(vid) or self.sign_attestation_data(vid, agg.data)
403476
for vid in validator_ids

packages/testing/src/consensus_testing/test_fixtures/fork_choice.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,11 @@ def make_fixture(self) -> Self:
200200
updated_validators = []
201201
for i, validator in enumerate(self.anchor_state.validators):
202202
idx = ValidatorIndex(i)
203+
attestation_pubkey, proposal_pubkey = key_manager.get_public_keys(idx)
203204
validator = validator.model_copy(
204205
update={
205-
"attestation_pubkey": key_manager[idx].attestation_public.encode_bytes(),
206-
"proposal_pubkey": key_manager[idx].proposal_public.encode_bytes(),
206+
"attestation_pubkey": attestation_pubkey.encode_bytes(),
207+
"proposal_pubkey": proposal_pubkey.encode_bytes(),
207208
}
208209
)
209210
updated_validators.append(validator)

packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def _build_block_from_spec(
230230
# Valid proof but from wrong validators
231231
# Sign with signer_ids but claim validator_ids as participants
232232
signer_public_keys = [
233-
key_manager[vid].attestation_public for vid in invalid_spec.signer_ids
233+
key_manager.get_public_keys(vid)[0] for vid in invalid_spec.signer_ids
234234
]
235235
signer_signatures = [
236236
key_manager.sign_attestation_data(vid, attestation_data)

packages/testing/src/consensus_testing/test_types/genesis.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,17 @@ def generate_pre_state(**kwargs: Any) -> State:
3333
f"but the key manager has only {available_keys} keys",
3434
)
3535

36-
validators = Validators(
37-
data=[
36+
validator_list = []
37+
for i in range(num_validators):
38+
idx = ValidatorIndex(i)
39+
attestation_pubkey, proposal_pubkey = key_manager.get_public_keys(idx)
40+
validator_list.append(
3841
Validator(
39-
attestation_pubkey=Bytes52(
40-
key_manager[ValidatorIndex(i)].attestation_public.encode_bytes()
41-
),
42-
proposal_pubkey=Bytes52(
43-
key_manager[ValidatorIndex(i)].proposal_public.encode_bytes()
44-
),
45-
index=ValidatorIndex(i),
42+
attestation_pubkey=Bytes52(attestation_pubkey.encode_bytes()),
43+
proposal_pubkey=Bytes52(proposal_pubkey.encode_bytes()),
44+
index=idx,
4645
)
47-
for i in range(num_validators)
48-
]
49-
)
46+
)
47+
validators = Validators(data=validator_list)
5048

5149
return State.generate_genesis(genesis_time=genesis_time, validators=validators)

0 commit comments

Comments
 (0)