Skip to content

Commit 6cc3afc

Browse files
Merge pull request #370 from scverse/faster-xenium
Xenium: cache zip store opening; vectorize cell id encoding
2 parents fb3af8b + 563693a commit 6cc3afc

File tree

2 files changed

+148
-66
lines changed

2 files changed

+148
-66
lines changed

src/spatialdata_io/readers/xenium.py

Lines changed: 132 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
import ome_types
1515
import packaging.version
1616
import pandas as pd
17+
import pyarrow.compute as pc
1718
import pyarrow.parquet as pq
1819
import tifffile
1920
import zarr
2021
from dask.dataframe import read_parquet
2122
from dask_image.imread import imread
2223
from geopandas import GeoDataFrame
23-
from pyarrow import Table
2424
from shapely import GeometryType, Polygon, from_ragged_array
2525
from spatialdata import SpatialData
2626
from spatialdata._core.query.relational_query import get_element_instances
@@ -44,6 +44,7 @@
4444
if TYPE_CHECKING:
4545
from collections.abc import Mapping
4646

47+
import pyarrow as pa
4748
from anndata import AnnData
4849
from spatialdata._types import ArrayLike
4950

@@ -69,6 +70,7 @@ def xenium(
6970
morphology_focus: bool = True,
7071
aligned_images: bool = True,
7172
cells_table: bool = True,
73+
n_jobs: int | None = None,
7274
gex_only: bool = True,
7375
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
7476
image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
@@ -121,6 +123,10 @@ def xenium(
121123
`False` and use the `xenium_aligned_image` function directly.
122124
cells_table
123125
Whether to read the cell annotations in the `AnnData` table.
126+
n_jobs
127+
.. deprecated::
128+
``n_jobs`` is not used anymore and will be removed in a future release. The reading time of shapes is now
129+
greatly improved and does not require parallelization.
124130
gex_only
125131
Whether to load only the "Gene Expression" feature type.
126132
imread_kwargs
@@ -153,6 +159,13 @@ def xenium(
153159
... )
154160
>>> sdata.write("path/to/data.zarr")
155161
"""
162+
if n_jobs is not None:
163+
warnings.warn(
164+
"The `n_jobs` parameter is deprecated and will be removed in a future release. "
165+
"The reading time of shapes is now greatly improved and does not require parallelization.",
166+
DeprecationWarning,
167+
stacklevel=2,
168+
)
156169
image_models_kwargs, labels_models_kwargs = _initialize_raster_models_kwargs(
157170
image_models_kwargs, labels_models_kwargs
158171
)
@@ -188,18 +201,42 @@ def xenium(
188201
else:
189202
table = None
190203

204+
# open cells.zarr.zip once and reuse across all functions that need it
205+
cells_zarr: zarr.Group | None = None
206+
need_cells_zarr = (
207+
nucleus_labels
208+
or cells_labels
209+
or (version is not None and version >= packaging.version.parse("2.0.0") and table is not None)
210+
)
211+
if need_cells_zarr:
212+
cells_zarr_store = zarr.storage.ZipStore(path / XeniumKeys.CELLS_ZARR, read_only=True)
213+
cells_zarr = zarr.open(cells_zarr_store, mode="r")
214+
215+
# pre-compute cell_id strings from the zarr once, to avoid redundant conversion
216+
# in both _get_cells_metadata_table_from_zarr and _get_labels_and_indices_mapping.
217+
cells_zarr_cell_id_str: np.ndarray | None = None
218+
if cells_zarr is not None and version is not None and version >= packaging.version.parse("1.3.0"):
219+
cell_id_raw = cells_zarr["cell_id"][...]
220+
cell_id_prefix, dataset_suffix = cell_id_raw[:, 0], cell_id_raw[:, 1]
221+
cells_zarr_cell_id_str = cell_id_str_from_prefix_suffix_uint32(cell_id_prefix, dataset_suffix)
222+
191223
if version is not None and version >= packaging.version.parse("2.0.0") and table is not None:
192-
cell_summary_table = _get_cells_metadata_table_from_zarr(path, XeniumKeys.CELLS_ZARR, specs)
193-
if not cell_summary_table[XeniumKeys.CELL_ID].equals(table.obs[XeniumKeys.CELL_ID]):
224+
assert cells_zarr is not None
225+
cell_summary_table = _get_cells_metadata_table_from_zarr(cells_zarr, specs, cells_zarr_cell_id_str)
226+
try:
227+
_assert_arrays_equal_sampled(
228+
cell_summary_table[XeniumKeys.CELL_ID].values, table.obs[XeniumKeys.CELL_ID].values
229+
)
230+
except AssertionError:
194231
warnings.warn(
195232
'The "cell_id" column in the cells metadata table does not match the "cell_id" column in the annotation'
196233
" table. This could be due to trying to read a new version that is not supported yet. Please "
197234
"report this issue.",
198235
UserWarning,
199236
stacklevel=2,
200237
)
201-
table.obs[XeniumKeys.Z_LEVEL] = cell_summary_table[XeniumKeys.Z_LEVEL]
202-
table.obs[XeniumKeys.NUCLEUS_COUNT] = cell_summary_table[XeniumKeys.NUCLEUS_COUNT]
238+
table.obs[XeniumKeys.Z_LEVEL] = cell_summary_table[XeniumKeys.Z_LEVEL].values
239+
table.obs[XeniumKeys.NUCLEUS_COUNT] = cell_summary_table[XeniumKeys.NUCLEUS_COUNT].values
203240

204241
polygons = {}
205242
labels = {}
@@ -220,6 +257,8 @@ def xenium(
220257
mask_index=0,
221258
labels_name="nucleus_labels",
222259
labels_models_kwargs=labels_models_kwargs,
260+
cells_zarr=cells_zarr,
261+
cell_id_str=None,
223262
)
224263
if cells_labels:
225264
labels["cell_labels"], cell_labels_indices_mapping = _get_labels_and_indices_mapping(
@@ -228,9 +267,15 @@ def xenium(
228267
mask_index=1,
229268
labels_name="cell_labels",
230269
labels_models_kwargs=labels_models_kwargs,
270+
cells_zarr=cells_zarr,
271+
cell_id_str=cells_zarr_cell_id_str,
231272
)
232273
if cell_labels_indices_mapping is not None and table is not None:
233-
if not pd.DataFrame.equals(cell_labels_indices_mapping["cell_id"], table.obs[str(XeniumKeys.CELL_ID)]):
274+
try:
275+
_assert_arrays_equal_sampled(
276+
cell_labels_indices_mapping["cell_id"].values, table.obs[str(XeniumKeys.CELL_ID)].values
277+
)
278+
except AssertionError:
234279
warnings.warn(
235280
"The cell_id column in the cell_labels_table does not match the cell_id column derived from the "
236281
"cell labels data. This could be due to trying to read a new version that is not supported yet. "
@@ -239,7 +284,7 @@ def xenium(
239284
stacklevel=2,
240285
)
241286
else:
242-
table.obs["cell_labels"] = cell_labels_indices_mapping["label_index"]
287+
table.obs["cell_labels"] = cell_labels_indices_mapping["label_index"].values
243288
if not cells_as_circles:
244289
table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "cell_labels"
245290

@@ -248,7 +293,7 @@ def xenium(
248293
path,
249294
XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
250295
specs,
251-
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
296+
idx=None,
252297
)
253298

254299
if cells_boundaries:
@@ -389,6 +434,13 @@ def filter(self, record: logging.LogRecord) -> bool:
389434
return _set_reader_metadata(sdata, "xenium")
390435

391436

437+
def _assert_arrays_equal_sampled(a: ArrayLike, b: ArrayLike, n: int = 100) -> None:
438+
"""Assert two arrays are equal by checking a random sample of entries."""
439+
assert len(a) == len(b), f"Array lengths differ: {len(a)} != {len(b)}"
440+
idx = np.random.default_rng(0).choice(len(a), size=min(n, len(a)), replace=False)
441+
np.testing.assert_array_equal(np.asarray(a[idx]), np.asarray(b[idx]))
442+
443+
392444
def _decode_cell_id_column(cell_id_column: pd.Series) -> pd.Series:
393445
if isinstance(cell_id_column.iloc[0], bytes):
394446
return cell_id_column.str.decode("utf-8")
@@ -403,28 +455,35 @@ def _get_polygons(
403455
specs: dict[str, Any],
404456
idx: pd.Series | None = None,
405457
) -> GeoDataFrame:
406-
# seems to be faster than pd.read_parquet
407-
df = pq.read_table(path / file).to_pandas()
408-
cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
409-
x = df[XeniumKeys.BOUNDARIES_VERTEX_X].to_numpy()
410-
y = df[XeniumKeys.BOUNDARIES_VERTEX_Y].to_numpy()
458+
# Use PyArrow compute to avoid slow .to_numpy() on Arrow-backed strings in pandas >= 3.0
459+
# The original approach was:
460+
# df = pq.read_table(path / file).to_pandas()
461+
# cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
462+
# which got slow with pandas >= 3.0 (Arrow-backed string .to_numpy() is ~100x slower).
463+
# By doing change detection in Arrow, we avoid allocating Python string objects for all rows.
464+
table = pq.read_table(path / file)
465+
cell_id_col = table.column(str(XeniumKeys.CELL_ID))
466+
467+
x = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_X)).to_numpy()
468+
y = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_Y)).to_numpy()
411469
coords = np.column_stack([x, y])
412470

413-
change_mask = np.concatenate([[True], cell_ids[1:] != cell_ids[:-1]])
471+
n = len(cell_id_col)
472+
change_mask = np.empty(n, dtype=bool)
473+
change_mask[0] = True
474+
change_mask[1:] = pc.not_equal(cell_id_col.slice(0, n - 1), cell_id_col.slice(1)).to_numpy(zero_copy_only=False)
414475
group_starts = np.where(change_mask)[0]
415-
group_ends = np.concatenate([group_starts[1:], [len(cell_ids)]])
476+
group_ends = np.concatenate([group_starts[1:], [n]])
416477

417478
# sanity check
418-
n_unique_ids = len(df[XeniumKeys.CELL_ID].drop_duplicates())
479+
n_unique_ids = pc.count_distinct(cell_id_col).as_py()
419480
if len(group_starts) != n_unique_ids:
420481
raise ValueError(
421482
f"In {file}, rows belonging to the same polygon must be contiguous. "
422483
f"Expected {n_unique_ids} group starts, but found {len(group_starts)}. "
423484
f"This indicates non-consecutive polygon rows."
424485
)
425486

426-
unique_ids = cell_ids[group_starts]
427-
428487
# offsets for ragged array:
429488
# offsets[0] (ring_offsets): describing to which rings the vertex positions belong to
430489
# offsets[1] (geom_offsets): describing to which polygons the rings belong to
@@ -433,22 +492,16 @@ def _get_polygons(
433492

434493
geoms = from_ragged_array(GeometryType.POLYGON, coords, offsets=(ring_offsets, geom_offsets))
435494

436-
index = _decode_cell_id_column(pd.Series(unique_ids))
437-
geo_df = GeoDataFrame({"geometry": geoms}, index=index.values)
438-
439-
version = _parse_version_of_xenium_analyzer(specs)
440-
if version is not None and version < packaging.version.parse("2.0.0"):
441-
assert idx is not None
442-
assert len(idx) == len(geo_df)
443-
assert np.array_equal(index.values, idx.values)
495+
# idx is not None for the cells and None for the nuclei (for xenium(cells_table=False) is None for both
496+
if idx is not None:
497+
# Cell IDs already available from the annotation table
498+
assert len(idx) == len(group_starts), f"Expected {len(group_starts)} cell IDs, got {len(idx)}"
499+
geo_df = GeoDataFrame({"geometry": geoms}, index=idx.values)
444500
else:
445-
if np.unique(geo_df.index).size != len(geo_df):
446-
warnings.warn(
447-
"Found non-unique polygon indices, this will be addressed in a future version of the reader. For the "
448-
"time being please consider merging polygons with non-unique indices into single multi-polygons.",
449-
UserWarning,
450-
stacklevel=2,
451-
)
501+
# Fall back to extracting unique cell IDs from parquet (slow for large_string columns).
502+
unique_ids = cell_id_col.filter(change_mask).to_pylist()
503+
index = _decode_cell_id_column(pd.Series(unique_ids))
504+
geo_df = GeoDataFrame({"geometry": geoms}, index=index.values)
452505

453506
scale = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
454507
return ShapesModel.parse(geo_df, transformations={"global": scale})
@@ -459,16 +512,15 @@ def _get_labels_and_indices_mapping(
459512
specs: dict[str, Any],
460513
mask_index: int,
461514
labels_name: str,
515+
cells_zarr: zarr.Group,
516+
cell_id_str: ArrayLike,
462517
labels_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
463518
) -> tuple[GeoDataFrame, pd.DataFrame | None]:
464519
if mask_index not in [0, 1]:
465520
raise ValueError(f"mask_index must be 0 or 1, found {mask_index}.")
466521

467-
zip_file = path / XeniumKeys.CELLS_ZARR
468-
store = zarr.storage.ZipStore(zip_file, read_only=True)
469-
z = zarr.open(store, mode="r")
470522
# get the labels
471-
masks = da.from_array(z["masks"][f"{mask_index}"])
523+
masks = da.from_array(cells_zarr["masks"][f"{mask_index}"])
472524
labels = Labels2DModel.parse(masks, dims=("y", "x"), transformations={"global": Identity()}, **labels_models_kwargs)
473525

474526
# build the matching table
@@ -481,11 +533,8 @@ def _get_labels_and_indices_mapping(
481533
# supported in versions < 1.3.0
482534
return labels, None
483535

484-
cell_id, dataset_suffix = z["cell_id"][...].T
485-
cell_id_str = cell_id_str_from_prefix_suffix_uint32(cell_id, dataset_suffix)
486-
487536
if version < packaging.version.parse("2.0.0"):
488-
label_index = z["seg_mask_value"][...]
537+
label_index = cells_zarr["seg_mask_value"][...]
489538
else:
490539
# For v >= 2.0.0, seg_mask_value is no longer available in the zarr;
491540
# read label_id from the corresponding parquet boundary file instead
@@ -515,42 +564,29 @@ def _get_labels_and_indices_mapping(
515564
"label_index": label_index.astype(np.int64),
516565
}
517566
)
518-
# because AnnData converts the indices to str
519-
indices_mapping.index = indices_mapping.index.astype(str)
520567
return labels, indices_mapping
521568

522569

523570
@inject_docs(xx=XeniumKeys)
524571
def _get_cells_metadata_table_from_zarr(
525-
path: Path,
526-
file: str,
572+
cells_zarr: zarr.Group,
527573
specs: dict[str, Any],
574+
cell_id_str: ArrayLike,
528575
) -> AnnData:
529576
"""Read cells metadata from ``{xx.CELLS_ZARR}``.
530577
531578
Read the cells summary table, which contains the z_level information for versions < 2.0.0, and also the
532579
nucleus_count for versions >= 2.0.0.
533580
"""
534-
# for version >= 2.0.0, in this function we could also parse the segmentation method used to obtain the masks
535-
zip_file = path / XeniumKeys.CELLS_ZARR
536-
store = zarr.storage.ZipStore(zip_file, read_only=True)
537-
538-
z = zarr.open(store, mode="r")
539-
x = z["cell_summary"][...]
540-
column_names = z["cell_summary"].attrs["column_names"]
581+
x = cells_zarr["cell_summary"][...]
582+
column_names = cells_zarr["cell_summary"].attrs["column_names"]
541583
df = pd.DataFrame(x, columns=column_names)
542-
cell_id_prefix = z["cell_id"][:, 0]
543-
dataset_suffix = z["cell_id"][:, 1]
544-
store.close()
545584

546-
cell_id_str = cell_id_str_from_prefix_suffix_uint32(cell_id_prefix, dataset_suffix)
547585
df[XeniumKeys.CELL_ID] = cell_id_str
548-
# because AnnData converts the indices to str
549-
df.index = df.index.astype(str)
550586
return df
551587

552588

553-
def _get_points(path: Path, specs: dict[str, Any]) -> Table:
589+
def _get_points(path: Path, specs: dict[str, Any]) -> pa.Table:
554590
table = read_parquet(path / XeniumKeys.TRANSCRIPTS_FILE)
555591

556592
# check if we need to decode bytes
@@ -592,10 +628,12 @@ def _get_tables_and_circles(
592628
) -> AnnData | tuple[AnnData, AnnData]:
593629
adata = _read_10x_h5(path / XeniumKeys.CELL_FEATURE_MATRIX_FILE, gex_only=gex_only)
594630
metadata = pd.read_parquet(path / XeniumKeys.CELL_METADATA_FILE)
595-
np.testing.assert_array_equal(metadata.cell_id.astype(str), adata.obs_names.values)
631+
_assert_arrays_equal_sampled(metadata.cell_id.astype(str), adata.obs_names.values)
596632
circ = metadata[[XeniumKeys.CELL_X, XeniumKeys.CELL_Y]].to_numpy()
597633
adata.obsm["spatial"] = circ
598634
metadata.drop([XeniumKeys.CELL_X, XeniumKeys.CELL_Y], axis=1, inplace=True)
635+
# avoids anndata's ImplicitModificationWarning
636+
metadata.index = adata.obs_names
599637
adata.obs = metadata
600638
adata.obs["region"] = specs["region"]
601639
adata.obs["region"] = adata.obs["region"].astype("category")
@@ -850,13 +888,18 @@ def _parse_version_of_xenium_analyzer(
850888
return None
851889

852890

853-
def cell_id_str_from_prefix_suffix_uint32(cell_id_prefix: ArrayLike, dataset_suffix: ArrayLike) -> ArrayLike:
854-
# explained here:
855-
# https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/analysis/xoa-output-zarr#cellID
891+
def _cell_id_str_from_prefix_suffix_uint32_reference(cell_id_prefix: ArrayLike, dataset_suffix: ArrayLike) -> ArrayLike:
892+
"""Reference implementation of cell_id_str_from_prefix_suffix_uint32.
893+
894+
Readable but slow for large arrays due to Python-level string operations.
895+
Kept as ground truth for testing the optimized version.
896+
897+
See https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/analysis/xoa-output-zarr#cellID
898+
"""
856899
# convert to hex, remove the 0x prefix
857900
cell_id_prefix_hex = [hex(x)[2:] for x in cell_id_prefix]
858901

859-
# shift the hex values
902+
# shift the hex values: '0'->'a', ..., '9'->'j', 'a'->'k', ..., 'f'->'p'
860903
hex_shift = {str(i): chr(ord("a") + i) for i in range(10)} | {
861904
chr(ord("a") + i): chr(ord("a") + 10 + i) for i in range(6)
862905
}
@@ -870,6 +913,31 @@ def cell_id_str_from_prefix_suffix_uint32(cell_id_prefix: ArrayLike, dataset_suf
870913
return np.array(cell_id_str)
871914

872915

916+
def cell_id_str_from_prefix_suffix_uint32(cell_id_prefix: ArrayLike, dataset_suffix: ArrayLike) -> ArrayLike:
917+
"""Convert cell ID prefix/suffix uint32 pairs to the Xenium string representation.
918+
919+
Each uint32 prefix is converted to 8 hex nibbles, each mapped to a character
920+
(0->'a', 1->'b', ..., 15->'p'), then joined with "-{suffix}".
921+
922+
See https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/analysis/xoa-output-zarr#cellID
923+
"""
924+
cell_id_prefix = np.asarray(cell_id_prefix, dtype=np.uint32)
925+
dataset_suffix = np.asarray(dataset_suffix)
926+
927+
# Extract 8 hex nibbles (4 bits each) from each uint32, most significant first.
928+
# Each nibble maps to a character: 0->'a', 1->'b', ..., 9->'j', 10->'k', ..., 15->'p'.
929+
# Leading zero nibbles become 'a', equivalent to rjust(8, 'a') padding.
930+
shifts = np.array([28, 24, 20, 16, 12, 8, 4, 0], dtype=np.uint32)
931+
nibbles = (cell_id_prefix[:, np.newaxis] >> shifts) & 0xF
932+
char_codes = (nibbles + ord("a")).astype(np.uint8)
933+
934+
# View the (n, 8) uint8 array as n byte-strings of length 8
935+
prefix_strs = char_codes.view("S8").ravel().astype("U8")
936+
937+
suffix_strs = np.char.add("-", dataset_suffix.astype("U"))
938+
return np.char.add(prefix_strs, suffix_strs)
939+
940+
873941
def prefix_suffix_uint32_from_cell_id_str(
874942
cell_id_str: ArrayLike,
875943
) -> tuple[ArrayLike, ArrayLike]:

0 commit comments

Comments
 (0)