Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/virtual_ecosystem/implementation/var_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import fields

from virtual_ecosystem.core.base_model import _discover_models
from virtual_ecosystem.core.base_model import discover_models
from virtual_ecosystem.core.variables import VariableMetadata, load_known_variables

# TODO - merge these into a single generate_model_variable_markdown and probably move it
Expand All @@ -16,7 +16,7 @@ def generate_variable_listing(model_name: str, var_attributes: list[str]) -> str
variables = load_known_variables()

# Find the model reference
models = {m.__name__: m for m in _discover_models()}
models = {m.__name__: m for m in discover_models()}
if model_name not in models:
raise ValueError("Unknown model name")
model = models[model_name]
Expand Down Expand Up @@ -103,7 +103,7 @@ def generate_variable_table(model_name: str, var_attributes: list[str]) -> str:
variables = load_known_variables()

# Find the model reference
models = {m.__name__: m for m in _discover_models()}
models = {m.__name__: m for m in discover_models()}
if model_name not in models:
raise ValueError("Unknown model name")
model = models[model_name]
Expand Down
18 changes: 15 additions & 3 deletions tests/core/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,9 +826,21 @@ def test_to_camel_case():


def test_discover_models():
"""Test the discover_all_variables_usage function."""
from virtual_ecosystem.core.base_model import BaseModel, _discover_models
"""Test the discover_models function."""
from virtual_ecosystem.core.base_model import BaseModel, discover_models

models = _discover_models()
models = discover_models()
assert len(models) > 0
assert all(issubclass(x, BaseModel) for x in models)


def test_discover_disturbances():
"""Test the discover_disturbances function.

TODO: Update when there are disturbance models implemented.
"""
from virtual_ecosystem.core.base_model import BaseDisturbance, discover_disturbances

models = discover_disturbances()
assert len(models) > 0
assert all(issubclass(x, BaseDisturbance) for x in models)
70 changes: 69 additions & 1 deletion tests/core/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
),
],
)
def test_registry(caplog, module_name, raises, exp_log):
def test_register_module(caplog, module_name, raises, exp_log):
"""Test the registry loading.

This runs tests on the actual core and testing modules and then uses some local
Expand Down Expand Up @@ -146,3 +146,71 @@ def test_registry(caplog, module_name, raises, exp_log):
log_check(
caplog=caplog, expected_log=exp_log, subset=slice(-len(exp_log), None, None)
)


@pytest.mark.parametrize(
argnames="module_name, raises, exp_log",
argvalues=[
pytest.param(
"virtual_ecosystem.disturbances.disturbance_testing",
does_not_raise(),
(
(
INFO,
"Registering module: "
"virtual_ecosystem.disturbances.disturbance_testing",
),
(
INFO,
"Registering model class for "
"virtual_ecosystem.disturbances.disturbance_testing: "
"DisturbanceTestingModel",
),
(
INFO,
"Configuration class registered for "
"virtual_ecosystem.disturbances.disturbance_testing",
),
),
id="disturbance_testing_import_good",
),
],
)
def test_register_disturbance(caplog, module_name, raises, exp_log):
"""Test the registry loading.

