Skip to content

Commit 3fa0779

Browse files
tayheautayheaupre-commit-ci[bot]chrishalcrowsamuelgarcia
authored
Speed up template similarity computing using numba (#4343)
Co-authored-by: tayheau <thopsore@WD25-1022.corp.pasteur.fr> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> Co-authored-by: Samuel Garcia <sam.garcia.die@gmail.com>
1 parent b1fdeab commit 3fa0779

1 file changed

Lines changed: 87 additions & 59 deletions

File tree

src/spikeinterface/postprocessing/template_similarity.py

Lines changed: 87 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _merge_extension_data(
8181
new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids, new_sorting_analyzer.channel_ids
8282
)
8383

84+
# TODO: soft merge template similarity
8485
new_similarity, _ = compute_similarity_with_templates_array(
8586
new_templates_array,
8687
all_templates_array,
@@ -225,6 +226,8 @@ def _compute_similarity_matrix_numpy(
225226
overlapping_templates = np.flatnonzero(np.sum(local_mask, 1))
226227
tgt_templates = tgt_sliced_templates[overlapping_templates]
227228
for gcount, j in enumerate(overlapping_templates):
229+
if j < i and same_array:
230+
continue
228231
src = src_template[:, local_mask[j]].reshape(1, -1)
229232
tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1)
230233

@@ -246,7 +249,10 @@ def _compute_similarity_matrix_numpy(
246249
distances[count, i, j] = 1 - distances[count, i, j]
247250

248251
if same_array:
249-
distances[num_shifts_both_sides - count - 1, j, i] = distances[count, i, j]
252+
distances[count, j, i] = distances[count, i, j]
253+
254+
if same_array and shift != 0:
255+
distances[num_shifts_both_sides - count - 1] = distances[count].T
250256

251257
return distances
252258

@@ -258,14 +264,20 @@ def _compute_similarity_matrix_numpy(
258264

259265
@numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True)
260266
def _compute_similarity_matrix_numba(
261-
templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union"
267+
templates_array,
268+
other_templates_array,
269+
num_shifts,
270+
method,
271+
sparsity_mask,
272+
other_sparsity_mask,
273+
support="union",
262274
):
263275
num_templates = templates_array.shape[0]
264276
num_samples = templates_array.shape[1]
265277
num_channels = templates_array.shape[2]
266278
other_num_templates = other_templates_array.shape[0]
267-
268279
num_shifts_both_sides = 2 * num_shifts + 1
280+
269281
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
270282
same_array = np.array_equal(templates_array, other_templates_array)
271283

@@ -284,74 +296,90 @@ def _compute_similarity_matrix_numba(
284296
elif method == "cosine":
285297
metric = 2
286298

287-
for count in range(len(shift_loop)):
288-
shift = shift_loop[count]
289-
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
290-
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
291-
for i in numba.prange(num_templates):
292-
src_template = src_sliced_templates[i]
299+
overlapping_j_list = numba.typed.List()
300+
active_channels_list = numba.typed.List()
293301

294-
## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays
295-
## So we inline the function here
296-
# local_mask = get_overlapping_mask_for_one_template(i, sparsity, other_sparsity, support=support)
302+
for src_unit in range(num_templates):
303+
overlapping_ids = numba.typed.List()
304+
overlapping_chs = numba.typed.List()
305+
306+
start = src_unit if same_array else 0
307+
for tgt_unit in range(start, other_num_templates):
297308

298309
if support == "intersection":
299-
local_mask = np.logical_and(
300-
sparsity_mask[i, :], other_sparsity_mask
301-
) # shape (other_num_templates, num_channels)
310+
ch = np.where(sparsity_mask[src_unit] & other_sparsity_mask[tgt_unit])[0].astype(np.uint16)
311+
302312
elif support == "union":
303-
connected_mask = np.logical_and(sparsity_mask[i, :], other_sparsity_mask)
304-
not_connected_mask = np.sum(connected_mask, axis=1) == 0
305-
local_mask = np.logical_or(
306-
sparsity_mask[i, :], other_sparsity_mask
307-
) # shape (other_num_templates, num_channels)
308-
for local_i in np.flatnonzero(not_connected_mask):
309-
local_mask[local_i] = False
313+
connected = False
314+
for c in range(num_channels):
315+
if sparsity_mask[src_unit, c] and other_sparsity_mask[tgt_unit, c]:
316+
connected = True
317+
break
318+
if not connected:
319+
ch = np.empty(0, dtype=np.uint16)
320+
else:
321+
ch_list = []
322+
for c in range(num_channels):
323+
if sparsity_mask[src_unit, c] or other_sparsity_mask[tgt_unit, c]:
324+
ch_list.append(c)
325+
ch = np.array(ch_list, dtype=np.uint16)
310326

311327
elif support == "dense":
312-
local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_)
313-
314-
overlapping_templates = np.flatnonzero(np.sum(local_mask, 1))
315-
tgt_templates = tgt_sliced_templates[overlapping_templates]
316-
for gcount in range(len(overlapping_templates)):
317-
318-
j = overlapping_templates[gcount]
319-
src = src_template[:, local_mask[j]].flatten()
320-
tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten()
321-
322-
norm_i = 0
323-
norm_j = 0
324-
distances[count, i, j] = 0
325-
326-
for k in range(len(src)):
327-
if metric == 0:
328-
norm_i += abs(src[k])
329-
norm_j += abs(tgt[k])
330-
distances[count, i, j] += abs(src[k] - tgt[k])
331-
elif metric == 1:
332-
norm_i += src[k] ** 2
333-
norm_j += tgt[k] ** 2
334-
distances[count, i, j] += (src[k] - tgt[k]) ** 2
335-
elif metric == 2:
336-
distances[count, i, j] += src[k] * tgt[k]
337-
norm_i += src[k] ** 2
338-
norm_j += tgt[k] ** 2
328+
ch = np.arange(num_channels, dtype=np.uint16)
329+
330+
if len(ch) > 0:
331+
overlapping_ids.append(np.uint16(tgt_unit))
332+
overlapping_chs.append(ch)
333+
334+
overlapping_j_list.append(overlapping_ids)
335+
active_channels_list.append(overlapping_chs)
336+
337+
for count in range(len(shift_loop)):
338+
shift = shift_loop[count]
339+
340+
src_sliced = templates_array[:, num_shifts : num_samples - num_shifts]
341+
tgt_sliced = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
342+
343+
for i in numba.prange(num_templates):
344+
src_template = src_sliced[i]
345+
overlapping_ids = overlapping_j_list[i]
346+
overlapping_chs = active_channels_list[i]
347+
348+
for pair_idx in range(len(overlapping_ids)):
349+
j = np.uint16(overlapping_ids[pair_idx])
350+
ch = overlapping_chs[pair_idx]
351+
352+
src_ch = src_template[:, ch]
353+
tgt_ch = tgt_sliced[j][:, ch]
339354

340355
if metric == 0:
341-
distances[count, i, j] /= norm_i + norm_j
356+
# l1
357+
norm_i = np.sum(np.abs(src_ch))
358+
norm_j = np.sum(np.abs(tgt_ch))
359+
dist = np.sum(np.abs(src_ch - tgt_ch))
360+
distances[count, i, j] = dist / (norm_i + norm_j)
361+
342362
elif metric == 1:
343-
norm_i = sqrt(norm_i)
344-
norm_j = sqrt(norm_j)
345-
distances[count, i, j] = sqrt(distances[count, i, j])
346-
distances[count, i, j] /= norm_i + norm_j
363+
# l2
364+
norm_i = sqrt(np.sum(src_ch**2))
365+
norm_j = sqrt(np.sum(tgt_ch**2))
366+
dist = sqrt(np.sum((src_ch - tgt_ch) ** 2))
367+
distances[count, i, j] = dist / (norm_i + norm_j)
368+
347369
elif metric == 2:
348-
norm_i = sqrt(norm_i)
349-
norm_j = sqrt(norm_j)
350-
distances[count, i, j] /= norm_i * norm_j
351-
distances[count, i, j] = 1 - distances[count, i, j]
370+
# cosine
371+
dot = np.sum(src_ch * tgt_ch)
372+
norm_i = sqrt(np.sum(src_ch**2))
373+
norm_j = sqrt(np.sum(tgt_ch**2))
374+
denom = norm_i * norm_j
375+
if denom > 0.0:
376+
distances[count, i, j] = 1.0 - dot / denom
352377

353378
if same_array:
354-
distances[num_shifts_both_sides - count - 1, j, i] = distances[count, i, j]
379+
distances[count, j, i] = distances[count, i, j]
380+
381+
if same_array and shift != 0:
382+
distances[num_shifts_both_sides - count - 1] = distances[count].T
355383

356384
return distances
357385

0 commit comments

Comments
 (0)