5151from lean_spec .subspecs .koalabear import Fp
5252from lean_spec .subspecs .xmss .aggregation import AggregatedSignatureProof
5353from lean_spec .subspecs .xmss .constants import TARGET_CONFIG
54- from lean_spec .subspecs .xmss .containers import PublicKey , Signature , ValidatorKeyPair
54+ from lean_spec .subspecs .xmss .containers import PublicKey , SecretKey , Signature , ValidatorKeyPair
5555from lean_spec .subspecs .xmss .interface import (
5656 PROD_SIGNATURE_SCHEME ,
5757 TEST_SIGNATURE_SCHEME ,
6565)
6666from lean_spec .types import Bytes32 , Uint64
6767
68+ SecretField = Literal ["attestation_secret" , "proposal_secret" ]
69+ """The two secret key field names on ValidatorKeyPair."""
70+
6871__all__ = [
6972 "CLI_DEFAULT_MAX_SLOT" ,
7073 "KEY_DOWNLOAD_URLS" ,
@@ -173,6 +176,8 @@ def __init__(self, scheme_name: str) -> None:
173176 self ._scheme_name = scheme_name
174177 self ._keys_dir = get_keys_dir (scheme_name )
175178 self ._cache : dict [ValidatorIndex , ValidatorKeyPair ] = {}
179+ self ._public_cache : dict [ValidatorIndex , tuple [PublicKey , PublicKey ]] = {}
180+ self ._raw_cache : dict [ValidatorIndex , dict [str , str ]] = {}
176181 self ._available_indices : set [ValidatorIndex ] | None = None
177182
178183 def _ensure_dir_exists (self ) -> None :
@@ -197,18 +202,62 @@ def _get_available_indices(self) -> set[ValidatorIndex]:
197202 )
198203 return self ._available_indices
199204
205+ def _load_raw (self , idx : ValidatorIndex ) -> dict [str , str ]:
206+ """Load raw JSON data from disk (cached)."""
207+ if idx not in self ._raw_cache :
208+ key_file = self ._keys_dir / f"{ idx } .json"
209+ try :
210+ self ._raw_cache [idx ] = json .loads (key_file .read_text ())
211+ except FileNotFoundError :
212+ raise KeyError (f"Key file not found: { key_file } " ) from None
213+ return self ._raw_cache [idx ]
214+
200215 def _load_key (self , idx : ValidatorIndex ) -> ValidatorKeyPair :
201216 """Load a single key from disk."""
202- key_file = self ._keys_dir / f"{ idx } .json"
203- if not key_file .exists ():
204- raise KeyError (f"Key file not found: { key_file } " )
205- data = json .loads (key_file .read_text ())
206- return ValidatorKeyPair .from_dict (data )
217+ return ValidatorKeyPair .from_dict (self ._load_raw (idx ))
218+
219+ def _load_public_keys (self , idx : ValidatorIndex ) -> tuple [PublicKey , PublicKey ]:
220+ """Load only public keys from disk, skipping expensive SecretKey deserialization."""
221+ data = self ._load_raw (idx )
222+ return (
223+ PublicKey .decode_bytes (bytes .fromhex (data ["attestation_public" ])),
224+ PublicKey .decode_bytes (bytes .fromhex (data ["proposal_public" ])),
225+ )
226+
227+ def get_public_keys (self , idx : ValidatorIndex ) -> tuple [PublicKey , PublicKey ]:
228+ """
229+ Get (attestation_public, proposal_public) without loading secret keys.
230+
231+ Returns cached public keys if available, otherwise loads only the public
232+ key portions from disk. Avoids deserializing the heavy SecretKey objects
233+ (each ~2.7KB raw with 3 HashSubTree structures) until signing is needed.
234+ """
235+ if idx in self ._cache :
236+ kp = self ._cache [idx ]
237+ return (kp .attestation_public , kp .proposal_public )
238+ if idx not in self ._public_cache :
239+ self ._public_cache [idx ] = self ._load_public_keys (idx )
240+ return self ._public_cache [idx ]
241+
242+ def get_secret_key (self , idx : ValidatorIndex , field : SecretField ) -> SecretKey :
243+ """
244+ Load a specific secret key from disk without deserializing the other keys.
245+
246+ Only the requested SecretKey is deserialized (~370 MB in Python objects).
247+ The other three fields remain as lightweight hex strings (~2.7 KB each).
248+ """
249+ if idx in self ._cache :
250+ return getattr (self ._cache [idx ], field )
251+ data = self ._load_raw (idx )
252+ return SecretKey .decode_bytes (bytes .fromhex (data [field ]))
207253
208254 def __getitem__ (self , idx : ValidatorIndex ) -> ValidatorKeyPair :
209255 """Get key pair by validator index, loading from disk if needed."""
210256 if idx not in self ._cache :
211257 self ._cache [idx ] = self ._load_key (idx )
258+ # Full pair supersedes raw/public caches for this index.
259+ self ._raw_cache .pop (idx , None )
260+ self ._public_cache .pop (idx , None )
212261 return self ._cache [idx ]
213262
214263 def __contains__ (self , idx : object ) -> bool :
@@ -247,7 +296,13 @@ def __init__(
247296 """Initialize the manager with optional custom configuration."""
248297 self .max_slot = max_slot
249298 self .scheme = scheme
250- self ._state : dict [ValidatorIndex , ValidatorKeyPair ] = {}
299+ self ._secret_state : dict [tuple [ValidatorIndex , SecretField ], bytes ] = {}
300+ """
301+ Advanced secret key state cached as raw SSZ bytes.
302+
303+ Raw bytes (~2.7 KB each) instead of deserialized SecretKey objects
304+ (~370 MB each) to avoid holding massive Pydantic model trees in memory.
305+ """
251306
252307 try :
253308 self .scheme_name = next (
@@ -264,9 +319,7 @@ def keys(self) -> LazyKeyDict:
264319 return _LAZY_KEY_CACHE [self .scheme_name ]
265320
266321 def __getitem__ (self , idx : ValidatorIndex ) -> ValidatorKeyPair :
267- """Get key pair, returning advanced state if available."""
268- if idx in self ._state :
269- return self ._state [idx ]
322+ """Get key pair. Prefer get_public_keys() or signing methods to avoid loading all keys."""
270323 if idx not in self .keys :
271324 raise KeyError (f"Validator { idx } not found (max: { len (self .keys ) - 1 } )" )
272325 return self .keys [idx ]
@@ -285,27 +338,45 @@ def __iter__(self) -> Iterator[ValidatorIndex]:
285338 """Iterate over validator indices."""
286339 return iter (self .keys )
287340
341+ def get_public_keys (self , idx : ValidatorIndex ) -> tuple [PublicKey , PublicKey ]:
342+ """
343+ Get (attestation_public, proposal_public) without loading secret keys.
344+
345+ Delegates to lazy disk loading that skips SecretKey deserialization.
346+ """
347+ return self .keys .get_public_keys (idx )
348+
288349 def _sign_with_secret (
289350 self ,
290351 validator_id : ValidatorIndex ,
291352 slot : Slot ,
292353 message : Bytes32 ,
293- secret_field : Literal [ "attestation_secret" , "proposal_secret" ] ,
354+ secret_field : SecretField ,
294355 ) -> Signature :
295356 """
296357 Shared signing logic for attestation/proposal paths.
297358
298359 Handles XMSS state advancement until the requested slot is within the
299- prepared interval, caches the updated secret, and produces the signature.
360+ prepared interval, caches the updated secret as raw bytes, and produces
361+ the signature.
362+
363+ Only the needed SecretKey is deserialized (~370 MB in Python objects).
364+ After signing, the advanced state is re-serialized to compact bytes
365+ (~2.7 KB) so only one SecretKey is in memory at a time.
300366
301367 Args:
302368 validator_id: Validator index whose key should be used.
303369 slot: The slot to sign for.
304370 message: The message bytes to sign.
305371 secret_field: Which secret on the key pair should advance.
306372 """
307- kp = self [validator_id ]
308- sk = getattr (kp , secret_field )
373+ cache_key = (validator_id , secret_field )
374+
375+ # Deserialize the secret key: from cached bytes or from disk.
376+ if cache_key in self ._secret_state :
377+ sk = SecretKey .decode_bytes (self ._secret_state [cache_key ])
378+ else :
379+ sk = self .keys .get_secret_key (validator_id , secret_field )
309380
310381 # Advance key state until the slot is ready for signing.
311382 prepared = self .scheme .get_prepared_interval (sk )
@@ -316,10 +387,12 @@ def _sign_with_secret(
316387 sk = self .scheme .advance_preparation (sk )
317388 prepared = self .scheme .get_prepared_interval (sk )
318389
319- # Cache advanced state (only the selected secret changes).
320- self ._state [validator_id ] = kp ._replace (** {secret_field : sk })
390+ signature = self .scheme .sign (sk , slot , message )
391+
392+ # Cache advanced state as raw bytes to keep memory compact.
393+ self ._secret_state [cache_key ] = sk .encode_bytes ()
321394
322- return self . scheme . sign ( sk , slot , message )
395+ return signature
323396
324397 def sign_attestation_data (
325398 self ,
@@ -397,7 +470,7 @@ def build_attestation_signatures(
397470 # Look up pre-computed signatures by attestation data and validator ID.
398471 sigs_for_data = lookup .get (agg .data , {})
399472
400- public_keys : list [PublicKey ] = [self [ vid ]. attestation_public for vid in validator_ids ]
473+ public_keys : list [PublicKey ] = [self . get_public_keys ( vid )[ 0 ] for vid in validator_ids ]
401474 signatures : list [Signature ] = [
402475 sigs_for_data .get (vid ) or self .sign_attestation_data (vid , agg .data )
403476 for vid in validator_ids
0 commit comments