Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
610f926
Initial plan
Copilot Feb 25, 2026
3ee90bc
Fix memory leak by clearing array cache after data retrieval
Copilot Feb 25, 2026
a1924d1
Add documentation for array cache clearing and more comprehensive tests
Copilot Feb 25, 2026
7031c62
Clarify that TensorStore cache is preserved when clearing array wrapper
Copilot Feb 25, 2026
df15e57
Add documentation and demo for memory leak fix
Copilot Feb 25, 2026
5ea10f2
Update demo_memory_fix.py
rhoadesScholar Feb 25, 2026
b29b4ec
fix: resolve memory leak by clearing array cache in CellMapImage
rhoadesScholar Feb 26, 2026
aae7e30
fix: optimize device handling and refactor properties in CellMapDatas…
rhoadesScholar Feb 26, 2026
ae9d561
black format
rhoadesScholar Feb 26, 2026
35b3665
fix: optimize bounding box and sampling box computations in CellMapDa…
rhoadesScholar Feb 26, 2026
ab0fbea
black format
rhoadesScholar Feb 26, 2026
ecf1631
Update tests/test_memory_management.py
rhoadesScholar Feb 26, 2026
5577a12
Update src/cellmap_data/multidataset.py
rhoadesScholar Feb 26, 2026
709451f
fix: address PR review feedback - add try/finally for cache clearing …
Copilot Feb 26, 2026
24fb1b6
black format
rhoadesScholar Feb 26, 2026
8438cde
fix: implement _ImmediateExecutor to prevent crashes on Windows+Tenso…
rhoadesScholar Feb 26, 2026
a6e96e8
black format
rhoadesScholar Feb 26, 2026
ce9054a
fix: properly emit warnings using logger instead of constructing unus…
Copilot Feb 26, 2026
b065568
Update multidataset.py
rhoadesScholar Feb 26, 2026
0bae283
Update multidataset.py
rhoadesScholar Feb 26, 2026
907b28a
black format
rhoadesScholar Feb 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,7 @@ clean/
.pytest_cache/
__pycache__/
mypy_cache/
.claude/
.claude/
*.out
*.log
*.err
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ dev = [
"twine",
"hatch",
"python-semantic-release",
"objgraph",
]
all = [
"cellmap-data[dev,test]",
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .empty_image import EmptyImage
from .image import CellMapImage
from .mutable_sampler import MutableSubsetRandomSampler
from .read_limiter import MAX_CONCURRENT_READS, limit_tensorstore_reads
from .utils.read_limiter import MAX_CONCURRENT_READS, limit_tensorstore_reads
from .utils import get_sliced_shape, is_array_2D, min_redundant_inds, split_target_path

logger = logging.getLogger(__name__)
Expand Down
8 changes: 2 additions & 6 deletions src/cellmap_data/dataset_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def __init__(
self.target_array_writers[array_name] = self.get_target_array_writer(
array_name, array_info
)
self._device: str | torch.device = device if device is not None else "cpu"
if device is not None:
self._device = device
self.to(device, non_blocking=True)

@cached_property
Expand Down Expand Up @@ -237,11 +237,7 @@ def loader(
@property
def device(self) -> str | torch.device:
"""Returns the device for the dataset."""
try:
return self._device
except AttributeError:
self._device = "cpu"
return self._device
return self._device

def get_center(self, idx: int) -> dict[str, float]:
"""
Expand Down
71 changes: 54 additions & 17 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor:
if self.value_transform is not None:
data = self.value_transform(data)

# Clear cached array property to prevent memory accumulation from xarray
# operations (interp/reindex/sel) during training iterations. The array
# will be reopened on next access if needed.
self._clear_array_cache()

# Return data on CPU - let the DataLoader handle device transfer with streams
# This avoids redundant transfers and allows for optimized batch transfers
return data
Expand All @@ -138,6 +143,18 @@ def __repr__(self) -> str:
"""Returns a string representation of the CellMapImage object."""
return f"CellMapImage({self.array_path})"

def _clear_array_cache(self) -> None:
"""
Clear the cached xarray DataArray to release intermediate objects.

xarray operations (interp, reindex, sel) create intermediate arrays that
remain referenced through the DataArray. Clearing the cache after each
__getitem__ releases those references without closing the underlying
TensorStore handle, which is separately cached in _ts_store and reused.
"""
if "array" in self.__dict__:
del self.__dict__["array"]

@property
def coord_offsets(self) -> Mapping[str, np.ndarray]:
"""
Expand Down Expand Up @@ -223,9 +240,43 @@ def array_path(self) -> str:
"""Returns the path to the single-scale image array."""
return os.path.join(self.path, self.scale_level)

@cached_property
def _ts_store(self) -> ts.TensorStore: # type: ignore
"""
Opens and caches the TensorStore array handle.

ts.open() is called exactly once per CellMapImage instance and the
resulting handle is kept alive for the instance's lifetime. The handle
is lightweight (it holds a reference to the shared context and chunk
cache) and is safe to reuse across many __getitem__ calls.

Separating this from the `array` cached_property means that clearing
`array` after each __getitem__ (to release xarray intermediate objects)
does not trigger a new ts.open() call on the next access.
"""
spec = xt._zarr_spec_from_path(self.array_path)
array_future = ts.open(spec, read=True, write=False, context=self.context)
try:
return array_future.result()
except ValueError as e:
logger.warning(
"Failed to open with default driver: %s. Falling back to zarr3 driver.",
e,
)
spec["driver"] = "zarr3"
return ts.open(spec, read=True, write=False, context=self.context).result()

@cached_property
def array(self) -> xarray.DataArray:
"""Returns the image data as an xarray DataArray."""
"""
Returns the image data as an xarray DataArray.

This property is cached but is explicitly cleared after each __getitem__
call to release xarray intermediate objects (from interp/reindex/sel)
that would otherwise accumulate during training. Clearing it is cheap
because the underlying TensorStore handle is separately cached in
_ts_store and is not reopened.
"""
if (
os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore").lower()
!= "tensorstore"
Expand All @@ -235,22 +286,7 @@ def array(self) -> xarray.DataArray:
chunks="auto",
)
else:
# Construct an xarray with Tensorstore backend
spec = xt._zarr_spec_from_path(self.array_path)
array_future = ts.open(spec, read=True, write=False, context=self.context)
try:
array = array_future.result()
except ValueError as e:
logger.warning(
"Failed to open with default driver: %s. Falling back to zarr3 driver.",
e,
)
spec["driver"] = "zarr3"
array_future = ts.open(
spec, read=True, write=False, context=self.context
)
array = array_future.result()
data = xt._TensorStoreAdapter(array)
data = xt._TensorStoreAdapter(self._ts_store)
return xarray.DataArray(data=data, coords=self.full_coords)

@cached_property
Expand Down Expand Up @@ -324,6 +360,7 @@ def class_counts(self) -> float:
else:
raise ValueError("s0_scale not found")
except Exception as e:
# TODO: This fallback is very expensive, and ideally should be avoided. We should add a script to precompute class counts for all images and save them to the metadata to avoid this in the future.
logger.warning(
"Unable to get class counts for %s from metadata, "
"falling back to calculating from array. Error: %s, %s",
Expand Down
Loading