diff --git a/.gitignore b/.gitignore index eaa04de..54c04e6 100644 --- a/.gitignore +++ b/.gitignore @@ -116,4 +116,7 @@ clean/ .pytest_cache/ __pycache__/ mypy_cache/ -.claude/ \ No newline at end of file +.claude/ +*.out +*.log +*.err \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 100fa1d..89a069b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dev = [ "twine", "hatch", "python-semantic-release", + "objgraph", ] all = [ "cellmap-data[dev,test]", diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 68aad95..6a92d9c 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -5,6 +5,7 @@ import logging import os import platform +from concurrent.futures import Executor as _ConcurrentExecutor from concurrent.futures import Future as _ConcurrentFuture from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Callable, Mapping, Optional, Sequence @@ -19,7 +20,7 @@ from .empty_image import EmptyImage from .image import CellMapImage from .mutable_sampler import MutableSubsetRandomSampler -from .read_limiter import MAX_CONCURRENT_READS, limit_tensorstore_reads +from .utils.read_limiter import MAX_CONCURRENT_READS, limit_tensorstore_reads from .utils import get_sliced_shape, is_array_2D, min_redundant_inds, split_target_path logger = logging.getLogger(__name__) @@ -42,13 +43,15 @@ ) -class _ImmediateExecutor: +class _ImmediateExecutor(_ConcurrentExecutor): """Drop-in for ThreadPoolExecutor that runs tasks in the calling thread. On Windows + TensorStore the real ThreadPoolExecutor causes native crashes. This executor avoids that by executing every submitted callable synchronously before returning, so the returned Future is already resolved. ``as_completed`` handles pre-resolved futures correctly (yields immediately). + ``map`` is inherited from ``concurrent.futures.Executor`` and works correctly + because it calls ``submit`` internally (which returns pre-resolved futures). ``shutdown`` is a no-op because there are no threads to join. """ @@ -166,7 +169,7 @@ def __init__( device=self._device, ) self.target_sources = {} - self.has_data = ( + self.has_data = force_has_data or ( False if (len(self.target_arrays) > 0 and len(self.classes) > 0) else True ) for array_name, array_info in self.target_arrays.items(): @@ -424,21 +427,27 @@ def largest_voxel_sizes(self) -> Mapping[str, float]: @cached_property def bounding_box(self) -> Mapping[str, list[float]]: """Returns the bounding box of the dataset.""" - bounding_box: dict[str, list[float]] | None = None all_sources = list(self.input_sources.values()) + list( self.target_sources.values() ) + # Flatten to individual CellMapImage objects + flat_sources = [] for source in all_sources: if isinstance(source, dict): - for sub_source in source.values(): - if hasattr(sub_source, "bounding_box"): - bounding_box = self._get_box_intersection( - sub_source.bounding_box, bounding_box - ) - elif hasattr(source, "bounding_box"): - bounding_box = self._get_box_intersection( - source.bounding_box, bounding_box + flat_sources.extend( + s for s in source.values() if hasattr(s, "bounding_box") ) + elif hasattr(source, "bounding_box"): + flat_sources.append(source) + + # Prefetch bounding boxes in parallel (each triggers a zarr group open) + # Use self.executor to respect Windows+TensorStore immediate executor handling + boxes = list(self.executor.map(lambda s: s.bounding_box, flat_sources)) + + bounding_box: dict[str, list[float]] | None = None + for box in boxes: + bounding_box = self._get_box_intersection(box, bounding_box) + if bounding_box is None: logger.warning( "Bounding box is None. This may cause errors during sampling." @@ -454,21 +463,27 @@ def bounding_box_shape(self) -> Mapping[str, int]: @cached_property def sampling_box(self) -> Mapping[str, list[float]]: """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box).""" - sampling_box: dict[str, list[float]] | None = None all_sources = list(self.input_sources.values()) + list( self.target_sources.values() ) + flat_sources = [] for source in all_sources: if isinstance(source, dict): - for sub_source in source.values(): - if hasattr(sub_source, "sampling_box"): - sampling_box = self._get_box_intersection( - sub_source.sampling_box, sampling_box - ) - elif hasattr(source, "sampling_box"): - sampling_box = self._get_box_intersection( - source.sampling_box, sampling_box + flat_sources.extend( + s for s in source.values() if hasattr(s, "sampling_box") ) + elif hasattr(source, "sampling_box"): + flat_sources.append(source) + + # Prefetch sampling boxes in parallel; bounding_box is already cached + # from the bounding_box property so these are cheap if called after it. + # Use self.executor to respect Windows+TensorStore immediate executor handling + boxes = list(self.executor.map(lambda s: s.sampling_box, flat_sources)) + + sampling_box: dict[str, list[float]] | None = None + for box in boxes: + sampling_box = self._get_box_intersection(box, sampling_box) + if sampling_box is None: logger.warning( "Sampling box is None. This may cause errors during sampling." @@ -781,7 +796,7 @@ def get_label_array( interpolation="nearest", device=self._device, ) - if not self.has_data: + if not self.has_data and not self.force_has_data: self.has_data = array.class_counts > 0 logger.debug(f"{str(self)} has data: {self.has_data}") else: diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 7626634..d5dcca7 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -95,8 +95,8 @@ def __init__( self.target_array_writers[array_name] = self.get_target_array_writer( array_name, array_info ) + self._device: str | torch.device = device if device is not None else "cpu" if device is not None: - self._device = device self.to(device, non_blocking=True) @cached_property @@ -237,11 +237,7 @@ def loader( @property def device(self) -> str | torch.device: """Returns the device for the dataset.""" - try: - return self._device - except AttributeError: - self._device = "cpu" - return self._device + return self._device def get_center(self, idx: int) -> dict[str, float]: """ diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 4b9d6f6..3b26484 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -84,7 +84,6 @@ def __init__( self._current_spatial_transforms = None self._current_coords: Any = None self._current_center = None - self._coord_offsets = None # Cache for coordinate offsets (optimization) if device is not None: self.device = device elif torch.cuda.is_available(): @@ -96,49 +95,68 @@ def __init__( def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: """Returns image data centered around the given point, based on the scale and shape of the target output image.""" - if isinstance(list(center.values())[0], int | float): - self._current_center = center - - # Use cached coordinate offsets + translation (much faster than np.linspace) - # This eliminates repeated coordinate grid generation - coords = {c: self.coord_offsets[c] + center[c] for c in self.axes} - - # Bounds checking - for c in self.axes: - if center[c] - self.output_size[c] / 2 < self.bounding_box[c][0]: - UserWarning( - f"Center {center[c]} is out of bounds for axis {c} in image {self.path}. {center[c] - self.output_size[c] / 2} would be less than {self.bounding_box[c][0]}" - ) - if center[c] + self.output_size[c] / 2 > self.bounding_box[c][1]: - UserWarning( - f"Center {center[c]} is out of bounds for axis {c} in image {self.path}. {center[c] + self.output_size[c] / 2} would be greater than {self.bounding_box[c][1]}" - ) - - # Apply any spatial transformations to the coordinates and return the image data as a PyTorch tensor - data = self.apply_spatial_transforms(coords) - else: - self._current_center = {k: np.mean(v) for k, v in center.items()} - self._current_coords = center - # Optimized tensor creation: use torch.from_numpy when possible to avoid data copying - array_data = self.return_data(self._current_coords).values - if isinstance(array_data, np.ndarray): - data = torch.from_numpy(array_data) + try: + if isinstance(list(center.values())[0], int | float): + self._current_center = center + + # Use cached coordinate offsets + translation (much faster than np.linspace) + # This eliminates repeated coordinate grid generation + coords = {c: self.coord_offsets[c] + center[c] for c in self.axes} + + # Bounds checking + for c in self.axes: + if center[c] - self.output_size[c] / 2 < self.bounding_box[c][0]: + logger.warning( + f"Center {center[c]} is out of bounds for axis {c} in image {self.path}. {center[c] - self.output_size[c] / 2} would be less than {self.bounding_box[c][0]}" + ) + if center[c] + self.output_size[c] / 2 > self.bounding_box[c][1]: + logger.warning( + f"Center {center[c]} is out of bounds for axis {c} in image {self.path}. {center[c] + self.output_size[c] / 2} would be greater than {self.bounding_box[c][1]}" + ) + + # Apply any spatial transformations to the coordinates and return the image data as a PyTorch tensor + data = self.apply_spatial_transforms(coords) else: - data = torch.tensor(array_data) - - # Apply any value transformations to the data - if self.value_transform is not None: - data = self.value_transform(data) - - # Return data on CPU - let the DataLoader handle device transfer with streams - # This avoids redundant transfers and allows for optimized batch transfers - return data + self._current_center = {k: np.mean(v) for k, v in center.items()} + self._current_coords = center + # Optimized tensor creation: use torch.from_numpy when possible to avoid data copying + array_data = self.return_data(self._current_coords).values + if isinstance(array_data, np.ndarray): + data = torch.from_numpy(array_data) + else: + data = torch.tensor(array_data) + + # Apply any value transformations to the data + if self.value_transform is not None: + data = self.value_transform(data) + + # Return data on CPU - let the DataLoader handle device transfer with streams + # This avoids redundant transfers and allows for optimized batch transfers + return data + finally: + # Clear cached array property to prevent memory accumulation from xarray + # operations (interp/reindex/sel) during training iterations. The array + # will be reopened on next access if needed. Use finally to ensure cleanup + # even if an exception occurs during data retrieval. + self._clear_array_cache() def __repr__(self) -> str: """Returns a string representation of the CellMapImage object.""" return f"CellMapImage({self.array_path})" - @property + def _clear_array_cache(self) -> None: + """ + Clear the cached xarray DataArray to release intermediate objects. + + xarray operations (interp, reindex, sel) create intermediate arrays that + remain referenced through the DataArray. Clearing the cache after each + __getitem__ releases those references without closing the underlying + TensorStore handle, which is separately cached in _ts_store and reused. + """ + if "array" in self.__dict__: + del self.__dict__["array"] + + @cached_property def coord_offsets(self) -> Mapping[str, np.ndarray]: """ Cached coordinate offsets from center. @@ -152,16 +170,14 @@ def coord_offsets(self) -> Mapping[str, np.ndarray]: Mapping[str, np.ndarray] Dictionary mapping axis names to coordinate offset arrays. """ - if self._coord_offsets is None: - self._coord_offsets = { - c: np.linspace( - -self.output_size[c] / 2 + self.scale[c] / 2, - self.output_size[c] / 2 - self.scale[c] / 2, - self.output_shape[c], - ) - for c in self.axes - } - return self._coord_offsets + return { + c: np.linspace( + -self.output_size[c] / 2 + self.scale[c] / 2, + self.output_size[c] / 2 - self.scale[c] / 2, + self.output_shape[c], + ) + for c in self.axes + } @cached_property def shape(self) -> Mapping[str, int]: @@ -223,9 +239,43 @@ def array_path(self) -> str: """Returns the path to the single-scale image array.""" return os.path.join(self.path, self.scale_level) + @cached_property + def _ts_store(self) -> ts.TensorStore: # type: ignore + """ + Opens and caches the TensorStore array handle. + + ts.open() is called exactly once per CellMapImage instance and the + resulting handle is kept alive for the instance's lifetime. The handle + is lightweight (it holds a reference to the shared context and chunk + cache) and is safe to reuse across many __getitem__ calls. + + Separating this from the `array` cached_property means that clearing + `array` after each __getitem__ (to release xarray intermediate objects) + does not trigger a new ts.open() call on the next access. + """ + spec = xt._zarr_spec_from_path(self.array_path) + array_future = ts.open(spec, read=True, write=False, context=self.context) + try: + return array_future.result() + except ValueError as e: + logger.warning( + "Failed to open with default driver: %s. Falling back to zarr3 driver.", + e, + ) + spec["driver"] = "zarr3" + return ts.open(spec, read=True, write=False, context=self.context).result() + @cached_property def array(self) -> xarray.DataArray: - """Returns the image data as an xarray DataArray.""" + """ + Returns the image data as an xarray DataArray. + + This property is cached but is explicitly cleared after each __getitem__ + call to release xarray intermediate objects (from interp/reindex/sel) + that would otherwise accumulate during training. Clearing it is cheap + because the underlying TensorStore handle is separately cached in + _ts_store and is not reopened. + """ if ( os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore").lower() != "tensorstore" @@ -235,22 +285,7 @@ def array(self) -> xarray.DataArray: chunks="auto", ) else: - # Construct an xarray with Tensorstore backend - spec = xt._zarr_spec_from_path(self.array_path) - array_future = ts.open(spec, read=True, write=False, context=self.context) - try: - array = array_future.result() - except ValueError as e: - logger.warning( - "Failed to open with default driver: %s. Falling back to zarr3 driver.", - e, - ) - spec["driver"] = "zarr3" - array_future = ts.open( - spec, read=True, write=False, context=self.context - ) - array = array_future.result() - data = xt._TensorStoreAdapter(array) + data = xt._TensorStoreAdapter(self._ts_store) return xarray.DataArray(data=data, coords=self.full_coords) @cached_property @@ -324,6 +359,7 @@ def class_counts(self) -> float: else: raise ValueError("s0_scale not found") except Exception as e: + # TODO: This fallback is very expensive, and ideally should be avoided. We should add a script to precompute class counts for all images and save them to the metadata to avoid this in the future. logger.warning( "Unable to get class counts for %s from metadata, " "falling back to calculating from array. Error: %s, %s", diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index 095d0ad..090a758 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -1,4 +1,6 @@ +import logging import os +from functools import cached_property from typing import Mapping, Optional, Sequence, Union import numpy as np @@ -14,6 +16,8 @@ from cellmap_data.utils import create_multiscale_metadata +logger = logging.getLogger(__name__) + class ImageWriter: """ @@ -73,158 +77,122 @@ def __init__( "chunk_shape": list(write_voxel_shape.values()), } - @property + @cached_property def array(self) -> xarray.DataArray: + os.makedirs(UPath(self.base_path), exist_ok=True) + group_path = str(self.base_path).split(".zarr")[0] + ".zarr" + for group in [""] + list(UPath(str(self.base_path).split(".zarr")[-1]).parts)[ + 1: + ]: + group_path = UPath(group_path) / group + with open(group_path / ".zgroup", "w") as f: + f.write('{"zarr_format": 2}') + create_multiscale_metadata( + ds_name=str(self.base_path), + voxel_size=self.metadata["voxel_size"], + translation=self.metadata["offset"], + units=self.metadata["units"], + axes=self.metadata["axes"], + base_scale_level=self.scale_level, + levels_to_add=0, + out_path=str(UPath(self.base_path) / ".zattrs"), + ) + spec = { + "driver": "zarr", + "kvstore": {"driver": "file", "path": self.path}, + } + open_kwargs = { + "read": True, + "write": True, + "create": True, + "delete_existing": self.overwrite, + "dtype": self.dtype, + "shape": list(self.shape.values()), + "fill_value": self.fill_value, + "chunk_layout": tensorstore.ChunkLayout(write_chunk_shape=self.chunk_shape), + "context": self.context, + } + array_future = tensorstore.open( + spec, + **open_kwargs, + ) try: - return self._array - except AttributeError: - os.makedirs(UPath(self.base_path), exist_ok=True) - group_path = str(self.base_path).split(".zarr")[0] + ".zarr" - for group in [""] + list( - UPath(str(self.base_path).split(".zarr")[-1]).parts - )[1:]: - group_path = UPath(group_path) / group - with open(group_path / ".zgroup", "w") as f: - f.write('{"zarr_format": 2}') - create_multiscale_metadata( - ds_name=str(self.base_path), - voxel_size=self.metadata["voxel_size"], - translation=self.metadata["offset"], - units=self.metadata["units"], - axes=self.metadata["axes"], - base_scale_level=self.scale_level, - levels_to_add=0, - out_path=str(UPath(self.base_path) / ".zattrs"), - ) - spec = { - "driver": "zarr", - "kvstore": {"driver": "file", "path": self.path}, - } - open_kwargs = { - "read": True, - "write": True, - "create": True, - "delete_existing": self.overwrite, - "dtype": self.dtype, - "shape": list(self.shape.values()), - "fill_value": self.fill_value, - "chunk_layout": tensorstore.ChunkLayout( - write_chunk_shape=self.chunk_shape - ), - "context": self.context, - } - array_future = tensorstore.open( - spec, - **open_kwargs, - ) - try: - array = array_future.result() - except ValueError as e: - if "ALREADY_EXISTS" in str(e): - raise FileExistsError( - f"Image already exists at {self.path}. Set overwrite=True to overwrite the image." + array = array_future.result() + except ValueError as e: + if "ALREADY_EXISTS" in str(e): + raise FileExistsError( + f"Image already exists at {self.path}. Set overwrite=True to overwrite the image." + ) + logger.warning("Error opening with zarr driver: %s", e) + logger.warning("Falling back to zarr3 driver") + spec["driver"] = "zarr3" + array_future = tensorstore.open(spec, **open_kwargs) + array = array_future.result() + data = xarray.DataArray( + data=xt._TensorStoreAdapter(array), + coords=coords_from_transforms( + axes=[ + Axis( + name=c, + type="space" if c != "c" else "channel", + unit="nm" if c != "c" else "", ) - Warning(e) - UserWarning("Falling back to zarr3 driver") - spec["driver"] = "zarr3" - array_future = tensorstore.open(spec, **open_kwargs) - array = array_future.result() - from pydantic_ome_ngff.v04.axis import Axis - from pydantic_ome_ngff.v04.transform import VectorScale, VectorTranslation - from xarray_ome_ngff.v04.multiscale import coords_from_transforms - - data = xarray.DataArray( - data=xt._TensorStoreAdapter(array), - coords=coords_from_transforms( - axes=[ - Axis( - name=c, - type="space" if c != "c" else "channel", - unit="nm" if c != "c" else "", - ) - for c in self.axes - ], - transforms=( - VectorScale(scale=tuple(self.scale.values())), - VectorTranslation(translation=tuple(self.offset.values())), - ), - shape=tuple(self.shape.values()), + for c in self.axes + ], + transforms=( + VectorScale(scale=tuple(self.scale.values())), + VectorTranslation(translation=tuple(self.offset.values())), ), - ) - self._array = data - with open(UPath(self.path) / ".zattrs", "w") as f: - f.write("{}") - return self._array + shape=tuple(self.shape.values()), + ), + ) + with open(UPath(self.path) / ".zattrs", "w") as f: + f.write("{}") + return data - @property + @cached_property def chunk_shape(self) -> Sequence[int]: - try: - return self._chunk_shape - except AttributeError: - self._chunk_shape = list(self.write_voxel_shape.values()) - return self._chunk_shape + return list(self.write_voxel_shape.values()) - @property + @cached_property def world_shape(self) -> Mapping[str, float]: - try: - return self._world_shape - except AttributeError: - self._world_shape = { - c: self.bounding_box[c][1] - self.bounding_box[c][0] - for c in self.spatial_axes - } - return self._world_shape + return { + c: self.bounding_box[c][1] - self.bounding_box[c][0] + for c in self.spatial_axes + } - @property + @cached_property def shape(self) -> Mapping[str, int]: - try: - return self._shape - except AttributeError: - self._shape = { - c: int(np.ceil(self.world_shape[c] / self.scale[c])) - for c in self.spatial_axes - } - return self._shape + return { + c: int(np.ceil(self.world_shape[c] / self.scale[c])) + for c in self.spatial_axes + } - @property + @cached_property def center(self) -> Mapping[str, float]: - try: - return self._center - except AttributeError: - self._center = { - str(k): float(np.mean(v)) for k, v in self.array.coords.items() - } - return self._center + return {str(k): float(np.mean(v)) for k, v in self.array.coords.items()} - @property + @cached_property def offset(self) -> Mapping[str, float]: - try: - return self._offset - except AttributeError: - self._offset = {c: self.bounding_box[c][0] for c in self.spatial_axes} - return self._offset + return {c: self.bounding_box[c][0] for c in self.spatial_axes} - @property + @cached_property def full_coords(self) -> tuple[xarray.DataArray, ...]: - try: - return self._full_coords - except AttributeError: - self._full_coords = coords_from_transforms( - axes=[ - Axis( - name=c, - type="space" if c != "c" else "channel", - unit="nm" if c != "c" else "", - ) - for c in self.axes - ], - transforms=( - VectorScale(scale=tuple(self.scale.values())), - VectorTranslation(translation=tuple(self.offset.values())), - ), - shape=tuple(self.shape.values()), - ) - return self._full_coords + return coords_from_transforms( + axes=[ + Axis( + name=c, + type="space" if c != "c" else "channel", + unit="nm" if c != "c" else "", + ) + for c in self.axes + ], + transforms=( + VectorScale(scale=tuple(self.scale.values())), + VectorTranslation(translation=tuple(self.offset.values())), + ), + shape=tuple(self.shape.values()), + ) def align_coords( self, coords: Mapping[str, tuple[Sequence, np.ndarray]] diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index 7067a08..d72c042 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -1,6 +1,8 @@ import functools +from concurrent.futures import ThreadPoolExecutor, as_completed from functools import cached_property import logging +import os from typing import Any, Callable, Mapping, Optional, Sequence import numpy as np @@ -9,7 +11,7 @@ from tqdm import tqdm from .base_dataset import CellMapBaseDataset -from .dataset import CellMapDataset +from .dataset import CellMapDataset, _USE_IMMEDIATE_EXECUTOR from .mutable_sampler import MutableSubsetRandomSampler from .utils.sampling import min_redundant_inds @@ -102,16 +104,77 @@ def class_counts(self) -> dict[str, dict[str, float]]: """ Returns the number of samples in each class for each dataset in the multi-dataset, as well as the total number of samples in each class. """ - class_counts = {"totals": {c: 0.0 for c in self.classes}} - class_counts["totals"].update({c + "_bg": 0.0 for c in self.classes}) - logger.info("Gathering class counts...") - for ds in tqdm(self.datasets): - for c in self.classes: - if c in ds.class_counts["totals"]: - class_counts["totals"][c] += ds.class_counts["totals"][c] - class_counts["totals"][c + "_bg"] += ds.class_counts["totals"][ - c + "_bg" - ] + classes: list[str] = list(self.classes or []) + class_counts: dict[str, dict[str, float]] = { + "totals": {c: 0.0 for c in classes} + } + class_counts["totals"].update({c + "_bg": 0.0 for c in classes}) + n_datasets = len(self.datasets) + + # Short-circuit if no classes or no datasets to avoid unnecessary computation + if not classes: + logger.info("No classes configured; returning empty totals dict") + return class_counts + if n_datasets == 0: + logger.info( + "No datasets to gather counts for; returning zero-initialized totals for configured classes" + ) + return class_counts + + logger.info("Gathering class counts for %d datasets...", n_datasets) + + # Determine number of worker threads from environment, with defensive parsing. + # Ensure we always have at least 1 worker when n_datasets > 0 to avoid + # ThreadPoolExecutor(max_workers=0) raising at runtime. + max_workers_env = os.environ.get("CELLMAP_MAX_WORKERS", "8") + try: + max_workers = int(max_workers_env) + except (TypeError, ValueError): + logger.warning( + "Invalid CELLMAP_MAX_WORKERS=%r; falling back to default of 8", + max_workers_env, + ) + max_workers = 8 + if max_workers < 1: + logger.warning( + "CELLMAP_MAX_WORKERS=%r is less than 1; using 1 worker instead", + max_workers_env, + ) + max_workers = 1 + n_workers = min(n_datasets, max_workers) + # On Windows + TensorStore, avoid ThreadPoolExecutor to prevent crashes + # when computing class_counts (which may access TensorStore arrays). + # Use the same flag as CellMapDataset for consistency. + if _USE_IMMEDIATE_EXECUTOR: + # Sequential computation to avoid Windows+TensorStore crashes + logger.info( + "Using sequential computation for class counts (Windows+TensorStore)" + ) + for ds in tqdm(self.datasets, desc="Gathering class counts"): + ds_counts = ds.class_counts + for c in classes: + if c in ds_counts["totals"]: + class_counts["totals"][c] += ds_counts["totals"][c] + class_counts["totals"][c + "_bg"] += ds_counts["totals"][ + c + "_bg" + ] + return class_counts + + # Parallel computation for non-Windows or non-TensorStore backends + with ThreadPoolExecutor(max_workers=n_workers) as pool: + futures = { + pool.submit(lambda ds=ds: ds.class_counts): ds for ds in self.datasets + } + with tqdm(total=n_datasets, desc="Gathering class counts") as pbar: + for future in as_completed(futures): + ds_counts = future.result() + for c in classes: + if c in ds_counts["totals"]: + class_counts["totals"][c] += ds_counts["totals"][c] + class_counts["totals"][c + "_bg"] += ds_counts["totals"][ + c + "_bg" + ] + pbar.update(1) return class_counts @cached_property diff --git a/src/cellmap_data/read_limiter.py b/src/cellmap_data/utils/read_limiter.py similarity index 100% rename from src/cellmap_data/read_limiter.py rename to src/cellmap_data/utils/read_limiter.py diff --git a/tests/demo_memory_fix.py b/tests/demo_memory_fix.py new file mode 100755 index 0000000..762942c --- /dev/null +++ b/tests/demo_memory_fix.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python +""" +Memory profiling demo for the CellMapImage array cache fix. + +Demonstrates two levels of profiling: + 1. Mock class — fast, no real data needed, shows the principle. + 2. Real CellMapImage — uses a temporary Zarr dataset to profile + actual xarray/TensorStore allocations. + +Profiling tools used: + - tracemalloc (built-in): snapshot comparison shows *what* is growing, + not just peak usage. + - objgraph (optional, pip install objgraph): counts live Python objects + by type, confirming whether xarray DataArrays accumulate. + +Usage: + python tests/demo_memory_fix.py + DEMO_ITERS=200 python tests/demo_memory_fix.py +""" + +import gc +import io +import os +import sys +import tempfile +import tracemalloc +from pathlib import Path + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +import numpy as np + +try: + import objgraph + + HAS_OBJGRAPH = True +except ImportError: + HAS_OBJGRAPH = False + + +# --------------------------------------------------------------------------- +# Profiling helpers +# --------------------------------------------------------------------------- + + +def profile_iters(label, call_fn, iterations=100, snapshot_every=25): + """ + Run call_fn(i) for `iterations` steps and track memory growth. + + Prints a tracemalloc snapshot diff every `snapshot_every` steps showing + which allocation sites are growing (not just peak usage). If objgraph is + available, also prints live object-type counts so you can confirm whether + xarray DataArrays or numpy arrays are accumulating. + + Args: + label: Description printed as the section header. + call_fn: Callable taking iteration index, e.g. lambda i: img[center]. + iterations: Total number of iterations to run. + snapshot_every: How often to print an intermediate snapshot diff. + """ + print(f"\n{'─' * 64}") + print(f" {label}") + print(f"{'─' * 64}") + + gc.collect() + tracemalloc.start() + baseline = tracemalloc.take_snapshot() + + # objgraph.show_growth() tracks growth relative to the previous call; + # calling it once here establishes the baseline object counts. + if HAS_OBJGRAPH: + objgraph.show_growth( + limit=10, file=io.StringIO() + ) # prime state, discard output + + for i in range(iterations): + call_fn(i) + + if (i + 1) % snapshot_every == 0: + gc.collect() + snap = tracemalloc.take_snapshot() + stats = snap.compare_to(baseline, "lineno") + growing = [s for s in stats if s.size_diff > 0] + + print(f"\n [iter {i + 1}/{iterations}] Allocations grown vs. baseline:") + if growing: + for s in growing[:6]: + kb = s.size_diff / 1024 + loc = s.traceback[0] + print(f" {kb:+8.1f} KB {loc.filename}:{loc.lineno}") + else: + print(" (none — memory is stable)") + + if HAS_OBJGRAPH: + print( + f"\n [iter {i + 1}/{iterations}] New object types since last check:" + ) + objgraph.show_growth(limit=5, shortnames=False) + + current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + print(f"\n Summary — current: {current / 1024:.1f} KB, peak: {peak / 1024:.1f} KB") + + +# --------------------------------------------------------------------------- +# Section 1: Mock demo (no real data needed) +# --------------------------------------------------------------------------- + + +class _MockCacheUser: + """ + Minimal stand-in simulating CellMapImage's cached_property array pattern. + + Each __getitem__ allocates a new array into self._array_cache, mirroring + how CellMapImage builds an xarray DataArray on every access. With + clear_cache=True the cache is dropped immediately (the fix); without it + the reference accumulates. + """ + + def __init__(self, shape=(512, 512)): + self.shape = shape + self._array_cache = None + + def _clear_array_cache(self): + self._array_cache = None + + def __getitem__(self, idx, clear_cache=True): + self._array_cache = np.ones(self.shape, dtype=np.float32) + result = self._array_cache + if clear_cache: + self._clear_array_cache() + return result + + +def run_mock_demo(iterations): + print("\n" + "=" * 64) + print("SECTION 1: Mock demo (no real data, illustrates the principle)") + print("=" * 64) + + leaky = _MockCacheUser() + fixed = _MockCacheUser() + + profile_iters( + "WITHOUT cache clearing (leaky)", + lambda i: leaky.__getitem__(i, clear_cache=False), + iterations=iterations, + ) + profile_iters( + "WITH cache clearing (fixed)", + lambda i: fixed.__getitem__(i, clear_cache=True), + iterations=iterations, + ) + + +# --------------------------------------------------------------------------- +# Section 2: Real CellMapImage with a temporary Zarr store +# --------------------------------------------------------------------------- + + +def _build_test_zarr(root_path: Path, shape=(32, 32, 32), scale=(4.0, 4.0, 4.0)): + """Create a minimal OME-NGFF Zarr array for profiling.""" + import zarr + from pydantic_ome_ngff.v04.axis import Axis + from pydantic_ome_ngff.v04.multiscale import ( + Dataset as MultiscaleDataset, + MultiscaleMetadata, + ) + from pydantic_ome_ngff.v04.transform import VectorScale + + root_path.mkdir(parents=True, exist_ok=True) + data = np.random.rand(*shape).astype(np.float32) + store = zarr.DirectoryStore(str(root_path)) + root = zarr.group(store=store, overwrite=True) + chunks = tuple(min(16, s) for s in shape) + root.create_dataset("s0", data=data, chunks=chunks, overwrite=True) + + axes = [Axis(name=n, type="space", unit="nanometer") for n in ["z", "y", "x"]] + datasets = ( + MultiscaleDataset( + path="s0", + coordinateTransformations=(VectorScale(type="scale", scale=scale),), + ), + ) + root.attrs["multiscales"] = [ + MultiscaleMetadata( + version="0.4", name="test", axes=axes, datasets=datasets + ).model_dump(mode="json", exclude_none=True) + ] + return str(root_path) + + +def run_real_demo(iterations): + print("\n" + "=" * 64) + print("SECTION 2: Real CellMapImage with a temporary Zarr dataset") + print("=" * 64) + + try: + from cellmap_data.image import CellMapImage + except ImportError as e: + print(f"\n Skipping — could not import CellMapImage: {e}") + return + + try: + with tempfile.TemporaryDirectory() as tmp: + # Larger array so each DataArray is meaningfully sized (~2 MB) + shape = (64, 64, 64) + scale = [4.0, 4.0, 4.0] + voxel_shape = [16, 16, 16] + img_path = _build_test_zarr( + Path(tmp) / "raw", shape=shape, scale=tuple(scale) + ) + + # Volume spans 0–256 nm per axis; vary centers to exercise interp/reindex + rng = np.random.default_rng(42) + half = voxel_shape[0] * scale[0] / 2 # 32 nm margin + lo, hi = half, shape[0] * scale[0] - half # 32 to 224 nm + + def random_center(i): + coords = rng.uniform(lo, hi, size=3) + return { + "z": float(coords[0]), + "y": float(coords[1]), + "x": float(coords[2]), + } + + def make_image(): + return CellMapImage( + path=img_path, + target_class="raw", + target_scale=scale, + target_voxel_shape=voxel_shape, + device="cpu", + ) + + # Warmup: load all heavy imports and initialize TensorStore context + # before profiling either mode, so the comparison is not confounded + # by import costs. + print("\n Warming up (pre-loading imports and TensorStore context)...") + _warmup = make_image() + for _ in range(5): + _warmup[{"z": 128.0, "y": 128.0, "x": 128.0}] + del _warmup + gc.collect() + print(" Done.\n") + + # Leaky first (no imports to pay), then fixed — equal footing. + img_leaky = make_image() + img_leaky._clear_array_cache = lambda: None + profile_iters( + "CellMapImage — WITHOUT cache clearing (leaky)", + lambda i: img_leaky[random_center(i)], + iterations=iterations, + ) + + img_fixed = make_image() + profile_iters( + "CellMapImage — WITH cache clearing (fixed)", + lambda i: img_fixed[random_center(i)], + iterations=iterations, + ) + + except Exception as e: + print(f"\n Error during real demo: {e}") + raise + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + iterations = int(os.environ.get("DEMO_ITERS", "100")) + + print("=" * 64) + print("CellMapImage Memory Profiling Demo") + print("=" * 64) + print( + f"\n iterations : {iterations} (set DEMO_ITERS env var to change)\n" + f" tracemalloc: built-in\n" + f" objgraph : {'available' if HAS_OBJGRAPH else 'not installed — pip install objgraph'}" + ) + + run_mock_demo(iterations=iterations) + run_real_demo(iterations=iterations) + + print("\n" + "=" * 64) + print("Done.") + print("=" * 64) + + +if __name__ == "__main__": + main() diff --git a/tests/test_image_edge_cases.py b/tests/test_image_edge_cases.py index 3004e9b..02c6e94 100644 --- a/tests/test_image_edge_cases.py +++ b/tests/test_image_edge_cases.py @@ -376,3 +376,75 @@ def test_nan_pad_value(self, test_zarr_image): ) assert np.isnan(image.pad_value) + + # ----------------------------------------------------------------------- + # coord_offsets caching + # ----------------------------------------------------------------------- + + def test_coord_offsets_is_cached_property(self, test_zarr_image): + """coord_offsets must use @cached_property, not a manual null-check pattern. + + Verifies: (a) the returned dict has the expected axes, (b) successive + accesses return the exact same objects (cached, not recomputed), and + (c) the offsets are symmetric around zero for each axis. + """ + path, _ = test_zarr_image + image = CellMapImage( + path=path, + target_class="test_class", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + ) + + offsets1 = image.coord_offsets + offsets2 = image.coord_offsets + + # Cached: same object returned on every access + assert offsets1 is offsets2 + + # Stored in __dict__ (cached_property, not regular property) + assert "coord_offsets" in image.__dict__ + + # Correct axes present + for axis in image.axes: + assert axis in offsets1 + arr = offsets1[axis] + assert len(arr) == image.output_shape[axis] + # Symmetric around zero within float tolerance + assert abs(arr[0] + arr[-1]) < 1e-9 + + def test_coord_offsets_not_cleared_by_array_cache_clear(self, test_zarr_image): + """_clear_array_cache must only clear 'array', leaving coord_offsets intact.""" + path, _ = test_zarr_image + image = CellMapImage( + path=path, + target_class="test_class", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + ) + + offsets_before = image.coord_offsets # populate cache + assert "coord_offsets" in image.__dict__ + + image._clear_array_cache() + + # coord_offsets must still be cached after cache clear + assert "coord_offsets" in image.__dict__ + assert image.coord_offsets is offsets_before + + def test_coord_offsets_values_match_output_size_and_scale(self, test_zarr_image): + """coord_offsets values must span exactly [-output_size/2+scale/2, output_size/2-scale/2].""" + path, _ = test_zarr_image + image = CellMapImage( + path=path, + target_class="test_class", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + ) + + for axis in image.axes: + arr = image.coord_offsets[axis] + expected_lo = -image.output_size[axis] / 2 + image.scale[axis] / 2 + expected_hi = image.output_size[axis] / 2 - image.scale[axis] / 2 + assert abs(arr[0] - expected_lo) < 1e-9 + assert abs(arr[-1] - expected_hi) < 1e-9 diff --git a/tests/test_init_optimizations.py b/tests/test_init_optimizations.py new file mode 100644 index 0000000..f2ab514 --- /dev/null +++ b/tests/test_init_optimizations.py @@ -0,0 +1,572 @@ +""" +Tests for initialization optimizations added to CellMapDataset and +CellMapMultiDataset. + +Covers: + - force_has_data=True sets has_data immediately (no class_counts read) + - bounding_box / sampling_box parallel computation: correctness and cleanup + - CellMapMultiDataset.class_counts parallel execution: correct aggregation, + exception propagation, CELLMAP_MAX_WORKERS env-var respected + - _ImmediateExecutor: submit/map correctness (Windows+TensorStore drop-in) + - Immediate executor code paths in bounding_box, sampling_box, and + CellMapMultiDataset.class_counts (simulated via monkeypatching) + - Consistency: dataset.py and multidataset.py share the same + _USE_IMMEDIATE_EXECUTOR flag +""" + +from unittest.mock import PropertyMock, patch + +import pytest + +from cellmap_data import CellMapDataset, CellMapMultiDataset +from cellmap_data.image import CellMapImage + +from .test_helpers import create_test_dataset + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def single_dataset_config(tmp_path): + return create_test_dataset( + tmp_path, + raw_shape=(32, 32, 32), + num_classes=2, + raw_scale=(4.0, 4.0, 4.0), + seed=0, + ) + + +@pytest.fixture +def multi_source_dataset(tmp_path): + """Dataset with two input arrays and two target arrays (four CellMapImage + objects), so the parallel bounding_box / sampling_box paths receive more + than one source to map over.""" + config = create_test_dataset( + tmp_path, + raw_shape=(32, 32, 32), + num_classes=2, + raw_scale=(4.0, 4.0, 4.0), + seed=7, + ) + return CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={ + "raw_4nm": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}, + "raw_8nm": {"shape": (8, 8, 8), "scale": (8.0, 8.0, 8.0)}, + }, + target_arrays={ + "gt_4nm": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}, + "gt_8nm": {"shape": (8, 8, 8), "scale": (8.0, 8.0, 8.0)}, + }, + force_has_data=True, + ) + + +@pytest.fixture +def three_datasets(tmp_path): + datasets = [] + for i in range(3): + config = create_test_dataset( + tmp_path / f"ds_{i}", + raw_shape=(32, 32, 32), + num_classes=2, + raw_scale=(4.0, 4.0, 4.0), + seed=i, + ) + ds = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + ) + datasets.append(ds) + return datasets + + +# --------------------------------------------------------------------------- +# force_has_data +# --------------------------------------------------------------------------- + + +class TestForceHasData: + """force_has_data=True should set has_data=True at construction time + without ever accessing CellMapImage.class_counts.""" + + def test_has_data_true_when_force_set(self, single_dataset_config): + """has_data is True immediately after __init__ when force_has_data=True.""" + config = single_dataset_config + dataset = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + ) + assert dataset.has_data is True + + def test_class_counts_not_accessed_when_force_has_data(self, single_dataset_config): + """CellMapImage.class_counts must never be accessed during __init__ + when force_has_data=True.""" + config = single_dataset_config + with patch.object( + CellMapImage, "class_counts", new_callable=PropertyMock, return_value=100.0 + ) as mock_counts: + CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + ) + mock_counts.assert_not_called() + + def test_class_counts_accessed_without_force_has_data(self, single_dataset_config): + """Without force_has_data, class_counts IS accessed (inverse check).""" + config = single_dataset_config + with patch.object( + CellMapImage, "class_counts", new_callable=PropertyMock, return_value=100.0 + ) as mock_counts: + CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=False, + ) + mock_counts.assert_called() + + def test_has_data_false_without_force_for_empty_data(self, tmp_path): + """Without force_has_data and with all-zero target data, has_data=False.""" + import numpy as np + from .test_helpers import create_test_zarr_array, create_test_image_data + + # Raw array + raw_data = create_test_image_data((16, 16, 16), pattern="random") + create_test_zarr_array(tmp_path / "dataset.zarr" / "raw", raw_data) + + # All-zero target → class_counts == 0 → has_data stays False + zero_data = np.zeros((16, 16, 16), dtype=np.uint8) + create_test_zarr_array( + tmp_path / "dataset.zarr" / "class_0", + zero_data, + absent=zero_data.size, # all absent + ) + + dataset = CellMapDataset( + raw_path=str(tmp_path / "dataset.zarr" / "raw"), + target_path=str(tmp_path / "dataset.zarr" / "[class_0]"), + classes=["class_0"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (1.0, 1.0, 1.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (1.0, 1.0, 1.0)}}, + force_has_data=False, + ) + assert not dataset.has_data + + +# --------------------------------------------------------------------------- +# bounding_box / sampling_box parallel computation +# --------------------------------------------------------------------------- + + +class TestParallelBoundingBox: + """bounding_box and sampling_box must give correct results when computed + in parallel across multiple CellMapImage sources.""" + + def test_bounding_box_correct_with_multiple_sources(self, multi_source_dataset): + bbox = multi_source_dataset.bounding_box + assert isinstance(bbox, dict) + for axis in multi_source_dataset.axis_order: + assert axis in bbox + lo, hi = bbox[axis] + assert lo <= hi + + def test_sampling_box_correct_with_multiple_sources(self, multi_source_dataset): + sbox = multi_source_dataset.sampling_box + assert isinstance(sbox, dict) + for axis in multi_source_dataset.axis_order: + assert axis in sbox + assert len(sbox[axis]) == 2 + + def test_bounding_box_consistent_with_single_source(self, single_dataset_config): + """Sequential vs. parallel should yield the same bounding box.""" + config = single_dataset_config + + def make_dataset(): + return CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + ) + + ds1 = make_dataset() + ds2 = make_dataset() + + bbox1 = ds1.bounding_box + bbox2 = ds2.bounding_box + + for axis in ds1.axis_order: + assert pytest.approx(bbox1[axis][0]) == bbox2[axis][0] + assert pytest.approx(bbox1[axis][1]) == bbox2[axis][1] + + def test_sampling_box_inside_bounding_box(self, multi_source_dataset): + """The sampling box must be a sub-region of (or equal to) the bounding box.""" + bbox = multi_source_dataset.bounding_box + sbox = multi_source_dataset.sampling_box + for axis in multi_source_dataset.axis_order: + assert sbox[axis][0] >= bbox[axis][0] - 1e-9 + assert sbox[axis][1] <= bbox[axis][1] + 1e-9 + + def test_bounding_box_pool_does_not_leak_threads(self, single_dataset_config): + """Accessing bounding_box twice on fresh datasets should not raise even + if the pool from the first call was already shut down.""" + config = single_dataset_config + + for _ in range(2): + ds = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + ) + bbox = ds.bounding_box + assert bbox is not None + + +# --------------------------------------------------------------------------- +# CellMapMultiDataset.class_counts parallel execution +# --------------------------------------------------------------------------- + + +class TestMultiDatasetClassCountsParallel: + """Parallel class_counts must aggregate correctly and behave robustly.""" + + def test_totals_equal_sum_of_individual_datasets(self, three_datasets): + """Aggregated totals must equal the element-wise sum of each dataset's + class_counts["totals"].""" + classes = ["class_0", "class_1"] + multi = CellMapMultiDataset( + classes=classes, + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + datasets=three_datasets, + ) + + # Compute expected totals by summing across individual datasets. + expected: dict[str, float] = {c: 0.0 for c in classes} + expected.update({c + "_bg": 0.0 for c in classes}) + for ds in three_datasets: + for key in expected: + expected[key] += ds.class_counts["totals"].get(key, 0.0) + + actual = multi.class_counts["totals"] + for key, val in expected.items(): + assert pytest.approx(actual[key], rel=1e-6) == val + + def test_class_counts_has_totals_key(self, three_datasets): + multi = CellMapMultiDataset( + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + datasets=three_datasets, + ) + counts = multi.class_counts + assert "totals" in counts + for c in ["class_0", "class_1", "class_0_bg", "class_1_bg"]: + assert c in counts["totals"] + + def test_exception_from_dataset_propagates(self, three_datasets): + """If any dataset's class_counts raises, the exception must propagate + out of CellMapMultiDataset.class_counts (via future.result()).""" + multi = CellMapMultiDataset( + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + datasets=three_datasets, + ) + + with patch.object( + CellMapDataset, + "class_counts", + new_callable=PropertyMock, + side_effect=RuntimeError("simulated failure"), + ): + with pytest.raises(RuntimeError, match="simulated failure"): + _ = multi.class_counts + + def test_max_workers_env_var_respected(self, three_datasets, monkeypatch): + """CELLMAP_MAX_WORKERS is the cap on the number of worker threads.""" + monkeypatch.setenv("CELLMAP_MAX_WORKERS", "1") + + multi = CellMapMultiDataset( + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + datasets=three_datasets, + ) + # Should still produce correct results with a single worker + counts = multi.class_counts + assert "totals" in counts + + def test_single_dataset_multidataset(self, tmp_path): + """Edge case: a multi-dataset with one child returns that child's counts.""" + config = create_test_dataset( + tmp_path, raw_shape=(32, 32, 32), num_classes=2, raw_scale=(4.0, 4.0, 4.0) + ) + ds = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + ) + multi = CellMapMultiDataset( + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + datasets=[ds], + ) + multi_totals = multi.class_counts["totals"] + ds_totals = ds.class_counts["totals"] + for c in ["class_0", "class_1"]: + assert pytest.approx(multi_totals[c]) == ds_totals.get(c, 0.0) + assert pytest.approx(multi_totals[c + "_bg"]) == ds_totals.get( + c + "_bg", 0.0 + ) + + def test_empty_classes_list(self, three_datasets): + """An empty classes list produces an empty totals dict without error.""" + multi = CellMapMultiDataset( + classes=[], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={}, + datasets=three_datasets, + ) + counts = multi.class_counts + assert counts["totals"] == {} + + +# --------------------------------------------------------------------------- +# _ImmediateExecutor unit tests +# --------------------------------------------------------------------------- + + +class TestImmediateExecutor: + """Unit tests for _ImmediateExecutor. + + _ImmediateExecutor is the Windows+TensorStore drop-in that runs every + submitted callable synchronously in the calling thread. It must satisfy + the same interface as ThreadPoolExecutor so all existing call sites + (submit, map, as_completed, shutdown) work without modification. + """ + + @pytest.fixture + def executor(self): + from cellmap_data.dataset import _ImmediateExecutor + + return _ImmediateExecutor() + + def test_submit_executes_synchronously(self, executor): + """submit() runs the callable before returning; the future is already done.""" + calls = [] + future = executor.submit(calls.append, 99) + assert future.done(), "Future should be resolved immediately" + assert calls == [99], "Callable should have run synchronously" + + def test_submit_returns_correct_result(self, executor): + """submit() stores the return value in the future.""" + future = executor.submit(lambda x, y: x + y, 3, 4) + assert future.result() == 7 + + def test_submit_captures_exception(self, executor): + """Exceptions raised by the callable are stored, not propagated.""" + future = executor.submit(lambda: 1 / 0) + assert future.exception() is not None + assert isinstance(future.exception(), ZeroDivisionError) + + def test_map_returns_results_in_order(self, executor): + """map() returns results in the same order as the input iterable.""" + results = list(executor.map(lambda x: x * 2, [1, 2, 3, 4])) + assert results == [2, 4, 6, 8] + + def test_map_with_lambda(self, executor): + """map() works with lambda functions, matching the bounding_box usage.""" + items = [{"v": i} for i in range(5)] + results = list(executor.map(lambda d: d["v"], items)) + assert results == list(range(5)) + + def test_map_propagates_exception(self, executor): + """Exceptions from map() propagate when the result is consumed.""" + with pytest.raises(ZeroDivisionError): + list(executor.map(lambda x: 1 / x, [1, 0, 2])) + + def test_shutdown_is_noop(self, executor): + """shutdown() must not raise even when called multiple times.""" + executor.shutdown(wait=True) + executor.shutdown(wait=False, cancel_futures=True) + + def test_as_completed_works_with_submit(self, executor): + """Futures from submit() are compatible with as_completed().""" + from concurrent.futures import as_completed + + futures = [executor.submit(lambda i=i: i * 3, i) for i in range(5)] + results = {f.result() for f in as_completed(futures)} + assert results == {0, 3, 6, 9, 12} + + def test_is_executor_subclass(self): + """_ImmediateExecutor must be a subclass of concurrent.futures.Executor + so it satisfies the Executor interface including map().""" + from concurrent.futures import Executor + + from cellmap_data.dataset import _ImmediateExecutor + + assert issubclass(_ImmediateExecutor, Executor) + + +# --------------------------------------------------------------------------- +# Immediate executor code paths (simulated via monkeypatching) +# --------------------------------------------------------------------------- + + +class TestImmediateExecutorPaths: + """Verify that bounding_box, sampling_box, and CellMapMultiDataset.class_counts + work correctly when _USE_IMMEDIATE_EXECUTOR is True. + + These tests simulate the Windows+TensorStore environment on any platform + by monkeypatching the module-level flag and singleton executor. + """ + + @pytest.fixture + def patched_immediate(self, monkeypatch): + """Patch dataset module to act as if running on Windows+TensorStore.""" + import cellmap_data.dataset as ds_module + from cellmap_data.dataset import _ImmediateExecutor + + monkeypatch.setattr(ds_module, "_USE_IMMEDIATE_EXECUTOR", True) + monkeypatch.setattr(ds_module, "_IMMEDIATE_EXECUTOR", _ImmediateExecutor()) + + def test_bounding_box_uses_immediate_executor( + self, single_dataset_config, patched_immediate + ): + """bounding_box must work via executor.map() when using _ImmediateExecutor.""" + config = single_dataset_config + ds = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + ) + from cellmap_data.dataset import _ImmediateExecutor + + assert isinstance(ds.executor, _ImmediateExecutor) + bbox = ds.bounding_box + assert isinstance(bbox, dict) + for axis in ds.axis_order: + assert axis in bbox + lo, hi = bbox[axis] + assert lo <= hi + + def test_sampling_box_uses_immediate_executor( + self, single_dataset_config, patched_immediate + ): + """sampling_box must work via executor.map() when using _ImmediateExecutor.""" + config = single_dataset_config + ds = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + ) + sbox = ds.sampling_box + assert isinstance(sbox, dict) + for axis in ds.axis_order: + assert axis in sbox + assert len(sbox[axis]) == 2 + + def test_getitem_uses_immediate_executor( + self, single_dataset_config, patched_immediate + ): + """__getitem__ must work when _ImmediateExecutor is active.""" + config = single_dataset_config + ds = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + ) + result = ds[0] + assert isinstance(result, dict) + assert "raw" in result + + def test_multidataset_class_counts_sequential_path( + self, three_datasets, monkeypatch + ): + """CellMapMultiDataset.class_counts takes the sequential path when + _USE_IMMEDIATE_EXECUTOR is True (shared flag from dataset.py).""" + import cellmap_data.multidataset as md_module + + monkeypatch.setattr(md_module, "_USE_IMMEDIATE_EXECUTOR", True) + + multi = CellMapMultiDataset( + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + datasets=three_datasets, + ) + counts = multi.class_counts + assert "totals" in counts + for c in ["class_0", "class_1", "class_0_bg", "class_1_bg"]: + assert c in counts["totals"] + + +# --------------------------------------------------------------------------- +# Consistency: dataset.py and multidataset.py share _USE_IMMEDIATE_EXECUTOR +# --------------------------------------------------------------------------- + + +class TestImmediateExecutorFlagConsistency: + """The _USE_IMMEDIATE_EXECUTOR flag must be sourced from dataset.py in + both dataset.py and multidataset.py so they always agree on whether + to use threads.""" + + def test_flag_values_match_across_modules(self): + """Both modules read the same flag value at import time.""" + import cellmap_data.dataset as ds_module + import cellmap_data.multidataset as md_module + + assert ds_module._USE_IMMEDIATE_EXECUTOR == md_module._USE_IMMEDIATE_EXECUTOR + + def test_multidataset_imports_flag_from_dataset(self): + """multidataset module must expose _USE_IMMEDIATE_EXECUTOR (imported + from dataset), not define its own copy.""" + import inspect + + import cellmap_data.multidataset as md_module + + assert hasattr( + md_module, "_USE_IMMEDIATE_EXECUTOR" + ), "multidataset must import _USE_IMMEDIATE_EXECUTOR from dataset" + # Verify the source: the flag in multidataset should be the same + # object as the one in dataset (True/False booleans are singletons). + import cellmap_data.dataset as ds_module + + assert md_module._USE_IMMEDIATE_EXECUTOR is ds_module._USE_IMMEDIATE_EXECUTOR diff --git a/tests/test_memory_management.py b/tests/test_memory_management.py new file mode 100644 index 0000000..7cf889b --- /dev/null +++ b/tests/test_memory_management.py @@ -0,0 +1,217 @@ +""" +Tests for memory management in CellMapImage. + +Specifically tests the array cache clearing mechanism to prevent memory leaks. +""" + +import pytest +from cellmap_data import CellMapImage +from .test_helpers import create_test_image_data, create_test_zarr_array + + +class TestMemoryManagement: + """Test memory management features.""" + + @pytest.fixture + def test_zarr_image(self, tmp_path): + """Create a test Zarr image.""" + data = create_test_image_data((32, 32, 32), pattern="gradient") + path = tmp_path / "test_image.zarr" + create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0)) + return str(path), data + + def test_array_cache_cleared_after_getitem(self, test_zarr_image): + """Test that array cache is cleared after __getitem__ to prevent memory leaks.""" + path, _ = test_zarr_image + + image = CellMapImage( + path=path, + target_class="test", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + axis_order="zyx", + ) + + # Access array to populate cache + _ = image.array + assert "array" in image.__dict__, "Array should be cached after first access" + + # Call __getitem__ which should clear the cache + center = {"z": 64.0, "y": 64.0, "x": 64.0} + _ = image[center] + + # Check that cache was cleared + assert ( + "array" not in image.__dict__ + ), "Array cache should be cleared after __getitem__" + + def test_array_cache_repopulates_after_clearing(self, test_zarr_image): + """Test that array cache can be repopulated after being cleared.""" + path, _ = test_zarr_image + + image = CellMapImage( + path=path, + target_class="test", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + axis_order="zyx", + ) + + # First access + center = {"z": 64.0, "y": 64.0, "x": 64.0} + data1 = image[center] + + # Array cache should be cleared + assert "array" not in image.__dict__ + + # Second access - should work without errors (cache will be repopulated) + data2 = image[center] + + # Both should produce valid tensors + assert data1.shape == data2.shape + assert data1.dtype == data2.dtype + + def test_clear_array_cache_method(self, test_zarr_image): + """Test the _clear_array_cache method directly.""" + path, _ = test_zarr_image + + image = CellMapImage( + path=path, + target_class="test", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + ) + + # Populate cache + _ = image.array + assert "array" in image.__dict__ + + # Clear cache + image._clear_array_cache() + assert "array" not in image.__dict__ + + # Clearing when not cached should not raise an error + image._clear_array_cache() # Should be a no-op + + def test_multiple_getitem_calls_clear_cache_each_time(self, test_zarr_image): + """Test that cache is cleared on every __getitem__ call.""" + path, _ = test_zarr_image + + image = CellMapImage( + path=path, + target_class="test", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + ) + + centers = [ + {"z": 48.0, "y": 48.0, "x": 48.0}, + {"z": 64.0, "y": 64.0, "x": 64.0}, + {"z": 80.0, "y": 80.0, "x": 80.0}, + ] + + for center in centers: + _ = image[center] + # Cache should be cleared after each call + assert ( + "array" not in image.__dict__ + ), f"Array cache should be cleared after accessing center {center}" + + def test_cache_clearing_with_spatial_transforms(self, test_zarr_image): + """Test that cache is cleared even with spatial transforms.""" + path, _ = test_zarr_image + + image = CellMapImage( + path=path, + target_class="test", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + ) + + # Set spatial transforms + image.set_spatial_transforms({"mirror": {"x": True}, "rotate": {"z": 15}}) + + center = {"z": 64.0, "y": 64.0, "x": 64.0} + _ = image[center] + + # Cache should still be cleared + assert "array" not in image.__dict__ + + def test_cache_clearing_with_value_transforms(self, test_zarr_image): + """Test that cache is cleared when value transforms are applied.""" + path, _ = test_zarr_image + + def normalize(x): + return x / 255.0 + + image = CellMapImage( + path=path, + target_class="test", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + value_transform=normalize, + ) + + center = {"z": 64.0, "y": 64.0, "x": 64.0} + _ = image[center] + + # Cache should be cleared + assert "array" not in image.__dict__ + + def test_simulated_training_loop_memory(self, test_zarr_image): + """ + Simulate a training loop to verify cache is cleared on each iteration. + + This test simulates the memory leak scenario described in the issue: + repeated calls to __getitem__ should not accumulate memory from cached arrays. + """ + path, _ = test_zarr_image + + image = CellMapImage( + path=path, + target_class="test", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + ) + + # Simulate multiple training iterations + centers = [ + {"z": 48.0 + i * 4.0, "y": 48.0 + i * 4.0, "x": 48.0 + i * 4.0} + for i in range(10) + ] + + for i, center in enumerate(centers): + _ = image[center] + + # After each iteration, array cache should be cleared + assert ( + "array" not in image.__dict__ + ), f"Iteration {i}: Array cache should be cleared to prevent memory leak" + + def test_cache_clearing_with_interpolation(self, tmp_path): + """ + Test cache clearing when interpolation is used (the main memory leak scenario). + + When coords require interpolation (not simple float/int), the array.interp() + method creates intermediate arrays that could accumulate memory. + """ + data = create_test_image_data((32, 32, 32), pattern="gradient") + path = tmp_path / "test_interp.zarr" + create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0)) + + image = CellMapImage( + path=str(path), + target_class="test", + target_scale=(4.0, 4.0, 4.0), + target_voxel_shape=(8, 8, 8), + interpolation="linear", # Use linear interpolation to trigger interp() + ) + + # Use spatial transforms to trigger the interpolation code path + image.set_spatial_transforms({"rotate": {"z": 15}}) + + center = {"z": 64.0, "y": 64.0, "x": 64.0} + _ = image[center] + + # Cache should be cleared even after interpolation + assert "array" not in image.__dict__ diff --git a/tests/test_windows_stress.py b/tests/test_windows_stress.py index 4ae8c05..28fdcaf 100644 --- a/tests/test_windows_stress.py +++ b/tests/test_windows_stress.py @@ -24,7 +24,7 @@ import pytest from cellmap_data import CellMapDataset -from cellmap_data.read_limiter import ( +from cellmap_data.utils.read_limiter import ( MAX_CONCURRENT_READS, _read_semaphore, limit_tensorstore_reads,