Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions statgpt/admin/services/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ async def _propagate_config_to_channel_datasets(
await self._session.refresh(dataset, attribute_names=["mapped_channels"])

if not dataset.mapped_channels:
_log.info(f"Dataset(id={dataset.id_}): no mapped channels, skipping config propagation")
return []

results: list[schemas.ChannelDatasetUpdateResult] = []
Expand Down Expand Up @@ -762,8 +763,16 @@ async def _propagate_config_to_channel_datasets(
and latest_version.preprocessing_status not in StatusEnum.final_statuses()
):
status = schemas.ChannelDatasetUpdateStatus.INDEXING_IN_PROGRESS
_log.info(
f"ChannelDataset(dataset={dataset.id_}, channel={channel.deployment_id!r}):"
f" indexing in progress, skipping"
)
elif last_completed is None:
status = schemas.ChannelDatasetUpdateStatus.NO_VERSION
_log.info(
f"ChannelDataset(dataset={dataset.id_}, channel={channel.deployment_id!r}):"
f" no completed version, skipping"
)
else:
_, resolved_config = await handler.resolve_config(
config=dataset.details,
Expand All @@ -779,8 +788,16 @@ async def _propagate_config_to_channel_datasets(
)
other_fields['new_version'] = new_version
status = schemas.ChannelDatasetUpdateStatus.AUTO_UPDATED
_log.info(
f"ChannelDataset(dataset={dataset.id_}, channel={channel.deployment_id!r}):"
f" auto-updated with new version"
)
else:
status = schemas.ChannelDatasetUpdateStatus.NEEDS_REINDEX
_log.info(
f"ChannelDataset(dataset={dataset.id_}, channel={channel.deployment_id!r}):"
f" indexing hash changed, needs reindex"
)

results.append(
schemas.ChannelDatasetUpdateResult(
Expand Down
1 change: 0 additions & 1 deletion statgpt/common/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ application-specific environment variables.
| LOG_LEVEL_HTTPCORE | No | Level of the logs for the HTTPCore library | `DEBUG`, `INFO`, `WARN`, `ERROR`, `CRITICAL` | `WARNING` |
| LOG_MULTILINE_MODE_ENABLED | No | Whether to enable multiline mode for logs | `true`, `false` | `false` |
| SDMX_CACHE_DIR | No | Directory to store the SDMX cache files. SDMX Cache is designed to speed up the application startup for development. If set, the application will take some data (dataset definitions, available values, etc.) from existing files instead of querying the SDMX portal. | | |
| QUANTHUB_DATASET_CACHE_TTL | No | The time in seconds after which the QuantHub SDMX dataset cache is invalidated | | `3600` |
| OAUTH2_TOKEN_ENDPOINT_URL | No | OAuth2 Token endpoint Url. Required if auth enabled | | |
| SERVICES_CHAT_SCOPE | No | Auth scope for the chat | | |
| SERVICES_CHAT_CLIENT_ID | No | Chat Client ID | | |
Expand Down
35 changes: 10 additions & 25 deletions statgpt/common/data/quanthub/v21/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from statgpt.common.data.sdmx.v21.ratelimiter import SdmxRateLimiterFactory
from statgpt.common.data.sdmx.v21.schemas import Urn
from statgpt.common.schemas.dataset import Status
from statgpt.common.settings.sdmx import quanthub_settings
from statgpt.common.utils import Cache
from statgpt.common.utils import AsyncLoadingCache
from statgpt.common.utils.timer import debug_timer

from .qh_sdmx_client import AsyncQuanthubClient
Expand All @@ -27,8 +26,7 @@
# todo: add generic typing with QuanthubInMemorySdmx21DataSet
class QuanthubSdmx21DataSourceHandler(Sdmx21DataSourceHandler):

# TEMP fix:
_dataset_cache: Cache[QuanthubSdmx21DataSet] = Cache(ttl=quanthub_settings.dataset_cache_ttl)
_dataset_cache: AsyncLoadingCache[QuanthubSdmx21DataSet] = AsyncLoadingCache()

def __init__(self, config: QuanthubSdmxDataSourceConfig):
super().__init__(config)
Expand Down Expand Up @@ -76,18 +74,9 @@ async def _get_dataset(
config: dict,
auth_context: AuthContext,
allow_offline: bool = False,
allow_cached: bool = False,
) -> QuanthubSdmx21DataSet | SdmxOfflineDataSet:
dataset_config = self.parse_data_set_config(config)

if allow_cached and not self._config.auth_enabled:
# If auth is disabled, we can cache datasets for all users
if ds := self._dataset_cache.get(str(entity_id)):
logger.debug(
f"Returning cached dataset(id={entity_id}, urn={dataset_config.urn!r})"
)
return ds

logger.info(f"Loading dataset urn={dataset_config.urn!r}.")

sdmx_client = await self.create_sdmx_client(auth_context)
Expand Down Expand Up @@ -191,12 +180,6 @@ async def _get_dataset(
else:
raise e

if allow_cached and not self._config.auth_enabled:
# If auth is disabled, cache the dataset for all users
# NOTE: we do not cache offline datasets
self._dataset_cache.set(str(entity_id), res)
logger.info(f"Cached dataset(id={entity_id}, urn={dataset_config.urn!r}).")

return res

async def get_dataset(
Expand All @@ -209,13 +192,15 @@ async def get_dataset(
allow_cached: bool = False,
) -> QuanthubSdmx21DataSet | SdmxOfflineDataSet:
with debug_timer(f"QuanthubSdmx21DataSourceHandler.get_dataset: {title}"):
if allow_cached and not self._config.auth_enabled:
dataset_config = self.parse_data_set_config(config)
return await self._dataset_cache.get(
key=str(entity_id),
loader=lambda: self._get_dataset(entity_id, title, config, auth_context),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behavior is not correct. If allow_offline=True and any problems occur, get_dataset should return an OfflineDataset without caching it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The loader now receives allow_offline, so the server is only hit once. If loading fails and allow_offline=True, the SdmxOfflineDataSet is stored in cache but the validator rejects it on the next access — triggering a fresh load in case the upstream recovered.

validator=lambda ds: ds.config == dataset_config,
)
return await self._get_dataset(
entity_id,
title,
config,
auth_context,
allow_offline=allow_offline,
allow_cached=allow_cached,
entity_id, title, config, auth_context, allow_offline=allow_offline
)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions statgpt/common/data/quanthub/v21/qh_sdmx_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from statgpt.common.data.sdmx.v21.ratelimiter import SdmxRateLimiter
from statgpt.common.data.sdmx.v21.sdmx_client import AsyncSdmxClient, init_sdmx
from statgpt.common.utils import Cache
from statgpt.common.utils import TtlCache

from .attributes_parser import AttributesParser
from .authorizer import QuanthubAuthorizer, QuanthubAuthorizerFactory
Expand All @@ -29,8 +29,8 @@ class AsyncQuanthubClient(AsyncSdmxClient):
Contains methods unique to QuantHub, such as fetching dynamic annotations.
"""

_annotation_cache: Cache[list[QhAnnotation]] = Cache()
_attributes_cache: Cache[dict[str, str | None]] = Cache()
_annotation_cache: TtlCache[list[QhAnnotation]] = TtlCache()
_attributes_cache: TtlCache[dict[str, str | None]] = TtlCache()

@classmethod
def from_config( # type: ignore[override]
Expand Down
14 changes: 0 additions & 14 deletions statgpt/common/settings/sdmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,5 @@ class SdmxSettings(BaseSettings):
)


class QuantHubSettings(BaseSettings):
"""
QuantHub specific settings
"""

model_config = SettingsConfigDict(env_prefix="quanthub_")

dataset_cache_ttl: int = Field(
default=3600,
description="Cache TTL for QuantHub datasets in seconds",
)


# Create singleton instances
sdmx_settings = SdmxSettings()
quanthub_settings = QuantHubSettings()
3 changes: 2 additions & 1 deletion statgpt/common/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .cache import Cache
from .async_loading_cache import AsyncLoadingCache
from .db_mixins import DateMixin, IdMixin
from .dial import (
AttachmentResponse,
Expand Down Expand Up @@ -47,3 +47,4 @@
get_ts_utcnow,
get_ts_utcnow_str,
)
from .ttl_cache import TtlCache
42 changes: 42 additions & 0 deletions statgpt/common/utils/async_loading_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
from collections.abc import Awaitable, Callable
from typing import Generic, TypeVar

T = TypeVar('T')


class AsyncLoadingCache(Generic[T]):
"""A cache that loads values asynchronously on cache miss,
with optional validation of cached entries.

Concurrent requests for the same key are deduplicated: only one
load runs while other callers await its result.
"""

def __init__(self) -> None:
self._cache: dict[str, asyncio.Future[T]] = {}

async def get(
self,
key: str,
loader: Callable[[], Awaitable[T]],
validator: Callable[[T], bool] | None = None,
) -> T:
if key in self._cache:
value = await self._cache[key]
if validator is None or validator(value):
return value
self._cache.pop(key, None)

self._cache[key] = asyncio.ensure_future(loader())
try:
return await self._cache[key]
except BaseException: # includes CancelledError to avoid caching canceled futures
self._cache.pop(key, None)
raise

def remove(self, key: str) -> None:
self._cache.pop(key, None)

def clear(self) -> None:
self._cache.clear()
4 changes: 2 additions & 2 deletions statgpt/common/utils/dial/model_pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from statgpt.common.config import multiline_logger as logger
from statgpt.common.schemas.dial import Pricing
from statgpt.common.settings.dial import dial_settings
from statgpt.common.utils import Cache, DialCore
from statgpt.common.utils import DialCore, TtlCache

_CACHE: Cache[Pricing] = Cache(ttl=24 * 3600) # 24 hours
_CACHE: TtlCache[Pricing] = TtlCache(ttl=24 * 3600) # 24 hours


class ModelPricingAuthContext(AuthContext):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class CacheItem(NamedTuple, Generic[T]):
expiry: float


class Cache(Generic[T]):
class TtlCache(Generic[T]):
def __init__(self, ttl: int = 3600):
self._cache: dict[str, CacheItem[T]] = {}
self._ttl = ttl
Expand All @@ -24,7 +24,7 @@ def get(self, key: str, default: T | None = None) -> T | None:
if time.time() < item.expiry:
return item.value
else:
self._remove_expired_item(key)
self.remove(key)
return default

def clear(self) -> None:
Expand All @@ -36,7 +36,8 @@ def cleanup(self) -> None:
current_time = time.time()
expired_keys = [key for key, item in self._cache.items() if current_time >= item.expiry]
for key in expired_keys:
self._remove_expired_item(key)
self.remove(key)

def _remove_expired_item(self, key: str) -> None:
def remove(self, key: str) -> None:
"""Remove a specific item from the cache by key."""
self._cache.pop(key, None)
Empty file added tests/unit/common/__init__.py
Empty file.
Empty file.
115 changes: 115 additions & 0 deletions tests/unit/common/utils/test_async_loading_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Unit tests for the AsyncLoadingCache utility class."""

import asyncio
from unittest.mock import AsyncMock

import pytest

from statgpt.common.utils.async_loading_cache import AsyncLoadingCache


class TestAsyncLoadingCacheGet:

@pytest.mark.asyncio
async def test_get_loads_on_miss(self) -> None:
cache: AsyncLoadingCache[str] = AsyncLoadingCache()
loader = AsyncMock(return_value="value")

result = await cache.get("k", loader)

assert result == "value"
loader.assert_awaited_once()

@pytest.mark.asyncio
async def test_get_returns_cached_on_hit(self) -> None:
cache: AsyncLoadingCache[str] = AsyncLoadingCache()
loader = AsyncMock(return_value="value")

await cache.get("k", loader)
result = await cache.get("k", loader)

assert result == "value"
loader.assert_awaited_once()

@pytest.mark.asyncio
async def test_get_reloads_when_validator_fails(self) -> None:
cache: AsyncLoadingCache[str] = AsyncLoadingCache()
loader = AsyncMock(side_effect=["old", "new"])

await cache.get("k", loader)
result = await cache.get("k", loader, validator=lambda v: v == "new")

assert result == "new"
assert loader.await_count == 2

@pytest.mark.asyncio
async def test_get_without_validator_always_hits(self) -> None:
cache: AsyncLoadingCache[str] = AsyncLoadingCache()
loader = AsyncMock(return_value="value")

await cache.get("k", loader)
result = await cache.get("k", loader)

assert result == "value"
loader.assert_awaited_once()

@pytest.mark.asyncio
async def test_concurrent_get_deduplicates(self) -> None:
cache: AsyncLoadingCache[str] = AsyncLoadingCache()
loader = AsyncMock(return_value="value")

results = await asyncio.gather(
cache.get("k", loader),
cache.get("k", loader),
)

assert results == ["value", "value"]
loader.assert_awaited_once()

@pytest.mark.asyncio
async def test_load_failure_not_cached(self) -> None:
cache: AsyncLoadingCache[str] = AsyncLoadingCache()
loader = AsyncMock(side_effect=[ValueError("fail"), "value"])

with pytest.raises(ValueError, match="fail"):
await cache.get("k", loader)

result = await cache.get("k", loader)
assert result == "value"
assert loader.await_count == 2


class TestAsyncLoadingCacheRemove:

@pytest.mark.asyncio
async def test_remove_triggers_reload(self) -> None:
cache: AsyncLoadingCache[str] = AsyncLoadingCache()
loader = AsyncMock(side_effect=["first", "second"])

await cache.get("k", loader)
cache.remove("k")
result = await cache.get("k", loader)

assert result == "second"
assert loader.await_count == 2

@pytest.mark.asyncio
async def test_remove_nonexistent_key(self) -> None:
cache: AsyncLoadingCache[str] = AsyncLoadingCache()
cache.remove("nonexistent") # should not raise


class TestAsyncLoadingCacheClear:

@pytest.mark.asyncio
async def test_clear_removes_all(self) -> None:
cache: AsyncLoadingCache[str] = AsyncLoadingCache()
loader = AsyncMock(side_effect=["a1", "b1", "a2", "b2"])

await cache.get("a", loader)
await cache.get("b", loader)
cache.clear()
await cache.get("a", loader)
await cache.get("b", loader)

assert loader.await_count == 4
Loading
Loading