Skip to content

Commit 2c7bceb

Browse files
committed
wip: add censor_period_ms to slay
1 parent 5eff246 commit 2c7bceb

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

src/spikeinterface/curation/auto_merge.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
"censored_period_ms": 0.3,
8989
},
9090
"quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3},
91-
"slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5},
91+
"slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5, "censored_period_ms": 0.2},
9292
}
9393

9494

@@ -1552,6 +1552,7 @@ def compute_slay_matrix(
15521552
sorting_analyzer: SortingAnalyzer,
15531553
k1: float,
15541554
k2: float,
1555+
censor_period_ms: float,
15551556
templates_diff: np.ndarray | None,
15561557
pair_mask: np.ndarray | None = None,
15571558
):
@@ -1569,6 +1570,9 @@ def compute_slay_matrix(
15691570
Coefficient determining the importance of the cross-correlation significance
15701571
k2 : float
15711572
Coefficient determining the importance of the sliding rp violation
1573+
censor_period_ms : float
1574+
The censored period to exclude from the refractory period computation to discard
1575+
duplicated spikes.
15721576
templates_diff : np.ndarray | None
15731577
Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer.
15741578
pair_mask : None | np.ndarray, default: None
@@ -1592,14 +1596,14 @@ def compute_slay_matrix(
15921596
sigma_ij = 1 - templates_diff
15931597
else:
15941598
sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data()
1595-
rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask)
1599+
rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask, censor_period_ms)
15961600

15971601
M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij
15981602

15991603
return M_ij
16001604

16011605

1602-
def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray):
1606+
def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray, censor_period_ms: float):
16031607
"""
16041608
Computes a cross-correlation significance measure and a sliding refractory period violation
16051609
measure for all units in the `sorting_analyzer`.
@@ -1610,6 +1614,9 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra
16101614
The sorting analyzer object containing the spike sorting data
16111615
pair_mask : np.ndarray
16121616
A bool matrix describing which pairs are possible merges based on previous steps
1617+
censor_period_ms : float
1618+
The censored period to exclude from the refractory period computation to discard
1619+
duplicated spikes.
16131620
"""
16141621

16151622
correlograms_extension = sorting_analyzer.get_extension("correlograms")
@@ -1628,7 +1635,12 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra
16281635
if not pair_mask[unit_index_1, unit_index_2]:
16291636
continue
16301637

1638+
# TODO: test this
16311639
xgram = ccgs[unit_index_1, unit_index_2, :]
1640+
if censor_period_ms > 0:
1641+
center_bin = len(xgram) // 2
1642+
censor_bins = int(round(censor_period_ms / bin_size_ms))
1643+
xgram[center_bin - censor_bins : center_bin + censor_bins + 1] = 0
16321644

16331645
rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair(
16341646
xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0

0 commit comments

Comments
 (0)