Skip to content

Commit 758d9d5

Browse files
Sequential curation (#4298)
Co-authored-by: chrishalcrow <[email protected]> Co-authored-by: Chris Halcrow <[email protected]>
1 parent df7a1e7 commit 758d9d5

File tree

9 files changed

+203
-17
lines changed

9 files changed

+203
-17
lines changed

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
* ComputeNoiseLevels which is very convenient to have
1010
"""
1111

12+
from copy import copy
1213
import warnings
1314
import numpy as np
1415
from collections import namedtuple
@@ -387,6 +388,13 @@ def _handle_backward_compatibility_on_load(self):
387388
# compatibility february 2024 > july 2024
388389
self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency
389390

391+
old_keys = copy(list(self.data.keys()))
392+
for operator in old_keys:
393+
if "pencentile" in operator:
394+
fixed_operator = operator.replace("pencentile", "percentile")
395+
self.data[fixed_operator] = self.data[operator]
396+
del self.data[operator]
397+
390398
def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=None):
391399
operators = operators or ["average", "std"]
392400
assert isinstance(operators, list)
@@ -485,7 +493,7 @@ def _compute_and_append_from_waveforms(self, operators):
485493
elif isinstance(operator, (list, tuple)):
486494
operator, percentile = operator
487495
assert operator == "percentile"
488-
key = f"pencentile_{percentile}"
496+
key = f"percentile_{percentile}"
489497
else:
490498
raise ValueError(f"ComputeTemplates: wrong operator {operator}")
491499
self.data[key] = np.zeros((unit_ids.size, num_samples, channel_ids.size))
@@ -516,7 +524,7 @@ def _compute_and_append_from_waveforms(self, operators):
516524
elif isinstance(operator, (list, tuple)):
517525
operator, percentile = operator
518526
arr = np.percentile(wfs, percentile, axis=0)
519-
key = f"pencentile_{percentile}"
527+
key = f"percentile_{percentile}"
520528

521529
if self.sparsity is None:
522530
self.data[key][unit_index, :, :] = arr
@@ -606,7 +614,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
606614
elif operator == "median":
607615
arr = np.median(wfs, axis=0)
608616
elif "percentile" in operator:
609-
_, percentile = operator.splot("_")
617+
_, percentile = operator.split("_")
610618
arr = np.percentile(wfs, float(percentile), axis=0)
611619
new_array[split_unit_index, ...] = arr
612620
else:
@@ -676,7 +684,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
676684
key = operator
677685
else:
678686
assert percentile is not None, "You must provide percentile=... if `operator='percentile'`"
679-
key = f"pencentile_{percentile}"
687+
key = f"percentile_{percentile}"
680688

681689
if key in self.data:
682690
templates_array = self.data[key]

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,6 +2264,7 @@ class AnalyzerExtension:
22642264
* _run()
22652265
* _select_extension_data()
22662266
* _merge_extension_data()
2267+
* _split_extension_data()
22672268
* _get_data()
22682269
22692270
The subclass must also set an `extension_name` class attribute which is not None by default.

src/spikeinterface/core/tests/test_analyzer_extension_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_ComputeTemplates(format, sparse, create_cache_folder):
160160

161161
# they all should be in data
162162
data = sorting_analyzer.get_extension("templates").data
163-
for k in ["average", "std", "median", "pencentile_5.0", "pencentile_95.0"]:
163+
for k in ["average", "std", "median", "percentile_5.0", "percentile_95.0"]:
164164
assert k in data.keys()
165165
assert data[k].shape[0] == sorting_analyzer.unit_ids.size
166166
assert data[k].shape[2] == sorting_analyzer.channel_ids.size

src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def get_all_templates(
324324
ext = self.sorting_analyzer.get_extension("templates")
325325

326326
if mode == "percentile":
327-
key = f"pencentile_{percentile}"
327+
key = f"percentile_{percentile}"
328328
else:
329329
key = mode
330330

src/spikeinterface/curation/curation_format.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from itertools import chain
77

88
from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting, apply_splits_to_sorting
9-
from spikeinterface.curation.curation_model import CurationModel
9+
from spikeinterface.curation.curation_model import CurationModel, SequentialCuration
1010

1111

1212
def validate_curation_dict(curation_dict: dict):
@@ -138,7 +138,7 @@ def apply_curation_labels(
138138

139139
def apply_curation(
140140
sorting_or_analyzer: BaseSorting | SortingAnalyzer,
141-
curation_dict_or_model: dict | CurationModel,
141+
curation_dict_or_model: dict | list | CurationModel | SequentialCuration,
142142
censor_ms: float | None = None,
143143
new_id_strategy: str = "append",
144144
merging_mode: str = "soft",
@@ -164,7 +164,7 @@ def apply_curation(
164164
----------
165165
sorting_or_analyzer : Sorting | SortingAnalyzer
166166
The Sorting or SortingAnalyzer object to apply merges.
167-
curation_dict : dict or CurationModel
167+
curation_dict : dict | CurationModel | SequentialCuration
168168
The curation dict or model.
169169
censor_ms : float | None, default: None
170170
When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of
@@ -199,14 +199,32 @@ def apply_curation(
199199
sorting_or_analyzer, (BaseSorting, SortingAnalyzer)
200200
), f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}"
201201
assert isinstance(
202-
curation_dict_or_model, (dict, CurationModel)
203-
), f"`curation_dict_or_model` must be a dict or a CurationModel, not an object of type {type(curation_dict_or_model)}"
202+
curation_dict_or_model, (dict, list, CurationModel, SequentialCuration)
203+
), f"`curation_dict_or_model` must be a dict, CurationModel or a SequentialCuration not an object of type {type(curation_dict_or_model)}"
204204
if isinstance(curation_dict_or_model, dict):
205205
curation_model = CurationModel(**curation_dict_or_model)
206+
elif isinstance(curation_dict_or_model, list):
207+
curation_model = SequentialCuration(curation_steps=curation_dict_or_model)
206208
else:
207209
curation_model = curation_dict_or_model.model_copy(deep=True)
208210

209-
if not np.array_equal(np.asarray(curation_model.unit_ids), sorting_or_analyzer.unit_ids):
211+
if isinstance(curation_model, SequentialCuration):
212+
for c, single_curation_model in enumerate(curation_model.curation_steps):
213+
if verbose:
214+
print(f"Applying curation step: {c + 1} / {len(curation_model.curation_steps)}")
215+
sorting_or_analyzer = apply_curation(
216+
sorting_or_analyzer,
217+
single_curation_model,
218+
censor_ms=censor_ms,
219+
merging_mode=merging_mode,
220+
sparsity_overlap=sparsity_overlap,
221+
raise_error_if_overlap_fails=raise_error_if_overlap_fails,
222+
verbose=verbose,
223+
job_kwargs=job_kwargs,
224+
)
225+
return sorting_or_analyzer
226+
227+
if not set(curation_model.unit_ids) == set(sorting_or_analyzer.unit_ids):
210228
raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer")
211229

212230
# 1. Apply labels
@@ -228,13 +246,15 @@ def apply_curation(
228246
curated_sorting_or_analyzer, _, _ = apply_merges_to_sorting(
229247
curated_sorting_or_analyzer,
230248
merge_unit_groups=merge_unit_groups,
249+
new_unit_ids=merge_new_unit_ids,
231250
censor_ms=censor_ms,
232251
new_id_strategy=new_id_strategy,
233252
return_extra=True,
234253
)
235254
else:
236255
curated_sorting_or_analyzer, _ = curated_sorting_or_analyzer.merge_units(
237256
merge_unit_groups=merge_unit_groups,
257+
new_unit_ids=merge_new_unit_ids,
238258
censor_ms=censor_ms,
239259
merging_mode=merging_mode,
240260
sparsity_overlap=sparsity_overlap,

src/spikeinterface/curation/curation_model.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from spikeinterface import BaseSorting
7+
from spikeinterface.core.sorting_tools import _get_ids_after_merging, _get_ids_after_splitting
78

89

910
class LabelDefinition(BaseModel):
@@ -190,7 +191,7 @@ def check_merges(cls, values):
190191

191192
# Check new unit id not already used
192193
if merge.new_unit_id is not None:
193-
if merge.new_unit_id in unit_ids:
194+
if merge.new_unit_id in unit_ids and merge.new_unit_id not in merge.unit_ids:
194195
raise ValueError(f"New unit ID {merge.new_unit_id} is already in the unit list")
195196

196197
values["merges"] = merges
@@ -366,6 +367,52 @@ def convert_old_format(cls, values):
366367
values["removed"] = list(removed_units)
367368
return values
368369

370+
def get_final_ids_from_new_unit_ids(self) -> list:
371+
"""
372+
Returns the final unit ids of the `curation_model`, when new unit ids are
373+
given for each curation choice. Raises an error if new unit ids are missing
374+
for any curation choice.
375+
376+
Returns
377+
-------
378+
final_ids : list
379+
The ids of the sorting/analyzer after curation takes place
380+
"""
381+
final_ids = list(self.unit_ids)
382+
# 1. Remove units
383+
for unit_id in self.removed:
384+
if unit_id not in final_ids:
385+
raise ValueError(f"Removed unit_id {unit_id} is not in the unit list")
386+
final_ids.remove(unit_id)
387+
388+
# 2. Merge units
389+
merge_unit_groups = []
390+
new_merge_unit_ids = []
391+
for merge in self.merges:
392+
if merge.new_unit_id is None:
393+
raise ValueError(
394+
f"The `new_unit_id` for the merge of units {merge.unit_ids} is `None`. This must be given."
395+
)
396+
merge_unit_groups.append(merge.unit_ids)
397+
new_merge_unit_ids.append(merge.new_unit_id)
398+
final_ids = _get_ids_after_merging(
399+
final_ids, merge_unit_groups=merge_unit_groups, new_unit_ids=new_merge_unit_ids
400+
)
401+
402+
# 3. Split units
403+
split_units = {}
404+
split_new_unit_ids = []
405+
for split in self.splits:
406+
if split.new_unit_ids is None:
407+
raise ValueError(
408+
f"The `new_unit_ids` for the split of unit {split.unit_id} is `None`. These must be given."
409+
)
410+
# we only need the correct key and elements for the split, to mimic the output of the split function
411+
split_units[split.unit_id] = [[]] * len(split.new_unit_ids)
412+
split_new_unit_ids.append(split.new_unit_ids)
413+
final_ids = _get_ids_after_splitting(final_ids, split_units=split_units, new_unit_ids=split_new_unit_ids)
414+
return list(final_ids)
415+
369416
@model_validator(mode="before")
370417
def validate_fields(cls, values):
371418
values = dict(values)
@@ -430,3 +477,43 @@ def validate_curation_dict(self):
430477
)
431478

432479
return self
480+
481+
482+
class SequentialCuration(BaseModel):
483+
"""
484+
A Pydantic model which defines a sequence of curation steps. If using sequential curations,
485+
we demand that each individual curation (except the final one) has manually defined new unit ids,
486+
and that these match the unit ids of the following curation.
487+
"""
488+
489+
curation_steps: List[CurationModel] = Field(description="List of curation steps applied sequentially")
490+
491+
@model_validator(mode="after")
492+
def validate_sequential_curation(self):
493+
494+
for curation in self.curation_steps[:-1]:
495+
for merge in curation.merges:
496+
if merge.new_unit_id is None:
497+
raise ValueError(
498+
"In a sequential curation, all curation decisions must have explicit `new_unit_id`s defined."
499+
)
500+
for split in curation.splits:
501+
if split.new_unit_ids is None:
502+
raiseValueError(
503+
"In a sequential curation, all curation decisions must have explicit `new_unit_id`s defined."
504+
)
505+
506+
for curation_index in range(len(self.curation_steps))[:-1]:
507+
508+
curation_1 = self.curation_steps[curation_index]
509+
curation_2 = self.curation_steps[curation_index + 1]
510+
511+
previous_model_final_ids = curation_1.get_final_ids_from_new_unit_ids()
512+
next_model_initial_ids = curation_2.unit_ids
513+
514+
if not (set(previous_model_final_ids) == set(next_model_initial_ids)):
515+
raise ValueError(
516+
f"The initial unit_ids of curation {curation_index+1} do not match the final unit_ids of curation {curation_index}."
517+
)
518+
519+
return self

src/spikeinterface/curation/tests/test_curation_format.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,27 @@
171171
# This is a failure because unit 99 is not in the initial list
172172
unknown_removed_unit = {**curation_ids_int, "removed": [31, 42, 99]}
173173

174+
# Sequential curation test data
175+
sequential_curation = [
176+
{
177+
"format_version": "2",
178+
"unit_ids": [1, 2, 3, 4, 5],
179+
"merges": [{"unit_ids": [3, 4], "new_unit_id": 34}],
180+
},
181+
{
182+
"format_version": "2",
183+
"unit_ids": [1, 2, 34, 5],
184+
"splits": [{"unit_id": 34, "mode": "indices", "indices": [[0, 1, 2, 3]], "new_unit_ids": [340, 341]}],
185+
},
186+
{
187+
"format_version": "2",
188+
"unit_ids": [1, 2, 340, 341, 5],
189+
"removed": [2, 5],
190+
"merges": [{"unit_ids": [1, 340], "new_unit_id": 100}],
191+
"splits": [{"unit_id": 341, "mode": "indices", "indices": [[0, 1, 2]], "new_unit_ids": [3410, 3411]}],
192+
},
193+
]
194+
174195

175196
def test_curation_format_validation():
176197
# Test basic formats
@@ -412,6 +433,26 @@ def test_apply_curation_splits_with_mask():
412433
assert spike_counts[45] == num_spikes - 2 * (num_spikes // 3) # Remainder
413434

414435

436+
def test_apply_sequential_curation():
437+
recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=5, seed=2205)
438+
sorting = sorting.rename_units([1, 2, 3, 4, 5])
439+
analyzer = create_sorting_analyzer(sorting, recording, sparse=False)
440+
441+
# sequential curation steps:
442+
# 1. merge 3 and 4 -> 34
443+
# 2. split 34 -> 340, 341
444+
# 3. remove 2, 5; merge 1 and 340 -> 100; split 341 -> 3410, 3411
445+
analyzer_curated = apply_curation(analyzer, sequential_curation, verbose=True)
446+
# initial -1(merge) +1(split) -2(remove) -1(merge) +1(split)
447+
num_final_units = analyzer.get_num_units() - 1 + 1 - 2 - 1 + 1
448+
assert analyzer_curated.get_num_units() == num_final_units
449+
450+
# check final unit ids
451+
final_unit_ids = analyzer_curated.sorting.unit_ids
452+
expected_final_unit_ids = [100, 3410, 3411]
453+
assert set(final_unit_ids) == set(expected_final_unit_ids)
454+
455+
415456
if __name__ == "__main__":
416457
test_curation_format_validation()
417458
test_to_from_json()

src/spikeinterface/curation/tests/test_curation_model.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pydantic import ValidationError
44
import numpy as np
55

6-
from spikeinterface.curation.curation_model import CurationModel, LabelDefinition
6+
from spikeinterface.curation.curation_model import CurationModel, SequentialCuration, LabelDefinition
77

88

99
# Test data for format version
@@ -282,3 +282,33 @@ def test_complete_model():
282282
assert len(model.merges) == 1
283283
assert len(model.splits) == 1
284284
assert len(model.removed) == 1
285+
286+
287+
def test_sequential_curation():
288+
sequential_curation_steps_valid = [
289+
{"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": [{"unit_ids": [1, 2], "new_unit_id": 22}]},
290+
{
291+
"format_version": "2",
292+
"unit_ids": [3, 4, 22],
293+
"splits": [
294+
{"unit_id": 22, "mode": "indices", "indices": [[0, 1, 2], [3, 4, 5]], "new_unit_ids": [222, 223]}
295+
],
296+
},
297+
{"format_version": "2", "unit_ids": [3, 4, 222, 223], "removed": [223]},
298+
]
299+
300+
# this is valid
301+
SequentialCuration(curation_steps=sequential_curation_steps_valid)
302+
303+
sequential_curation_steps_no_ids = sequential_curation_steps_valid.copy()
304+
# remove new_unit_id in merge step
305+
sequential_curation_steps_no_ids[0]["merges"][0]["new_unit_id"] = None
306+
307+
with pytest.raises(ValidationError):
308+
SequentialCuration(curation_steps=sequential_curation_steps_no_ids)
309+
310+
sequential_curation_steps_invalid = sequential_curation_steps_valid.copy()
311+
# invalid unit_ids in last step
312+
sequential_curation_steps_invalid[2]["unit_ids"] = [3, 4, 222, 224] # 224 should be 223
313+
with pytest.raises(ValidationError):
314+
SequentialCuration(curation_steps=sequential_curation_steps_invalid)

src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def split_clusters(
7474
peak_labels = peak_labels.copy()
7575
split_count = np.zeros(peak_labels.size, dtype=int)
7676
recursion_level = 1
77-
Executor = get_poolexecutor(n_jobs)
77+
executor = get_poolexecutor(n_jobs)
7878

79-
with Executor(
79+
with executor(
8080
max_workers=n_jobs,
8181
initializer=split_worker_init,
8282
mp_context=get_context(method=mp_context),
@@ -166,7 +166,6 @@ def split_worker_init(
166166
_ctx = {}
167167

168168
_ctx["recording"] = recording
169-
features_dict_or_folder
170169
_ctx["original_labels"] = original_labels
171170
_ctx["method"] = method
172171
_ctx["method_kwargs"] = method_kwargs

0 commit comments

Comments
 (0)