Skip to content

Commit 36285e8

Browse files
authored
Merge pull request #264 from opentensor/release/1.6.1
Release/1.6.1
2 parents d52f20f + 353ccbb commit 36285e8

File tree

11 files changed

+253
-59
lines changed

11 files changed

+253
-59
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## 1.6.1 /2025-02-03
4+
* RuntimeCache updates by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/260
5+
* fix memory leak by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/261
6+
* Avoid Race Condition on SQLite Table Creation by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/263
7+
8+
**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.6.0...v1.6.1
9+
310
## 1.6.0 /2025-01-27
411
* Fix typo by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/258
512
* Improve Disk Caching by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/227

async_substrate_interface/async_substrate.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import socket
1212
import ssl
1313
import warnings
14-
from contextlib import suppress
1514
from unittest.mock import AsyncMock
1615
from hashlib import blake2b
1716
from typing import (
@@ -40,7 +39,6 @@
4039
from websockets.asyncio.client import connect, ClientConnection
4140
from websockets.exceptions import (
4241
ConnectionClosed,
43-
WebSocketException,
4442
)
4543
from websockets.protocol import State
4644

@@ -708,6 +706,10 @@ async def _cancel(self):
708706
logger.debug("Cancelling send/recv tasks")
709707
if self._send_recv_task is not None:
710708
self._send_recv_task.cancel()
709+
try:
710+
await self._send_recv_task
711+
except asyncio.CancelledError:
712+
pass
711713
except asyncio.CancelledError:
712714
pass
713715
except Exception as e:
@@ -777,16 +779,31 @@ async def _handler(self, ws: ClientConnection) -> Union[None, Exception]:
777779
logger.debug("WS handler attached")
778780
recv_task = asyncio.create_task(self._start_receiving(ws))
779781
send_task = asyncio.create_task(self._start_sending(ws))
780-
done, pending = await asyncio.wait(
781-
[recv_task, send_task],
782-
return_when=asyncio.FIRST_COMPLETED,
783-
)
782+
try:
783+
done, pending = await asyncio.wait(
784+
[recv_task, send_task],
785+
return_when=asyncio.FIRST_COMPLETED,
786+
)
787+
except asyncio.CancelledError:
788+
# Handler was cancelled, clean up child tasks
789+
for task in [recv_task, send_task]:
790+
if not task.done():
791+
task.cancel()
792+
try:
793+
await task
794+
except asyncio.CancelledError:
795+
pass
796+
raise
784797
loop = asyncio.get_running_loop()
785798
should_reconnect = False
786799
is_retry = False
787800

788801
for task in pending:
789802
task.cancel()
803+
try:
804+
await task
805+
except asyncio.CancelledError:
806+
pass
790807

791808
for task in done:
792809
task_res = task.result()
@@ -887,6 +904,14 @@ async def _exit_with_timer(self):
887904

888905
async def shutdown(self):
889906
logger.debug("Shutdown requested")
907+
# Cancel the exit timer task if it exists
908+
if self._exit_task is not None:
909+
self._exit_task.cancel()
910+
try:
911+
await self._exit_task
912+
except asyncio.CancelledError:
913+
pass
914+
self._exit_task = None
890915
try:
891916
await asyncio.wait_for(self._cancel(), timeout=10.0)
892917
except asyncio.TimeoutError:
@@ -990,8 +1015,9 @@ async def _start_sending(self, ws) -> Exception:
9901015
)
9911016
if to_send is not None:
9921017
to_send_ = json.loads(to_send)
993-
self._received[to_send_["id"]].set_exception(e)
994-
self._received[to_send_["id"]].cancel()
1018+
if to_send_["id"] in self._received:
1019+
self._received[to_send_["id"]].set_exception(e)
1020+
self._received[to_send_["id"]].cancel()
9951021
else:
9961022
for i in self._received.keys():
9971023
self._received[i].set_exception(e)
@@ -1975,7 +2001,6 @@ async def result_handler(
19752001

19762002
if subscription_result is not None:
19772003
reached = True
1978-
logger.info("REACHED!")
19792004
# Handler returned end result: unsubscribe from further updates
19802005
async with self.ws as ws:
19812006
await ws.unsubscribe(

async_substrate_interface/sync_substrate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3398,5 +3398,12 @@ def close(self):
33983398
self.ws.close()
33993399
except AttributeError:
34003400
pass
3401+
# Clear lru_cache on instance methods to allow garbage collection
3402+
self.get_runtime_for_version.cache_clear()
3403+
self.get_parent_block_hash.cache_clear()
3404+
self.get_block_runtime_info.cache_clear()
3405+
self.get_block_runtime_version_for.cache_clear()
3406+
self.supports_rpc_method.cache_clear()
3407+
self.get_block_hash.cache_clear()
34013408

34023409
encode_scale = SubstrateMixin._encode_scale

async_substrate_interface/types.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import bisect
12
import logging
3+
import os
24
from abc import ABC
35
from collections import defaultdict, deque
46
from collections.abc import Iterable
57
from contextlib import suppress
68
from dataclasses import dataclass
79
from datetime import datetime
8-
from typing import Optional, Union, Any
10+
from typing import Optional, Union, Any, Sequence
911

1012
import scalecodec.types
1113
from bt_decode import PortableRegistry, encode as encode_by_type_string
@@ -17,9 +19,11 @@
1719

1820
from .const import SS58_FORMAT
1921
from .utils import json
20-
from .utils.cache import AsyncSqliteDB
22+
from .utils.cache import AsyncSqliteDB, LRUCache
2123

2224
logger = logging.getLogger("async_substrate_interface")
25+
SUBSTRATE_RUNTIME_CACHE_SIZE = int(os.getenv("SUBSTRATE_RUNTIME_CACHE_SIZE", "16"))
26+
SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512"))
2327

2428

2529
class RuntimeCache:
@@ -41,11 +45,45 @@ class RuntimeCache:
4145
versions: dict[int, "Runtime"]
4246
last_used: Optional["Runtime"]
4347

44-
def __init__(self):
45-
self.blocks = {}
46-
self.block_hashes = {}
47-
self.versions = {}
48-
self.last_used = None
48+
def __init__(self, known_versions: Optional[Sequence[tuple[int, int]]] = None):
49+
# {block: block_hash, ...}
50+
self.blocks: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
51+
# {block_hash: specVersion, ...}
52+
self.block_hashes: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
53+
# {specVersion: Runtime, ...}
54+
self.versions: LRUCache = LRUCache(max_size=SUBSTRATE_RUNTIME_CACHE_SIZE)
55+
# [(block, specVersion), ...]
56+
self.known_versions: list[tuple[int, int]] = []
57+
# [block, ...] for binary search (excludes last item)
58+
self._known_version_blocks: list[int] = []
59+
if known_versions:
60+
self.add_known_versions(known_versions)
61+
self.last_used: Optional["Runtime"] = None
62+
63+
def add_known_versions(self, known_versions: Sequence[tuple[int, int]]):
64+
"""
65+
Known versions are a map of {block: specVersion} for when runtimes change.
66+
67+
E.g.
68+
[
69+
(561, 102),
70+
(1075, 103),
71+
...,
72+
(7257645, 367)
73+
]
74+
75+
This mapping is generally user-created or pulled from an external API, such as
76+
https://api.tao.app/docs#/chain/get_runtime_versions_api_beta_chain_runtime_version_get
77+
78+
By preloading the known versions, there can be significantly fewer chain calls to determine version.
79+
80+
Note that because the last runtime in the supplied known versions will be ignored, as otherwise we would
81+
have to assume that the final known version never changes.
82+
"""
83+
known_versions = list(sorted(known_versions, key=lambda v: v[0]))
84+
self.known_versions = known_versions
85+
# Cache block numbers (excluding last) for O(log n) binary search lookups
86+
self._known_version_blocks = [v[0] for v in known_versions[:-1]]
4987

5088
def add_item(
5189
self,
@@ -59,11 +97,11 @@ def add_item(
5997
"""
6098
self.last_used = runtime
6199
if block is not None and block_hash is not None:
62-
self.blocks[block] = block_hash
100+
self.blocks.set(block, block_hash)
63101
if block_hash is not None and runtime_version is not None:
64-
self.block_hashes[block_hash] = runtime_version
102+
self.block_hashes.set(block_hash, runtime_version)
65103
if runtime_version is not None:
66-
self.versions[runtime_version] = runtime
104+
self.versions.set(runtime_version, runtime)
67105

68106
def retrieve(
69107
self,
@@ -75,26 +113,35 @@ def retrieve(
75113
Retrieves a Runtime object from the cache, using the key of its block number, block hash, or runtime version.
76114
Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`.
77115
"""
116+
# No reason to do this lookup if the runtime version is already supplied in this call
117+
if block is not None and runtime_version is None and self._known_version_blocks:
118+
# _known_version_blocks excludes the last item (see note in `add_known_versions`)
119+
idx = bisect.bisect_right(self._known_version_blocks, block) - 1
120+
if idx >= 0:
121+
runtime_version = self.known_versions[idx][1]
122+
78123
runtime = None
79124
if block is not None:
80125
if block_hash is not None:
81-
self.blocks[block] = block_hash
126+
self.blocks.set(block, block_hash)
82127
if runtime_version is not None:
83-
self.block_hashes[block_hash] = runtime_version
84-
with suppress(KeyError):
85-
runtime = self.versions[self.block_hashes[self.blocks[block]]]
128+
self.block_hashes.set(block_hash, runtime_version)
129+
with suppress(AttributeError):
130+
runtime = self.versions.get(
131+
self.block_hashes.get(self.blocks.get(block))
132+
)
86133
self.last_used = runtime
87134
return runtime
88135
if block_hash is not None:
89136
if runtime_version is not None:
90-
self.block_hashes[block_hash] = runtime_version
91-
with suppress(KeyError):
92-
runtime = self.versions[self.block_hashes[block_hash]]
137+
self.block_hashes.set(block_hash, runtime_version)
138+
with suppress(AttributeError):
139+
runtime = self.versions.get(self.block_hashes.get(block_hash))
93140
self.last_used = runtime
94141
return runtime
95142
if runtime_version is not None:
96-
with suppress(KeyError):
97-
runtime = self.versions[runtime_version]
143+
runtime = self.versions.get(runtime_version)
144+
if runtime is not None:
98145
self.last_used = runtime
99146
return runtime
100147
return runtime
@@ -110,16 +157,21 @@ async def load_from_disk(self, chain_endpoint: str):
110157
logger.debug("No runtime mappings in disk cache")
111158
else:
112159
logger.debug("Found runtime mappings in disk cache")
113-
self.blocks = block_mapping
114-
self.block_hashes = block_hash_mapping
115-
self.versions = {
116-
x: Runtime.deserialize(y) for x, y in runtime_version_mapping.items()
117-
}
160+
self.blocks.cache = block_mapping
161+
self.block_hashes.cache = block_hash_mapping
162+
for x, y in runtime_version_mapping.items():
163+
self.versions.cache[x] = Runtime.deserialize(y)
118164

119165
async def dump_to_disk(self, chain_endpoint: str):
120166
db = AsyncSqliteDB(chain_endpoint=chain_endpoint)
167+
blocks = self.blocks.cache
168+
block_hashes = self.block_hashes.cache
169+
versions = self.versions.cache
121170
await db.dump_runtime_cache(
122-
chain_endpoint, self.blocks, self.block_hashes, self.versions
171+
chain=chain_endpoint,
172+
block_mapping=blocks,
173+
block_hash_mapping=block_hashes,
174+
version_mapping=versions,
123175
)
124176

125177

async_substrate_interface/utils/cache.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import inspect
3+
import weakref
34
from collections import OrderedDict
45
import functools
56
import logging
@@ -60,6 +61,7 @@ async def _create_if_not_exists(self, chain: str, table_name: str):
6061
);
6162
"""
6263
)
64+
await self._db.commit()
6365
await self._db.execute(
6466
f"""
6567
CREATE TRIGGER IF NOT EXISTS prune_rows_trigger_{table_name} AFTER INSERT ON {table_name}
@@ -81,8 +83,8 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]
8183
if not self._db:
8284
_ensure_dir()
8385
self._db = await aiosqlite.connect(CACHE_LOCATION)
84-
table_name = _get_table_name(func)
85-
local_chain = await self._create_if_not_exists(chain, table_name)
86+
table_name = _get_table_name(func)
87+
local_chain = await self._create_if_not_exists(chain, table_name)
8688
key = pickle.dumps((args, kwargs or None))
8789
try:
8890
cursor: aiosqlite.Cursor = await self._db.execute(
@@ -111,9 +113,9 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]:
111113
if not self._db:
112114
_ensure_dir()
113115
self._db = await aiosqlite.connect(CACHE_LOCATION)
114-
block_mapping = {}
115-
block_hash_mapping = {}
116-
version_mapping = {}
116+
block_mapping = OrderedDict()
117+
block_hash_mapping = OrderedDict()
118+
version_mapping = OrderedDict()
117119
tables = {
118120
"RuntimeCache_blocks": block_mapping,
119121
"RuntimeCache_block_hashes": block_hash_mapping,
@@ -419,6 +421,26 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
419421
self._inflight.pop(key, None)
420422

421423

424+
class _WeakMethod:
425+
"""
426+
Weak reference to a bound method that allows the instance to be garbage collected.
427+
Preserves the method's signature for introspection.
428+
"""
429+
430+
def __init__(self, method):
431+
self._func = method.__func__
432+
self._instance_ref = weakref.ref(method.__self__)
433+
# Store the bound method's signature (without 'self') for inspect.signature() to find.
434+
# We capture this once at creation time to avoid holding references to the bound method.
435+
self.__signature__ = inspect.signature(method)
436+
437+
def __call__(self, *args, **kwargs):
438+
instance = self._instance_ref()
439+
if instance is None:
440+
raise ReferenceError("Instance has been garbage collected")
441+
return self._func(instance, *args, **kwargs)
442+
443+
422444
class _CachedFetcherMethod:
423445
"""
424446
Helper class for using CachedFetcher with method caches (rather than functions)
@@ -428,18 +450,21 @@ def __init__(self, method, max_size: int, cache_key_index: int):
428450
self.method = method
429451
self.max_size = max_size
430452
self.cache_key_index = cache_key_index
431-
self._instances = {}
453+
# Use WeakKeyDictionary to avoid preventing garbage collection of instances
454+
self._instances: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
432455

433456
def __get__(self, instance, owner):
434457
if instance is None:
435458
return self
436459

437-
# Cache per-instance
460+
# Cache per-instance (weak references allow GC when instance is no longer used)
438461
if instance not in self._instances:
439462
bound_method = self.method.__get__(instance, owner)
463+
# Use weak reference wrapper to avoid preventing GC of instance
464+
weak_method = _WeakMethod(bound_method)
440465
self._instances[instance] = CachedFetcher(
441466
max_size=self.max_size,
442-
method=bound_method,
467+
method=weak_method,
443468
cache_key_index=self.cache_key_index,
444469
)
445470
return self._instances[instance]

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "async-substrate-interface"
3-
version = "1.6.0"
3+
version = "1.6.1"
44
description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface"
55
readme = "README.md"
66
license = { file = "LICENSE" }
@@ -37,6 +37,8 @@ classifiers = [
3737
"Programming Language :: Python :: 3.10",
3838
"Programming Language :: Python :: 3.11",
3939
"Programming Language :: Python :: 3.12",
40+
"Programming Language :: Python :: 3.13",
41+
"Programming Language :: Python :: 3.14",
4042
"Programming Language :: Python :: 3 :: Only",
4143
]
4244

0 commit comments

Comments
 (0)