diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c9c1e0a6e..75d739dba 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1,6 +1,7 @@ from __future__ import annotations import hashlib +import json import os import warnings from collections.abc import Generator, Mapping @@ -1655,6 +1656,84 @@ def write_metadata( if consolidate_metadata: self.write_consolidated_metadata() + def get_attrs( + self, + key: str, + return_as: Literal["dict", "json", "df"] | None = None, + sep: str = "_", + flatten: bool = True, + ) -> dict[str, Any] | str | pd.DataFrame: + """ + Retrieve a specific key from sdata.attrs and return it in the specified format. + + Parameters + ---------- + key + The key to retrieve from the attrs. + return_as + The format in which to return the data. Options are 'dict', 'json', 'df'. + If None, the function returns the data in its original format. + sep + Separator for nested keys in flattened data. Defaults to "_". + flatten + If True, flatten the data if it is a mapping. Defaults to True. + + Returns + ------- + dict[str, Any] | str | pd.DataFrame + The data associated with the specified key, returned in the specified format. + The format can be a dictionary, JSON string, or Pandas DataFrame, depending on + the value of `return_as`. + """ + + def _flatten_mapping(m: Mapping[str, Any], parent_key: str = "", sep: str = "_") -> dict[str, Any]: + + items: list[tuple[str, Any]] = [] + for k, v in m.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, Mapping): + items.extend(_flatten_mapping(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + if not isinstance(key, str): + raise TypeError("The key must be a string.") + + if not isinstance(sep, str): + raise TypeError("Parameter 'sep_for_nested_keys' must be a string.") + + if key not in self.attrs: + raise KeyError(f"The key '{key}' was not found in sdata.attrs.") + + data = self.attrs[key] + + # If the data is a mapping, flatten it + if flatten and isinstance(data, Mapping): + data = _flatten_mapping(data, sep=sep) + + if return_as is None: + return data + + if return_as == "dict": + if not isinstance(data, dict): + raise TypeError("Cannot convert non-dictionary data to a dictionary.") + return data + + if return_as == "json": + try: + return json.dumps(data) + except (TypeError, ValueError) as e: + raise ValueError(f"Failed to convert data to JSON: {e}") from e + + if return_as == "df": + try: + return pd.DataFrame([data]) + except Exception as e: + raise ValueError(f"Failed to convert data to DataFrame: {e}") from e + + raise ValueError(f"Invalid 'return_as' value: {return_as}. Expected 'dict', 'json', 'df', or None.") + @property def tables(self) -> Tables: """ diff --git a/tests/core/test_get_attrs.py b/tests/core/test_get_attrs.py new file mode 100644 index 000000000..aedf32711 --- /dev/null +++ b/tests/core/test_get_attrs.py @@ -0,0 +1,74 @@ +import pandas as pd +import pytest + +from spatialdata.datasets import blobs + + +@pytest.fixture +def sdata_attrs(): + sdata = blobs() + sdata.attrs["test"] = {"a": {"b": 12}, "c": 8} + return sdata + + +def test_get_attrs_as_is(sdata_attrs): + result = sdata_attrs.get_attrs(key="test", return_as=None, flatten=False) + expected = {"a": {"b": 12}, "c": 8} + assert result == expected + + +def test_get_attrs_as_dict_flatten(sdata_attrs): + result = sdata_attrs.get_attrs(key="test", return_as="dict", flatten=True) + expected = {"a_b": 12, "c": 8} + assert result == expected + + +def test_get_attrs_as_json_flatten_false(sdata_attrs): + result = sdata_attrs.get_attrs(key="test", return_as="json", flatten=False) + expected = '{"a": {"b": 12}, "c": 8}' + assert result == expected + + +def test_get_attrs_as_json_flatten_true(sdata_attrs): + result = sdata_attrs.get_attrs(key="test", return_as="json", flatten=True) + expected = '{"a_b": 12, "c": 8}' + assert result == expected + + +def test_get_attrs_as_dataframe_flatten_false(sdata_attrs): + result = sdata_attrs.get_attrs(key="test", return_as="df", flatten=False) + expected = pd.DataFrame([{"a": {"b": 12}, "c": 8}]) + pd.testing.assert_frame_equal(result, expected) + + +def test_get_attrs_as_dataframe_flatten_true(sdata_attrs): + result = sdata_attrs.get_attrs(key="test", return_as="df", flatten=True) + expected = pd.DataFrame([{"a_b": 12, "c": 8}]) + pd.testing.assert_frame_equal(result, expected) + + +# test invalid cases +def test_invalid_key(sdata_attrs): + with pytest.raises(KeyError, match="was not found in sdata.attrs"): + sdata_attrs.get_attrs(key="non_existent_key") + + +def test_invalid_return_as_value(sdata_attrs): + with pytest.raises(ValueError, match="Invalid 'return_as' value"): + sdata_attrs.get_attrs(key="test", return_as="invalid_option") + + +def test_non_string_key(sdata_attrs): + with pytest.raises(TypeError, match="The key must be a string."): + sdata_attrs.get_attrs(key=123) + + +def test_non_string_sep(sdata_attrs): + with pytest.raises(TypeError, match="Parameter 'sep_for_nested_keys' must be a string."): + sdata_attrs.get_attrs(key="test", sep=123) + + +def test_empty_attrs(): + sdata = blobs() + with pytest.raises(KeyError, match="was not found in sdata.attrs."): + sdata.get_attrs(key="test")