Skip to content

Commit a20bbaf

Browse files
Merge pull request #376 from scverse/more-xenium-performance
Improve Xenium performance, fix multinucleate cells bug
2 parents 9011cfd + c144815 commit a20bbaf

File tree

3 files changed

+135
-107
lines changed

3 files changed

+135
-107
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ node_modules/
3939
# memray report
4040
*.bin
4141

42+
# speedscope report
43+
profile.speedscope.json
44+
4245
# test datasets (e.g. Xenium ones)
4346
# symlinks
4447
data

src/spatialdata_io/readers/xenium.py

Lines changed: 126 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
from geopandas import GeoDataFrame
2323
from shapely import GeometryType, Polygon, from_ragged_array
2424
from spatialdata import SpatialData
25-
from spatialdata._core.query.relational_query import get_element_instances
26-
from spatialdata._logging import logger
2725
from spatialdata.models import (
2826
Image2DModel,
2927
Labels2DModel,
@@ -203,16 +201,17 @@ def xenium(
203201
# open cells.zarr.zip once and reuse across all functions that need it
204202
cells_zarr: zarr.Group | None = None
205203
need_cells_zarr = (
206-
nucleus_labels
204+
nucleus_boundaries
205+
or nucleus_labels
206+
or cells_boundaries
207207
or cells_labels
208208
or (version is not None and version >= packaging.version.parse("2.0.0") and table is not None)
209209
)
210210
if need_cells_zarr:
211211
cells_zarr_store = zarr.storage.ZipStore(path / XeniumKeys.CELLS_ZARR, read_only=True)
212212
cells_zarr = zarr.open(cells_zarr_store, mode="r")
213213

214-
# pre-compute cell_id strings from the zarr once, to avoid redundant conversion
215-
# in both _get_cells_metadata_table_from_zarr and _get_labels_and_indices_mapping.
214+
# pre-compute cell_id strings from the zarr once, to avoid redundant conversion.
216215
cells_zarr_cell_id_str: np.ndarray | None = None
217216
if cells_zarr is not None and version is not None and version >= packaging.version.parse("1.3.0"):
218217
cell_id_raw = cells_zarr["cell_id"][...]
@@ -221,7 +220,7 @@ def xenium(
221220

222221
if version is not None and version >= packaging.version.parse("2.0.0") and table is not None:
223222
assert cells_zarr is not None
224-
cell_summary_table = _get_cells_metadata_table_from_zarr(cells_zarr, specs, cells_zarr_cell_id_str)
223+
cell_summary_table = _get_cells_metadata_table_from_zarr(cells_zarr, cells_zarr_cell_id_str)
225224
try:
226225
_assert_arrays_equal_sampled(
227226
cell_summary_table[XeniumKeys.CELL_ID].values, table.obs[XeniumKeys.CELL_ID].values
@@ -243,36 +242,31 @@ def xenium(
243242
points = {}
244243
images = {}
245244

246-
# From the public release notes here:
247-
# https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/release-notes/release-notes-for-xoa
248-
# we see that for distinguishing between the nuclei of polinucleated cells, the `label_id` column is used.
249-
# This column is currently not found in the preview data, while I think it is needed in order to unambiguously match
250-
# nuclei to cells. Therefore for the moment we only link the table to the cell labels, and not to the nucleus
251-
# labels.
245+
# Build the label_index <-> cell_id mappings from the zarr once, reuse for both labels
246+
# and boundaries. For v2.0+ this is deterministic from the zarr polygon_sets
247+
# (label_id = cell_index + 1). For older versions, use seg_mask_value (cells only).
248+
# For nuclei in v2.0+, this correctly handles multinucleate cells: each nucleus gets its
249+
# own label_index, avoiding the bug of merging multiple nuclei into a single polygon.
250+
# Older versions do not support multinucleate cells, so cell_id-based grouping is correct.
251+
nucleus_indices_mapping: pd.DataFrame | None = None
252+
cell_indices_mapping: pd.DataFrame | None = None
253+
if cells_zarr_cell_id_str is not None and cells_zarr is not None and "polygon_sets" in cells_zarr:
254+
if nucleus_boundaries or nucleus_labels:
255+
nucleus_indices_mapping = _get_indices_mapping_from_zarr(cells_zarr, cells_zarr_cell_id_str, mask_index=0)
256+
if cells_boundaries or cells_labels:
257+
cell_indices_mapping = _get_indices_mapping_from_zarr(cells_zarr, cells_zarr_cell_id_str, mask_index=1)
258+
elif cells_zarr_cell_id_str is not None:
259+
if cells_boundaries or cells_labels:
260+
cell_indices_mapping = _get_indices_mapping_legacy(cells_zarr, cells_zarr_cell_id_str, specs=specs)
261+
252262
if nucleus_labels:
253-
labels["nucleus_labels"], _ = _get_labels_and_indices_mapping(
254-
path=path,
255-
specs=specs,
256-
mask_index=0,
257-
labels_name="nucleus_labels",
258-
labels_models_kwargs=labels_models_kwargs,
259-
cells_zarr=cells_zarr,
260-
cell_id_str=None,
261-
)
263+
labels["nucleus_labels"] = _get_labels(cells_zarr, mask_index=0, labels_models_kwargs=labels_models_kwargs)
262264
if cells_labels:
263-
labels["cell_labels"], cell_labels_indices_mapping = _get_labels_and_indices_mapping(
264-
path=path,
265-
specs=specs,
266-
mask_index=1,
267-
labels_name="cell_labels",
268-
labels_models_kwargs=labels_models_kwargs,
269-
cells_zarr=cells_zarr,
270-
cell_id_str=cells_zarr_cell_id_str,
271-
)
272-
if cell_labels_indices_mapping is not None and table is not None:
265+
labels["cell_labels"] = _get_labels(cells_zarr, mask_index=1, labels_models_kwargs=labels_models_kwargs)
266+
if cell_indices_mapping is not None and table is not None:
273267
try:
274268
_assert_arrays_equal_sampled(
275-
cell_labels_indices_mapping["cell_id"].values, table.obs[str(XeniumKeys.CELL_ID)].values
269+
cell_indices_mapping["cell_id"].values, table.obs[str(XeniumKeys.CELL_ID)].values
276270
)
277271
except AssertionError:
278272
warnings.warn(
@@ -283,7 +277,7 @@ def xenium(
283277
stacklevel=2,
284278
)
285279
else:
286-
table.obs["cell_labels"] = cell_labels_indices_mapping["label_index"].values
280+
table.obs["cell_labels"] = cell_indices_mapping["label_index"].values
287281
if not cells_as_circles:
288282
table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "cell_labels"
289283

@@ -292,15 +286,16 @@ def xenium(
292286
path,
293287
XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
294288
specs,
295-
idx=None,
289+
indices_mapping=nucleus_indices_mapping,
290+
is_nucleus=True,
296291
)
297292

298293
if cells_boundaries:
299294
polygons["cell_boundaries"] = _get_polygons(
300295
path,
301296
XeniumKeys.CELL_BOUNDARIES_FILE,
302297
specs,
303-
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
298+
indices_mapping=cell_indices_mapping,
304299
)
305300

306301
if transcripts:
@@ -455,37 +450,57 @@ def _get_polygons(
455450
path: Path,
456451
file: str,
457452
specs: dict[str, Any],
458-
idx: pd.Series | None = None,
453+
indices_mapping: pd.DataFrame | None = None,
454+
is_nucleus: bool = False,
459455
) -> GeoDataFrame:
460-
# Use PyArrow compute to avoid slow .to_numpy() on Arrow-backed strings in pandas >= 3.0
461-
# The original approach was:
462-
# df = pq.read_table(path / file).to_pandas()
463-
# cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
464-
# which got slow with pandas >= 3.0 (Arrow-backed string .to_numpy() is ~100x slower).
465-
# By doing change detection in Arrow, we avoid allocating Python string objects for all rows.
466-
table = pq.read_table(path / file)
467-
cell_id_col = table.column(str(XeniumKeys.CELL_ID))
456+
"""Parse boundary polygons from a parquet file.
457+
458+
Parameters
459+
----------
460+
indices_mapping
461+
When provided (from ``_get_indices_mapping_from_zarr`` or ``_get_indices_mapping_legacy``),
462+
contains ``cell_id`` and ``label_index`` columns. The parquet ``label_id`` column is used
463+
for fast integer-based change detection (to locate all the vertices of each polygon).
464+
When None, falls back to cell_id-based grouping from the parquet (Xenium < 2.0).
465+
is_nucleus
466+
When True (nucleus boundaries), use ``label_index`` as the GeoDataFrame index and store
467+
``cell_id`` as a column. This gives each nucleus a distinct integer id matching the raster
468+
labels, correctly handling multinucleate cells.
469+
When False (cell boundaries), use ``cell_id`` as the GeoDataFrame index.
470+
"""
471+
# Check whether the parquet has a label_id column (v2.0+). When present, use it for
472+
# fast integer-based change detection. Otherwise fall back to cell_id strings.
473+
parquet_schema = pq.read_schema(path / file)
474+
has_label_id = "label_id" in parquet_schema.names
475+
476+
columns_to_read = [str(XeniumKeys.BOUNDARIES_VERTEX_X), str(XeniumKeys.BOUNDARIES_VERTEX_Y)]
477+
columns_to_read.append("label_id" if has_label_id else str(XeniumKeys.CELL_ID))
478+
table = pq.read_table(path / file, columns=columns_to_read)
468479

469480
x = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_X)).to_numpy()
470481
y = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_Y)).to_numpy()
471482
coords = np.column_stack([x, y])
472483

473-
n = len(cell_id_col)
474-
change_mask = np.empty(n, dtype=bool)
475-
change_mask[0] = True
476-
change_mask[1:] = pc.not_equal(cell_id_col.slice(0, n - 1), cell_id_col.slice(1)).to_numpy(zero_copy_only=False)
477-
group_starts = np.where(change_mask)[0]
478-
group_ends = np.concatenate([group_starts[1:], [n]])
484+
n = len(x)
479485

480-
# sanity check
481-
n_unique_ids = pc.count_distinct(cell_id_col).as_py()
486+
if has_label_id:
487+
id_col = table.column("label_id")
488+
id_arr = id_col.to_numpy()
489+
change_mask = id_arr[1:] != id_arr[:-1]
490+
else:
491+
id_col = table.column(str(XeniumKeys.CELL_ID))
492+
change_mask = pc.not_equal(id_col.slice(0, n - 1), id_col.slice(1)).to_numpy(zero_copy_only=False)
493+
group_starts = np.where(np.concatenate([[True], change_mask]))[0]
494+
n_unique_ids = pc.count_distinct(id_col).as_py()
482495
if len(group_starts) != n_unique_ids:
483496
raise ValueError(
484497
f"In {file}, rows belonging to the same polygon must be contiguous. "
485498
f"Expected {n_unique_ids} group starts, but found {len(group_starts)}. "
486499
f"This indicates non-consecutive polygon rows."
487500
)
488501

502+
group_ends = np.concatenate([group_starts[1:], [n]])
503+
489504
# offsets for ragged array:
490505
# offsets[0] (ring_offsets): describing to which rings the vertex positions belong to
491506
# offsets[1] (geom_offsets): describing to which polygons the rings belong to
@@ -494,85 +509,92 @@ def _get_polygons(
494509

495510
geoms = from_ragged_array(GeometryType.POLYGON, coords, offsets=(ring_offsets, geom_offsets))
496511

497-
# idx is not None for the cells and None for the nuclei (for xenium(cells_table=False) is None for both
498-
if idx is not None:
499-
# Cell IDs already available from the annotation table
500-
assert len(idx) == len(group_starts), f"Expected {len(group_starts)} cell IDs, got {len(idx)}"
501-
geo_df = GeoDataFrame({"geometry": geoms}, index=idx.values)
512+
if indices_mapping is not None:
513+
assert len(indices_mapping) == len(group_starts), (
514+
f"Expected {len(group_starts)} polygons, but indices_mapping has {len(indices_mapping)} entries."
515+
)
516+
if is_nucleus:
517+
# Use label_index (int) as GeoDataFrame index, cell_id as column.
518+
geo_df = GeoDataFrame(
519+
{"geometry": geoms, str(XeniumKeys.CELL_ID): indices_mapping["cell_id"].values},
520+
index=indices_mapping["label_index"].values,
521+
)
522+
else:
523+
# Use cell_id (str) as GeoDataFrame index.
524+
geo_df = GeoDataFrame({"geometry": geoms}, index=indices_mapping["cell_id"].values)
502525
else:
503526
# Fall back to extracting unique cell IDs from parquet (slow for large_string columns).
504-
unique_ids = cell_id_col.filter(change_mask).to_pylist()
527+
unique_ids = id_col.filter(np.concatenate([[True], change_mask])).to_pylist()
505528
index = _decode_cell_id_column(pd.Series(unique_ids))
506529
geo_df = GeoDataFrame({"geometry": geoms}, index=index.values)
507530

508531
scale = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
509532
return ShapesModel.parse(geo_df, transformations={"global": scale})
510533

511534

512-
def _get_labels_and_indices_mapping(
513-
path: Path,
514-
specs: dict[str, Any],
515-
mask_index: int,
516-
labels_name: str,
535+
def _get_labels(
517536
cells_zarr: zarr.Group,
518-
cell_id_str: ArrayLike,
537+
mask_index: int,
519538
labels_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
520-
) -> tuple[GeoDataFrame, pd.DataFrame | None]:
539+
) -> DataArray:
540+
"""Read the labels raster from cells.zarr.zip masks/{mask_index}."""
521541
if mask_index not in [0, 1]:
522542
raise ValueError(f"mask_index must be 0 or 1, found {mask_index}.")
523-
524-
# get the labels
525543
masks = da.from_array(cells_zarr["masks"][f"{mask_index}"])
526-
labels = Labels2DModel.parse(masks, dims=("y", "x"), transformations={"global": Identity()}, **labels_models_kwargs)
544+
return Labels2DModel.parse(masks, dims=("y", "x"), transformations={"global": Identity()}, **labels_models_kwargs)
527545

528-
# build the matching table
529-
version = _parse_version_of_xenium_analyzer(specs)
530-
if mask_index == 0:
531-
# nuclei currently not supported
532-
return labels, None
533-
if version is None or version is not None and version < packaging.version.parse("1.3.0"):
534-
# supported in version 1.3.0 and not supported in version 1.0.2; conservatively, let's assume it is not
535-
# supported in versions < 1.3.0
536-
return labels, None
537-
538-
if version < packaging.version.parse("2.0.0"):
539-
label_index = cells_zarr["seg_mask_value"][...]
540-
else:
541-
# For v >= 2.0.0, seg_mask_value is no longer available in the zarr;
542-
# read label_id from the corresponding parquet boundary file instead
543-
boundaries_file = XeniumKeys.NUCLEUS_BOUNDARIES_FILE if mask_index == 0 else XeniumKeys.CELL_BOUNDARIES_FILE
544-
boundary_columns = pq.read_schema(path / boundaries_file).names
545-
if "label_id" in boundary_columns:
546-
boundary_df = pq.read_table(path / boundaries_file, columns=[XeniumKeys.CELL_ID, "label_id"]).to_pandas()
547-
unique_pairs = boundary_df.drop_duplicates(subset=[XeniumKeys.CELL_ID, "label_id"]).copy()
548-
unique_pairs[XeniumKeys.CELL_ID] = _decode_cell_id_column(unique_pairs[XeniumKeys.CELL_ID])
549-
cell_id_to_label_id = unique_pairs.set_index(XeniumKeys.CELL_ID)["label_id"]
550-
label_index = cell_id_to_label_id.loc[cell_id_str].values
551-
else:
552-
# fallback for dev versions around 2.0.0 that lack both seg_mask_value and label_id
553-
logger.warn(
554-
f"Could not find the labels ids from the metadata for version {version}. Using a fallback (slower) implementation."
555-
)
556-
label_index = get_element_instances(labels).values
557546

558-
if label_index[0] == 0:
559-
label_index = label_index[1:]
547+
def _get_indices_mapping_from_zarr(
548+
cells_zarr: zarr.Group,
549+
cells_zarr_cell_id_str: np.ndarray,
550+
mask_index: int,
551+
) -> pd.DataFrame:
552+
"""Build the label_index <-> cell_id mapping from the zarr polygon_sets.
560553
561-
# labels_index is an uint32, so let's cast to np.int64 to avoid the risk of overflow on some systems
562-
indices_mapping = pd.DataFrame(
554+
From the 10x Genomics docs: "the label ID is equal to the cell index + 1",
555+
where cell_index is polygon_sets/{mask_index}/cell_index. This is deterministic
556+
and avoids reading the slow parquet boundary files.
557+
558+
For cells (mask_index=1): cell_index is 0..N-1 (1:1 with cells), so
559+
label_index = arange(1, N+1).
560+
For nuclei (mask_index=0): cell_index maps each nucleus to its parent cell,
561+
so label_index = arange(1, M+1) and cell_id = cell_id_str[cell_index[i]].
562+
"""
563+
cell_index = cells_zarr[f"polygon_sets/{mask_index}/cell_index"][...]
564+
label_index = np.arange(1, len(cell_index) + 1, dtype=np.int64)
565+
cell_id = cells_zarr_cell_id_str[cell_index]
566+
return pd.DataFrame(
567+
{
568+
"cell_id": cell_id,
569+
"label_index": label_index,
570+
}
571+
)
572+
573+
574+
def _get_indices_mapping_legacy(
575+
cells_zarr: zarr.Group,
576+
cell_id_str: ArrayLike,
577+
specs: dict[str, Any],
578+
) -> pd.DataFrame | None:
579+
"""Build the label_index <-> cell_id mapping for versions < 2.0.0.
580+
581+
Uses seg_mask_value from the zarr (available in v1.3.0+).
582+
"""
583+
version = _parse_version_of_xenium_analyzer(specs)
584+
if version is None or version < packaging.version.parse("1.3.0"):
585+
return None
586+
label_index = cells_zarr["seg_mask_value"][...]
587+
return pd.DataFrame(
563588
{
564-
"region": labels_name,
565589
"cell_id": cell_id_str,
566590
"label_index": label_index.astype(np.int64),
567591
}
568592
)
569-
return labels, indices_mapping
570593

571594

572595
@inject_docs(xx=XeniumKeys)
573596
def _get_cells_metadata_table_from_zarr(
574597
cells_zarr: zarr.Group,
575-
specs: dict[str, Any],
576598
cell_id_str: ArrayLike,
577599
) -> AnnData:
578600
"""Read cells metadata from ``{xx.CELLS_ZARR}``.

tests/test_xenium.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def test_example_data_index_integrity(dataset: str) -> None:
113113
assert sdata["nucleus_labels"]["scale0"]["image"].sel(y=3515.5, x=4618.5).data.compute() == 6392
114114
assert np.allclose(sdata['transcripts'].compute().loc[[0, 10000, 1113949]]['x'], [2.608911, 194.917831, 1227.499268])
115115
assert np.isclose(sdata['cell_boundaries'].loc['oipggjko-1'].geometry.centroid.x,736.4864931162789)
116-
assert np.isclose(sdata['nucleus_boundaries'].loc['oipggjko-1'].geometry.centroid.x,736.4931256878282)
116+
index = sdata['nucleus_boundaries']['cell_id'].index[sdata['nucleus_boundaries']['cell_id'].eq('oipggjko-1')][0]
117+
assert np.isclose(sdata['nucleus_boundaries'].loc[index].geometry.centroid.x,736.4931256878282)
117118
assert np.array_equal(sdata['table'].X.indices[:3], [1, 3, 34])
118119
# fmt: on
119120

@@ -138,7 +139,8 @@ def test_example_data_index_integrity(dataset: str) -> None:
138139
assert sdata["nucleus_labels"]["scale0"]["image"].sel(y=18.5, x=3015.5).data.compute() == 2764
139140
assert np.allclose(sdata['transcripts'].compute().loc[[0, 10000, 20000]]['x'], [174.258392, 12.210024, 214.759186])
140141
assert np.isclose(sdata['cell_boundaries'].loc['aaanbaof-1'].geometry.centroid.x, 43.96894317275074)
141-
assert np.isclose(sdata['nucleus_boundaries'].loc['aaanbaof-1'].geometry.centroid.x,43.31874577809517)
142+
index = sdata['nucleus_boundaries']['cell_id'].index[sdata['nucleus_boundaries']['cell_id'].eq('aaanbaof-1')][0]
143+
assert np.isclose(sdata['nucleus_boundaries'].loc[index].geometry.centroid.x,43.31874577809517)
142144
assert np.array_equal(sdata['table'].X.indices[:3], [1, 8, 19])
143145
# fmt: on
144146

@@ -164,7 +166,8 @@ def test_example_data_index_integrity(dataset: str) -> None:
164166
assert sdata["nucleus_labels"]["scale0"]["image"].sel(y=4039.5, x=93.5).data.compute() == 274
165167
assert np.allclose(sdata['transcripts'].compute().loc[[0, 10000, 20000]]['x'], [43.296875, 62.484375, 93.125])
166168
assert np.isclose(sdata['cell_boundaries'].loc['aadmbfof-1'].geometry.centroid.x, 64.54541104696033)
167-
assert np.isclose(sdata['nucleus_boundaries'].loc['aadmbfof-1'].geometry.centroid.x, 65.43305896114295)
169+
index = sdata['nucleus_boundaries']['cell_id'].index[sdata['nucleus_boundaries']['cell_id'].eq('aadmbfof-1')][0]
170+
assert np.isclose(sdata['nucleus_boundaries'].loc[index].geometry.centroid.x, 65.43305896114295)
168171
assert np.array_equal(sdata['table'].X.indices[:3], [3, 49, 53])
169172
# fmt: on
170173

0 commit comments

Comments
 (0)