diff --git a/.github/workflows/test_and_release.yml b/.github/workflows/test_and_release.yml index e80ed00..8be766f 100644 --- a/.github/workflows/test_and_release.yml +++ b/.github/workflows/test_and_release.yml @@ -53,7 +53,7 @@ jobs: - name: Install and Run Tests run: | - pip install ".[dev]" + pip install -e ".[dev]" pip install -r tests/requirements.txt # Run the tests with coverage so we get a coverage report too pip install coverage diff --git a/src/trame_server/state.py b/src/trame_server/state.py index 71e6048..5e2902a 100644 --- a/src/trame_server/state.py +++ b/src/trame_server/state.py @@ -3,6 +3,7 @@ import weakref from collections import deque from contextlib import contextmanager +from typing import Any, Iterable, Iterator from .utils import asynchronous, is_dunder, is_private, share from .utils.hot_reload import reload @@ -43,10 +44,44 @@ def flushing_context(self): self.flushing = False +class _OrderedSet: + """ + Lightweight ordered set implementation based on dict to preserve insertion order + without external dependencies. + """ + + def __init__(self, *args: Any) -> None: + self._data: dict[Any, None] = {} + for arg in args: + self.add(arg) + + def __bool__(self) -> bool: + return bool(self._data) + + def __contains__(self, key: Any) -> bool: + return key in self._data + + def __iter__(self) -> Iterator[Any]: + return iter(self._data) + + def add(self, key: Any) -> None: + self._data[key] = None + + def clear(self) -> None: + self._data.clear() + + def discard(self, key: Any) -> None: + self._data.pop(key, None) + + def update(self, iterable: Iterable[Any]) -> None: + for item in iterable: + self.add(item) + + class StateChangeHandler: def __init__(self, listeners): self._all_listeners = listeners - self._currents = set() + self._currents = _OrderedSet() def add(self, key): if key in self._all_listeners: @@ -70,9 +105,9 @@ class _SuppressListenersChangeStack: """ def __init__(self): - self._deque: deque[set[str]] = deque() - self._suppressed_keys: set[str] | None = None - self._listener_keys: set[str] = set() + self._deque: deque[_OrderedSet] = deque() + self._suppressed_keys: _OrderedSet | None = None + self._listener_keys: _OrderedSet = _OrderedSet() def on_pending_key_added(self, key: str) -> None: if not self._is_suppressed(key): @@ -82,7 +117,7 @@ def on_pending_key_removed(self, key: str) -> None: self._listener_keys.discard(key) def push(self, *keys: str) -> None: - self._deque.append(set(keys)) + self._deque.append(_OrderedSet(*keys)) self._update_suppressed_keys() def pop(self) -> None: @@ -93,15 +128,15 @@ def pop(self) -> None: self._update_suppressed_keys() def clear(self) -> None: - self._listener_keys = set() + self._listener_keys = _OrderedSet() - def get_change_listener_keys(self) -> set[str]: + def get_change_listener_keys(self) -> _OrderedSet: return self._listener_keys def _is_suppressed(self, key: str) -> bool: if self._suppressed_keys is None: return False - if self._suppressed_keys == set(): + if not self._suppressed_keys: return True return key in self._suppressed_keys @@ -116,9 +151,9 @@ def _update_suppressed_keys(self) -> None: self._suppressed_keys = None return - self._suppressed_keys = set() + self._suppressed_keys = _OrderedSet() for d_set in self._deque: - if d_set == set(): + if not d_set: self._suppressed_keys.clear() return self._suppressed_keys.update(d_set) diff --git a/tests/test_state.py b/tests/test_state.py index 58b8c97..34be29d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -698,3 +698,30 @@ def on_change(a, **_): state.flush() mock.assert_not_called() + + +def test_state_change_listeners_are_triggered_in_modified_order(state): + keys = [f"k_{i}" for i in range(10)] + recorded = [] + + def create_listener(k): + def on_change(**_): + recorded.append(k) + + return on_change + + for key in keys: + state.change(key)(create_listener(key)) + + for i_trial in range(1000): + recorded.clear() + + for key in keys: + state[key] = i_trial + + state.flush() + + _assert_msg = ( + f"Order mismatch at iteration {i_trial}: expected {keys}, got {recorded}" + ) + assert recorded == keys, _assert_msg