Skip to content

Commit 5d3bd03

Browse files
committed
fix: account for anndata bug
1 parent 1e4ae99 commit 5d3bd03

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning][].
1010

1111
## [0.0.4]
1212

13-
- Load into memory nullables/categoricals from `obs` by default when shuffling (i.e., no custom `load_adata` argument to {meth}`annbatch.Datasetcollection.add_adatas`)
13+
- Load into memory nullables/categoricals from `obs` by default when shuffling (i.e., no custom `load_adata` argument to {meth}`annbatch.DatasetCollection.add_adatas`)
1414

1515
## [0.0.3]
1616

src/annbatch/io.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ def _default_load_adata[T: zarr.Group | h5py.Group | PathLike[str] | str](x: T)
4141
group = x
4242
# -1 indicates that all of each `obs` column should just be loaded, but this is probably fine since it goes column by column and discards.
4343
# Only one column at a time will be loaded so we will hopefully pick up the benefit of loading into memory by the cache without having memory pressure.
44-
adata.obs = ad.experimental.read_elem_lazy(group["obs"], chunks=(-1,))
45-
for col in adata.obs.columns:
46-
# Nullables / categoricals have bad perforamnce characteristics when concatenating using dask
47-
if pd.api.types.is_extension_array_dtype(adata.obs[col].dtype):
48-
adata.obs[col] = adata.obs[col].data
44+
# https://github.com/scverse/anndata/pull/2307
45+
for attr in ["obs", "var"]:
46+
if len(getattr(adata, attr).columns) > 0:
47+
setattr(adata, attr, ad.experimental.read_elem_lazy(group[attr], chunks=(-1,)))
48+
for col in getattr(adata, attr).columns:
49+
# Nullables / categoricals have bad perforamnce characteristics when concatenating using dask
50+
if pd.api.types.is_extension_array_dtype(adata.obs[col].dtype):
51+
adata.obs[col] = adata.obs[col].data
4952
return adata
5053

5154

@@ -299,8 +302,6 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData:
299302

300303
return adata.to_memory()
301304

302-
return adata
303-
304305

305306
DATASET_PREFIX = "dataset"
306307

@@ -386,11 +387,11 @@ def is_empty(self) -> bool:
386387
)
387388

388389
@_with_settings
389-
def add_adatas[T: zarr.Group | h5py.Group | PathLike[str] | str](
390+
def add_adatas(
390391
self,
391-
adata_paths: Iterable[T],
392+
adata_paths: Iterable[zarr.Group | h5py.Group | PathLike[str] | str],
392393
*,
393-
load_adata: Callable[[T], ad.AnnData] = _default_load_adata,
394+
load_adata: Callable[[zarr.Group | h5py.Group | PathLike[str] | str], ad.AnnData] = _default_load_adata,
394395
var_subset: Iterable[str] | None = None,
395396
zarr_sparse_chunk_size: int = 32768,
396397
zarr_sparse_shard_size: int = 134_217_728,

0 commit comments

Comments
 (0)