Skip to content

Commit 151abf1

Browse files
committed
use a dataclass for mutable cache state
1 parent 47757e1 commit 151abf1

File tree

2 files changed

+73
-98
lines changed

2 files changed

+73
-98
lines changed

src/zarr/experimental/cache_store.py

Lines changed: 53 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import time
66
from collections import OrderedDict
7+
from dataclasses import dataclass, field
78
from typing import TYPE_CHECKING, Any, Literal, Self
89

910
from zarr.abc.store import ByteRequest, Store
@@ -15,14 +16,16 @@
1516
from zarr.core.buffer.core import Buffer, BufferPrototype
1617

1718

19+
@dataclass(slots=True)
1820
class _CacheState:
19-
_cache_order: OrderedDict[str, None] # Track access order for LRU
20-
_current_size: int # Track current cache size
21-
_key_sizes: dict[str, int] # Track size of each cached key
22-
_lock: asyncio.Lock
23-
_hits: int # Cache hit counter
24-
_misses: int # Cache miss counter
25-
_evictions: int # Cache eviction counter
21+
cache_order: OrderedDict[str, None] = field(default_factory=OrderedDict)
22+
current_size: int = 0
23+
key_sizes: dict[str, int] = field(default_factory=dict)
24+
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
25+
hits: int = 0
26+
misses: int = 0
27+
evictions: int = 0
28+
key_insert_times: dict[str, float] = field(default_factory=dict)
2629

2730

2831
class CacheStore(WrapperStore[Store]):
@@ -46,9 +49,6 @@ class CacheStore(WrapperStore[Store]):
4649
Maximum size of the cache in bytes. When exceeded, least recently used
4750
items are evicted. None means unlimited size. Default is None.
4851
Note: Individual values larger than max_size will not be cached.
49-
key_insert_times : dict[str, float] | None, optional
50-
Dictionary to track insertion times (using monotonic time).
51-
Primarily for internal use. Default is None (creates new dict).
5252
cache_set_data : bool, optional
5353
Whether to cache data when it's written to the store. Default is True.
5454
@@ -79,7 +79,6 @@ class CacheStore(WrapperStore[Store]):
7979
_cache: Store
8080
max_age_seconds: int | Literal["infinity"]
8181
max_size: int | None
82-
key_insert_times: dict[str, float]
8382
cache_set_data: bool
8483
_state: _CacheState
8584

@@ -90,7 +89,6 @@ def __init__(
9089
cache_store: Store,
9190
max_age_seconds: int | str = "infinity",
9291
max_size: int | None = None,
93-
key_insert_times: dict[str, float] | None = None,
9492
cache_set_data: bool = True,
9593
) -> None:
9694
super().__init__(store)
@@ -114,18 +112,6 @@ def __init__(
114112
self.cache_set_data = cache_set_data
115113
self._state = _CacheState()
116114

117-
if key_insert_times is None:
118-
self.key_insert_times = {}
119-
else:
120-
self.key_insert_times = key_insert_times
121-
self._state._cache_order = OrderedDict()
122-
self._state._current_size = 0
123-
self._state._key_sizes = {}
124-
self._state._lock = asyncio.Lock()
125-
self._state._hits = 0
126-
self._state._misses = 0
127-
self._state._evictions = 0
128-
129115
def _with_store(self, store: Store) -> Self:
130116
# Cannot support this operation because it would share a cache, but have a new store
131117
# So cache keys would conflict
@@ -138,7 +124,6 @@ def with_read_only(self, read_only: bool = False) -> Self:
138124
cache_store=self._cache,
139125
max_age_seconds=self.max_age_seconds,
140126
max_size=self.max_size,
141-
key_insert_times=self.key_insert_times,
142127
cache_set_data=self.cache_set_data,
143128
)
144129
store._state = self._state
@@ -152,7 +137,7 @@ def _is_key_fresh(self, key: str) -> bool:
152137
if self.max_age_seconds == "infinity":
153138
return True
154139
now = time.monotonic()
155-
elapsed = now - self.key_insert_times.get(key, 0)
140+
elapsed = now - self._state.key_insert_times.get(key, 0)
156141
return elapsed < self.max_age_seconds
157142

158143
async def _accommodate_value(self, value_size: int) -> None:
@@ -164,9 +149,9 @@ async def _accommodate_value(self, value_size: int) -> None:
164149
return
165150

