@@ -81,6 +81,7 @@ def _merge_extension_data(
8181 new_sorting_analyzer .sparsity .mask [keep , :], new_unit_ids , new_sorting_analyzer .channel_ids
8282 )
8383
84+ # TODO: soft merge template similarity
8485 new_similarity , _ = compute_similarity_with_templates_array (
8586 new_templates_array ,
8687 all_templates_array ,
@@ -225,6 +226,8 @@ def _compute_similarity_matrix_numpy(
225226 overlapping_templates = np .flatnonzero (np .sum (local_mask , 1 ))
226227 tgt_templates = tgt_sliced_templates [overlapping_templates ]
227228 for gcount , j in enumerate (overlapping_templates ):
229+ if j < i and same_array :
230+ continue
228231 src = src_template [:, local_mask [j ]].reshape (1 , - 1 )
229232 tgt = (tgt_templates [gcount ][:, local_mask [j ]]).reshape (1 , - 1 )
230233
@@ -246,7 +249,10 @@ def _compute_similarity_matrix_numpy(
246249 distances [count , i , j ] = 1 - distances [count , i , j ]
247250
248251 if same_array :
249- distances [num_shifts_both_sides - count - 1 , j , i ] = distances [count , i , j ]
252+ distances [count , j , i ] = distances [count , i , j ]
253+
254+ if same_array and shift != 0 :
255+ distances [num_shifts_both_sides - count - 1 ] = distances [count ].T
250256
251257 return distances
252258
@@ -258,14 +264,20 @@ def _compute_similarity_matrix_numpy(
258264
259265 @numba .jit (nopython = True , parallel = True , fastmath = True , nogil = True )
260266 def _compute_similarity_matrix_numba (
261- templates_array , other_templates_array , num_shifts , method , sparsity_mask , other_sparsity_mask , support = "union"
267+ templates_array ,
268+ other_templates_array ,
269+ num_shifts ,
270+ method ,
271+ sparsity_mask ,
272+ other_sparsity_mask ,
273+ support = "union" ,
262274 ):
263275 num_templates = templates_array .shape [0 ]
264276 num_samples = templates_array .shape [1 ]
265277 num_channels = templates_array .shape [2 ]
266278 other_num_templates = other_templates_array .shape [0 ]
267-
268279 num_shifts_both_sides = 2 * num_shifts + 1
280+
269281 distances = np .ones ((num_shifts_both_sides , num_templates , other_num_templates ), dtype = np .float32 )
270282 same_array = np .array_equal (templates_array , other_templates_array )
271283
@@ -284,74 +296,90 @@ def _compute_similarity_matrix_numba(
284296 elif method == "cosine" :
285297 metric = 2
286298
287- for count in range (len (shift_loop )):
288- shift = shift_loop [count ]
289- src_sliced_templates = templates_array [:, num_shifts : num_samples - num_shifts ]
290- tgt_sliced_templates = other_templates_array [:, num_shifts + shift : num_samples - num_shifts + shift ]
291- for i in numba .prange (num_templates ):
292- src_template = src_sliced_templates [i ]
299+ overlapping_j_list = numba .typed .List ()
300+ active_channels_list = numba .typed .List ()
293301
294- ## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays
295- ## So we inline the function here
296- # local_mask = get_overlapping_mask_for_one_template(i, sparsity, other_sparsity, support=support)
302+ for src_unit in range (num_templates ):
303+ overlapping_ids = numba .typed .List ()
304+ overlapping_chs = numba .typed .List ()
305+
306+ start = src_unit if same_array else 0
307+ for tgt_unit in range (start , other_num_templates ):
297308
298309 if support == "intersection" :
299- local_mask = np .logical_and (
300- sparsity_mask [i , :], other_sparsity_mask
301- ) # shape (other_num_templates, num_channels)
310+ ch = np .where (sparsity_mask [src_unit ] & other_sparsity_mask [tgt_unit ])[0 ].astype (np .uint16 )
311+
302312 elif support == "union" :
303- connected_mask = np .logical_and (sparsity_mask [i , :], other_sparsity_mask )
304- not_connected_mask = np .sum (connected_mask , axis = 1 ) == 0
305- local_mask = np .logical_or (
306- sparsity_mask [i , :], other_sparsity_mask
307- ) # shape (other_num_templates, num_channels)
308- for local_i in np .flatnonzero (not_connected_mask ):
309- local_mask [local_i ] = False
313+ connected = False
314+ for c in range (num_channels ):
315+ if sparsity_mask [src_unit , c ] and other_sparsity_mask [tgt_unit , c ]:
316+ connected = True
317+ break
318+ if not connected :
319+ ch = np .empty (0 , dtype = np .uint16 )
320+ else :
321+ ch_list = []
322+ for c in range (num_channels ):
323+ if sparsity_mask [src_unit , c ] or other_sparsity_mask [tgt_unit , c ]:
324+ ch_list .append (c )
325+ ch = np .array (ch_list , dtype = np .uint16 )
310326
311327 elif support == "dense" :
312- local_mask = np .ones ((other_num_templates , num_channels ), dtype = np .bool_ )
313-
314- overlapping_templates = np .flatnonzero (np .sum (local_mask , 1 ))
315- tgt_templates = tgt_sliced_templates [overlapping_templates ]
316- for gcount in range (len (overlapping_templates )):
317-
318- j = overlapping_templates [gcount ]
319- src = src_template [:, local_mask [j ]].flatten ()
320- tgt = (tgt_templates [gcount ][:, local_mask [j ]]).flatten ()
321-
322- norm_i = 0
323- norm_j = 0
324- distances [count , i , j ] = 0
325-
326- for k in range (len (src )):
327- if metric == 0 :
328- norm_i += abs (src [k ])
329- norm_j += abs (tgt [k ])
330- distances [count , i , j ] += abs (src [k ] - tgt [k ])
331- elif metric == 1 :
332- norm_i += src [k ] ** 2
333- norm_j += tgt [k ] ** 2
334- distances [count , i , j ] += (src [k ] - tgt [k ]) ** 2
335- elif metric == 2 :
336- distances [count , i , j ] += src [k ] * tgt [k ]
337- norm_i += src [k ] ** 2
338- norm_j += tgt [k ] ** 2
328+ ch = np .arange (num_channels , dtype = np .uint16 )
329+
330+ if len (ch ) > 0 :
331+ overlapping_ids .append (np .uint16 (tgt_unit ))
332+ overlapping_chs .append (ch )
333+
334+ overlapping_j_list .append (overlapping_ids )
335+ active_channels_list .append (overlapping_chs )
336+
337+ for count in range (len (shift_loop )):
338+ shift = shift_loop [count ]
339+
340+ src_sliced = templates_array [:, num_shifts : num_samples - num_shifts ]
341+ tgt_sliced = other_templates_array [:, num_shifts + shift : num_samples - num_shifts + shift ]
342+
343+ for i in numba .prange (num_templates ):
344+ src_template = src_sliced [i ]
345+ overlapping_ids = overlapping_j_list [i ]
346+ overlapping_chs = active_channels_list [i ]
347+
348+ for pair_idx in range (len (overlapping_ids )):
349+ j = np .uint16 (overlapping_ids [pair_idx ])
350+ ch = overlapping_chs [pair_idx ]
351+
352+ src_ch = src_template [:, ch ]
353+ tgt_ch = tgt_sliced [j ][:, ch ]
339354
340355 if metric == 0 :
341- distances [count , i , j ] /= norm_i + norm_j
356+ # l1
357+ norm_i = np .sum (np .abs (src_ch ))
358+ norm_j = np .sum (np .abs (tgt_ch ))
359+ dist = np .sum (np .abs (src_ch - tgt_ch ))
360+ distances [count , i , j ] = dist / (norm_i + norm_j )
361+
342362 elif metric == 1 :
343- norm_i = sqrt (norm_i )
344- norm_j = sqrt (norm_j )
345- distances [count , i , j ] = sqrt (distances [count , i , j ])
346- distances [count , i , j ] /= norm_i + norm_j
363+ # l2
364+ norm_i = sqrt (np .sum (src_ch ** 2 ))
365+ norm_j = sqrt (np .sum (tgt_ch ** 2 ))
366+ dist = sqrt (np .sum ((src_ch - tgt_ch ) ** 2 ))
367+ distances [count , i , j ] = dist / (norm_i + norm_j )
368+
347369 elif metric == 2 :
348- norm_i = sqrt (norm_i )
349- norm_j = sqrt (norm_j )
350- distances [count , i , j ] /= norm_i * norm_j
351- distances [count , i , j ] = 1 - distances [count , i , j ]
370+ # cosine
371+ dot = np .sum (src_ch * tgt_ch )
372+ norm_i = sqrt (np .sum (src_ch ** 2 ))
373+ norm_j = sqrt (np .sum (tgt_ch ** 2 ))
374+ denom = norm_i * norm_j
375+ if denom > 0.0 :
376+ distances [count , i , j ] = 1.0 - dot / denom
352377
353378 if same_array :
354- distances [num_shifts_both_sides - count - 1 , j , i ] = distances [count , i , j ]
379+ distances [count , j , i ] = distances [count , i , j ]
380+
381+ if same_array and shift != 0 :
382+ distances [num_shifts_both_sides - count - 1 ] = distances [count ].T
355383
356384 return distances
357385
0 commit comments