This runs tests on the actual core and testing modules and then uses some local
badly formatted models to check error handling.
"""

from virtual_ecosystem.core.base_model import BaseDisturbance
from virtual_ecosystem.core.configuration import Configuration
from virtual_ecosystem.core.registry import (
DISTURBANCE_REGISTRY,
ModuleInfo,
register_disturbance,
)

# Get the short name
_, _, short_name = module_name.rpartition(".")

caplog.clear()

with raises:
register_disturbance(module_name=module_name)

if isinstance(raises, does_not_raise):
# Test the detailed structure of the registry for the module
assert short_name in DISTURBANCE_REGISTRY
mod_info = DISTURBANCE_REGISTRY[short_name]
assert isinstance(mod_info, ModuleInfo)

if not mod_info.is_core:
assert issubclass(mod_info.model, BaseDisturbance)

assert issubclass(mod_info.config, Configuration)

# Check the last N entries in the log match the expectation.
log_check(
caplog=caplog, expected_log=exp_log, subset=slice(-len(exp_log), None, None)
)
113 changes: 100 additions & 13 deletions virtual_ecosystem/core/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@
import pkgutil
from abc import ABC, abstractmethod
from importlib import import_module
from typing import Any
from types import ModuleType
from typing import Any, TypeVar

import pint

Expand Down Expand Up @@ -782,9 +783,17 @@ def to_camel_case(snake_str: str) -> str:
return "".join(x.capitalize() for x in snake_str.lower().split("_"))


def _discover_models() -> list[type[BaseModel]]:
"""Discover all the models in Virtual Ecosystem."""
import virtual_ecosystem.models as models
T = TypeVar("T")


def _discover_models(models: ModuleType, of_type: type[T]) -> list[type[T]]:
Comment on lines +786 to +789
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what is going on here is that we'd "ideally" have this as:

def _discover_models(models: ModuleType, of_type: type[BaseModel] | type[BaseDisturbance]) ->  list[type[BaseModel] | type[BaseDisturbance)]:

But because we can't import those two classes outside of the functions because of circularity (which bugs me but 🤷 ) we instead use this to tell the function that it is generically getting a type of something (not an instance etc.).

Is that right - if so, could you add a comment to that effect? It's probably pretty entry level but right now generic typing is out of my comfort zone, so it would be useful to have an explanation to hand 😄 .

Of course, if my explanation isn't right, then it needs a comment even more!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, no. We do want type[T] - i.e. to use generics - because that way I am locking the types of the inputs and outputs together. It's like saying that whatever the input, the output is a list if those same things. In your example, from the typing perspective there is no relationship between input and output. I could have type[BaseModel] as input and have a list of type[BaseDisturbance] as outputs.

Generics are really powerful, but can also be quite confusing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring expanded with some details.

"""Discover all the models in Virtual Ecosystem.

We use the generic T type to ensure that the types of the inputs and the
outputs are linked together. In practice, T will be either
:attr:`~virtual_ecosystem.core.base_model.BaseModel` or
:attr:`~virtual_ecosystem.core.base_model.BaseDisturbance`.
"""

models_found = []
for mod in pkgutil.iter_modules(models.__path__):
Expand All @@ -800,18 +809,27 @@ def _discover_models() -> list[type[BaseModel]]:
continue

mod_class_name = to_camel_case(mod.name) + "Model"
if hasattr(module, mod_class_name):
if hasattr(module, mod_class_name) and issubclass(
getattr(module, mod_class_name), of_type
):
models_found.append(getattr(module, mod_class_name))
else:
LOGGER.warning(
f"No model class '{mod_class_name}' found in module "
f"'{models.__name__}.{mod.name}.{mod.name}_model'."
f"No model class '{mod_class_name}' of type `{of_type}` found in module"
f" '{models.__name__}.{mod.name}.{mod.name}_model'."
)
continue

return models_found


def discover_models() -> list[type[BaseModel]]:
"""Discover all the models in Virtual Ecosystem."""
import virtual_ecosystem.models as models

return _discover_models(models, BaseModel) # type: ignore[type-abstract]


class BaseDisturbance(ABC):
"""A superclass for all Virtual Ecosystem disturbance models.

Expand All @@ -830,22 +848,22 @@ class BaseDisturbance(ABC):
instance containing shared core elements used throughout models.
"""

disturbance_name: str
model_name: str
"""The model name.

This class attribute sets the name used to refer to identify the disturbance class
in the disturbance registry, within the configuration settings and in logging
messages.
"""

disturbed_models: list[str]
disturbed_models: tuple[str, ...]
"""A list of model names that this disturbance will affect.

This list will be used to validate the configuration - check that all the models
to disturb are available in the simulation - as well at runtime to select those
models when creating an instance of the disturbance."""

data_variables_disturbed: list[str]
data_variables_disturbed: tuple[str, ...]
"""A list of data variables that will be updated.