166151
# Remove least recently used items until we have enough space
167-
while self._state._current_size + value_size > self.max_size and self._state._cache_order:
152+
while self._state.current_size + value_size > self.max_size and self._state.cache_order:
168153
# Get the least recently used key (first in OrderedDict)
169-
lru_key = next(iter(self._state._cache_order))
154+
lru_key = next(iter(self._state.cache_order))
170155
await self._evict_key(lru_key)
171156

172157
async def _evict_key(self, key: str) -> None:
@@ -176,15 +161,15 @@ async def _evict_key(self, key: str) -> None:
176161
Updates size tracking atomically with deletion.
177162
"""
178163
try:
179-
key_size = self._state._key_sizes.get(key, 0)
164+
key_size = self._state.key_sizes.get(key, 0)
180165

181166
# Delete from cache store
182167
await self._cache.delete(key)
183168

184169
# Update tracking after successful deletion
185170
self._remove_from_tracking(key)
186-
self._state._current_size = max(0, self._state._current_size - key_size)
187-
self._state._evictions += 1
171+
self._state.current_size = max(0, self._state.current_size - key_size)
172+
self._state.evictions += 1
188173

189174
logger.debug("_evict_key: evicted key %s, freed %d bytes", key, key_size)
190175
except Exception:
@@ -207,39 +192,39 @@ async def _cache_value(self, key: str, value: Buffer) -> None:
207192
)
208193
return
209194

210-
async with self._state._lock:
195+
async with self._state.lock:
211196
# If key already exists, subtract old size first
212-
if key in self._state._key_sizes:
213-
old_size = self._state._key_sizes[key]
214-
self._state._current_size -= old_size
197+
if key in self._state.key_sizes:
198+
old_size = self._state.key_sizes[key]
199+
self._state.current_size -= old_size
215200
logger.debug("_cache_value: updating existing key %s, old size %d", key, old_size)
216201

217202
# Make room for the new value (this calls _evict_key_locked internally)
218203
await self._accommodate_value(value_size)
219204

220205
# Update tracking atomically
221-
self._state._cache_order[key] = None # OrderedDict to track access order
222-
self._state._current_size += value_size
223-
self._state._key_sizes[key] = value_size
224-
self.key_insert_times[key] = time.monotonic()
206+
self._state.cache_order[key] = None # OrderedDict to track access order
207+
self._state.current_size += value_size
208+
self._state.key_sizes[key] = value_size
209+
self._state.key_insert_times[key] = time.monotonic()
225210

226211
logger.debug("_cache_value: cached key %s with size %d bytes", key, value_size)
227212

228213
async def _update_access_order(self, key: str) -> None:
229214
"""Update the access order for LRU tracking."""
230-
if key in self._state._cache_order:
231-
async with self._state._lock:
215+
if key in self._state.cache_order:
216+
async with self._state.lock:
232217
# Move to end (most recently used)
233-
self._state._cache_order.move_to_end(key)
218+
self._state.cache_order.move_to_end(key)
234219

235220
def _remove_from_tracking(self, key: str) -> None:
236221
"""Remove a key from all tracking structures.
237222
238-
Must be called while holding self._state._lock.
223+
Must be called while holding self._state.lock.
239224
"""
240-
self._state._cache_order.pop(key, None)
241-
self.key_insert_times.pop(key, None)
242-
self._state._key_sizes.pop(key, None)
225+
self._state.cache_order.pop(key, None)
226+
self._state.key_insert_times.pop(key, None)
227+
self._state.key_sizes.pop(key, None)
243228

244229
async def _get_try_cache(
245230
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
@@ -248,20 +233,20 @@ async def _get_try_cache(
248233
maybe_cached_result = await self._cache.get(key, prototype, byte_range)
249234
if maybe_cached_result is not None:
250235
logger.debug("_get_try_cache: key %s found in cache (HIT)", key)
251-
self._state._hits += 1
236+
self._state.hits += 1
252237
# Update access order for LRU
253238
await self._update_access_order(key)
254239
return maybe_cached_result
255240
else:
256241
logger.debug(
257242
"_get_try_cache: key %s not found in cache (MISS), fetching from store", key
258243
)
259-
self._state._misses += 1
244+
self._state.misses += 1
260245
maybe_fresh_result = await super().get(key, prototype, byte_range)
261246
if maybe_fresh_result is None:
262247
# Key doesn't exist in source store
263248
await self._cache.delete(key)
264-
async with self._state._lock:
249+
async with self._state.lock:
265250
self._remove_from_tracking(key)
266251
else:
267252
# Cache the newly fetched value
@@ -273,12 +258,12 @@ async def _get_no_cache(
273258
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
274259
) -> Buffer | None:
275260
"""Get data directly from source store and update cache."""
276-
self._state._misses += 1
261+
self._state.misses += 1
277262
maybe_fresh_result = await super().get(key, prototype, byte_range)
278263
if maybe_fresh_result is None:
279264
# Key doesn't exist in source, remove from cache and tracking
280265
await self._cache.delete(key)
281-
async with self._state._lock:
266+
async with self._state.lock:
282267
self._remove_from_tracking(key)
283268
else:
284269
logger.debug("_get_no_cache: key %s found in store, setting in cache", key)
@@ -336,7 +321,7 @@ async def set(self, key: str, value: Buffer) -> None:
336321
else:
337322
logger.debug("set: deleting key %s from cache", key)
338323
await self._cache.delete(key)
339-
async with self._state._lock:
324+
async with self._state.lock:
340325
self._remove_from_tracking(key)
341326

342327
async def delete(self, key: str) -> None:
@@ -352,7 +337,7 @@ async def delete(self, key: str) -> None:
352337
await super().delete(key)
353338
logger.debug("delete: deleting key %s from cache", key)
354339
await self._cache.delete(key)
355-
async with self._state._lock:
340+
async with self._state.lock:
356341
self._remove_from_tracking(key)
357342

358343
def cache_info(self) -> dict[str, Any]:
@@ -363,20 +348,20 @@ def cache_info(self) -> dict[str, Any]:
363348
if self.max_age_seconds == "infinity"
364349
else self.max_age_seconds,
365350
"max_size": self.max_size,
366-
"current_size": self._state._current_size,
351+
"current_size": self._state.current_size,
367352
"cache_set_data": self.cache_set_data,
368-
"tracked_keys": len(self.key_insert_times),
369-
"cached_keys": len(self._state._cache_order),
353+
"tracked_keys": len(self._state.key_insert_times),
354+
"cached_keys": len(self._state.cache_order),
370355
}
371356

372357
def cache_stats(self) -> dict[str, Any]:
373358
"""Return cache performance statistics."""
374-
total_requests = self._state._hits + self._state._misses
375-
hit_rate = self._state._hits / total_requests if total_requests > 0 else 0.0
359+
total_requests = self._state.hits + self._state.misses
360+
hit_rate = self._state.hits / total_requests if total_requests > 0 else 0.0
376361
return {
377-
"hits": self._state._hits,
378-
"misses": self._state._misses,
379-
"evictions": self._state._evictions,
362+
"hits": self._state.hits,
363+
"misses": self._state.misses,
364+
"evictions": self._state.evictions,
380365
"total_requests": total_requests,
381366
"hit_rate": hit_rate,
382367
}
@@ -388,11 +373,11 @@ async def clear_cache(self) -> None:
388373
await self._cache.clear()
389374

390375
# Reset tracking
391-
async with self._state._lock:
392-
self.key_insert_times.clear()
393-
self._state._cache_order.clear()
394-
self._state._key_sizes.clear()
395-
self._state._current_size = 0
376+
async with self._state.lock:
377+
self._state.key_insert_times.clear()
378+
self._state.cache_order.clear()
379+
self._state.key_sizes.clear()
380+
self._state.current_size = 0
396381
logger.debug("clear_cache: cleared all cache data")
397382

398383
def __repr__(self) -> str:
@@ -403,6 +388,6 @@ def __repr__(self) -> str:
403388
f"cache_store={self._cache!r}, "
404389
f"max_age_seconds={self.max_age_seconds}, "
405390
f"max_size={self.max_size}, "
406-
f"current_size={self._state._current_size}, "
407-
f"cached_keys={len(self._state._cache_order)})"
391+
f"current_size={self._state.current_size}, "
392+
f"cached_keys={len(self._state.cache_order)})"
408393
)

0 commit comments

Comments
 (0)