This list will be used to validate the configuration and ensure that all the
Expand Down Expand Up @@ -879,7 +897,7 @@ def __init__(
missing = set(self.disturbed_models).difference(models.keys())
if missing:
raise ConfigurationError(
f"Models {missing} required by disturbance {self.disturbance_name}"
f"Models {missing} required by disturbance {self.model_name}"
"not available."
)
self.models = {
Expand All @@ -889,14 +907,76 @@ def __init__(
}
"""The models this disturbance will disturb."""

def __init_subclass__(self, disturbance_name: str, disturbed_models: list[str]):
@classmethod
def __init_subclass__(
cls,
model_name: str,
disturbed_models: tuple[str, ...],
data_variables_disturbed: tuple[str, ...],
):
"""Checks the disturbed models and variables are all known.

If so, it adds the disturbance to the registry.
"""
cls.model_name = cls._check_model_name(model_name)
cls.disturbed_models = cls._check_attributes(disturbed_models)
cls.data_variables_disturbed = cls._check_attributes(data_variables_disturbed)

@classmethod
def _check_model_name(cls, model_name: str) -> str:
"""Check the model_name attribute is valid.

Args:
model_name: The
:attr:`~virtual_ecosystem.core.base_model.BaseModel.model_name`
attribute to be used for a subclass.

Raises:
ValueError: the model_name is not a string.

TODO: Complete when the registry is implemented in #1368.
Returns:
The provided ``model_name`` if valid
"""

if not isinstance(model_name, str):
excep = TypeError(
f"Class attribute model_name in {cls.__name__} is not a string"
)
LOGGER.error(excep)
raise excep

return model_name

@classmethod
def _check_attributes(cls, attribute_value: tuple[str, ...]) -> tuple[str, ...]:
"""Check that disturbance variables and models attributes are valid.

They both need to be tuples of strings, so we make sure that is the case
when creating the class.

Args:
attribute_value: The provided value for the attribute

Raises:
TypeError: the value of the model attribute has the wrong type structure.

Returns:
The validated variables attribute value
"""

# Check the structure
if isinstance(attribute_value, tuple) and all(
isinstance(vname, str) for vname in attribute_value
):
return attribute_value

to_raise = TypeError(
f"Class attribute {attribute_value} has the wrong "
f"structure in {cls.__name__}"
)
LOGGER.error(to_raise)
raise to_raise

@classmethod
@abstractmethod
def from_config(
Expand Down Expand Up @@ -929,3 +1009,10 @@ def _disturb(self, time_index: int) -> None:
Args:
time_index: The index of the current timestep.
"""


def discover_disturbances() -> list[type[BaseDisturbance]]:
"""Discover all the disturbances in Virtual Ecosystem."""
import virtual_ecosystem.disturbances as disturbances

return _discover_models(disturbances, BaseDisturbance) # type: ignore[type-abstract]
42 changes: 42 additions & 0 deletions virtual_ecosystem/core/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
found :doc:`here </using_the_ve/configuration/config>`.
""" # noqa: D205

from __future__ import annotations

from collections.abc import Callable
from pathlib import Path
from typing import Annotated, Any, ClassVar, TypeAlias, TypeVar
Expand All @@ -26,6 +28,7 @@
DirectoryPath,
Field,
FilePath,
model_validator,
)
from pydantic._internal._model_construction import ModelMetaclass
from pydantic_core import PydanticUndefined
Expand Down Expand Up @@ -173,6 +176,45 @@ class ModelConfigurationRoot(Configuration):
"""The model static mode setting."""


class DisturbanceConfigurationRoot(Configuration):
"""Root configuration class for disturbance Virtual Ecosystem models.

This model provides a common Pydantic base class that must be used to define
the root configuration class of a Virtual Ecosystem disturbance model. Each
disturbance must define an object ``model_name.model_config.ModelConfiguration``
that inherits from :class:`DisturbanceConfigurationRoot`. The
``model_name.model_config`` module can then include other :class:`Configuration`
classes that are used as nested fields within the root configuration but can be only
one :class:`ModelConfigurationRoot` class per model. This base model sets common
shared attributes across models: currently just the timing options.

It also validates the timing fields to ensure that at least one of them is set.
"""

run_at: int | list[int] | None = None
"""Define time indices to run at specific times.

Either a single integer or a list of integers indicating the time indices when the
disturbance is to run.
"""
run_every: tuple[int, ...] | None = None
"""Define a range of indices to run the disturbance.

A tuple of integers indicating (start), or (start, step), or (start, step, stop),
from where a list of integers indicating the time indices when the disturbance is to
run can be constructed. If not provided, 'step' defaults to 1 and 'stop' defaults to
the last time index. 'start' must always be provided."""

@model_validator(mode="after")
def timing_options_are_not_both_none(self) -> DisturbanceConfigurationRoot:
"""Validate the timing options of the configuration."""
if self.run_at is None and self.run_every is None:
raise ValueError(
"Timing options 'run_at' and 'run_every' cannot be both None."
)
return self


def model_config_to_html(
model_name: str, config_object: type[Configuration], rows_only: bool = False
):
Expand Down
Loading
Loading