@@ -45,6 +45,9 @@ class ComputeCorrelograms(AnalyzerExtension):
4545 bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ...
4646 method : "auto" | "numpy" | "numba", default: "auto"
4747 If "auto" and numba is installed, numba is used, otherwise numpy is used.
48+ fast_mode : "auto" | "on" | "off", default: "auto"
49+ If "auto", a faster multithreaded implementations is used if method is "numba" and
50+ if the number of units is greater than 300.
4851
4952 Returns
5053 -------
@@ -88,8 +91,8 @@ class ComputeCorrelograms(AnalyzerExtension):
8891 use_nodepipeline = False
8992 need_job_kwargs = False
9093
91- def _set_params (self , window_ms : float = 50.0 , bin_ms : float = 1.0 , method : str = "auto" ):
92- params = dict (window_ms = window_ms , bin_ms = bin_ms , method = method )
94+ def _set_params (self , window_ms : float = 50.0 , bin_ms : float = 1.0 , method : str = "auto" , fast_mode : str = "auto" ):
95+ params = dict (window_ms = window_ms , bin_ms = bin_ms , method = method , fast_mode = fast_mode )
9396
9497 return params
9598
@@ -215,6 +218,7 @@ def compute_correlograms(
215218 window_ms : float = 50.0 ,
216219 bin_ms : float = 1.0 ,
217220 method : str = "auto" ,
221+ fast_mode : str = "auto" ,
218222):
219223 """
220224 Compute correlograms using Numba or Numpy.
@@ -225,11 +229,11 @@ def compute_correlograms(
225229
226230 if isinstance (sorting_analyzer_or_sorting , SortingAnalyzer ):
227231 return compute_correlograms_sorting_analyzer (
228- sorting_analyzer_or_sorting , window_ms = window_ms , bin_ms = bin_ms , method = method
232+ sorting_analyzer_or_sorting , window_ms = window_ms , bin_ms = bin_ms , method = method , fast_mode = fast_mode
229233 )
230234 else :
231235 return _compute_correlograms_on_sorting (
232- sorting_analyzer_or_sorting , window_ms = window_ms , bin_ms = bin_ms , method = method
236+ sorting_analyzer_or_sorting , window_ms = window_ms , bin_ms = bin_ms , method = method , fast_mode = fast_mode
233237 )
234238
235239
@@ -299,7 +303,7 @@ def _compute_num_bins(window_size, bin_size):
299303 return num_bins , num_half_bins
300304
301305
302- def _compute_correlograms_on_sorting (sorting , window_ms , bin_ms , method = "auto" ):
306+ def _compute_correlograms_on_sorting (sorting , window_ms , bin_ms , method = "auto" , fast_mode = "auto" ):
303307 """
304308 Computes cross-correlograms from multiple units.
305309
@@ -318,6 +322,9 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"):
318322 method : str
319323 To use "numpy" or "numba". "auto" will use numba if available,
320324 otherwise numpy.
325+ fast_mode : "auto" | "on" | "off", default: "auto"
326+ If "auto", a faster multithreaded implementations is used if method is "numba" and
327+ if the number of units is greater than 300.
321328
322329 Returns
323330 -------
@@ -333,12 +340,20 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"):
333340 if method == "auto" :
334341 method = "numba" if HAVE_NUMBA else "numpy"
335342
343+ if method == "numba" and fast_mode == "auto" :
344+ num_units = len (sorting .unit_ids )
345+ fast_mode = num_units > 300
346+ elif fast_mode == "off" :
347+ fast_mode = False
348+ elif fast_mode == "on" :
349+ fast_mode = True
350+
336351 bins , window_size , bin_size = _make_bins (sorting , window_ms , bin_ms )
337352
338353 if method == "numpy" :
339354 correlograms = _compute_correlograms_numpy (sorting , window_size , bin_size )
340355 if method == "numba" :
341- correlograms = _compute_correlograms_numba (sorting , window_size , bin_size )
356+ correlograms = _compute_correlograms_numba (sorting , window_size , bin_size , fast_mode = fast_mode )
342357
343358 return correlograms , bins
344359
@@ -483,7 +498,7 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi
483498 return correlograms
484499
485500
486- def _compute_correlograms_numba (sorting , window_size , bin_size ):
501+ def _compute_correlograms_numba (sorting , window_size , bin_size , fast_mode ):
487502 """
488503 Computes cross-correlograms between all units in `sorting`.
489504
@@ -499,6 +514,9 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
499514 The window size over which to perform the cross-correlation, in samples
500515 bin_size : int
501516 The size of which to bin lags, in samples.
517+ fast_mode : bool
518+ If True, use faster implementations (currently only if method is 'numba'),
519+ at the cost of possible minor numerical differences.
502520
503521 Returns
504522 -------
@@ -516,6 +534,11 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
516534 spikes = sorting .to_spike_vector (concatenated = False )
517535 correlograms = np .zeros ((num_units , num_units , num_bins ), dtype = np .int64 )
518536
537+ if fast_mode :
538+ num_threads = mp .cpu_count ()
539+ else :
540+ num_threads = 1
541+
519542 for seg_index in range (sorting .get_num_segments ()):
520543 spike_times = spikes [seg_index ]["sample_index" ]
521544 spike_unit_indices = spikes [seg_index ]["unit_index" ]
@@ -527,6 +550,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
527550 window_size ,
528551 bin_size ,
529552 num_half_bins ,
553+ num_threads ,
530554 )
531555
532556 return correlograms
@@ -539,9 +563,10 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
539563 nopython = True ,
540564 nogil = True ,
541565 cache = False ,
566+ parallel = True ,
542567 )
543568 def _compute_correlograms_one_segment_numba (
544- correlograms , spike_times , spike_unit_indices , window_size , bin_size , num_half_bins
569+ correlograms , spike_times , spike_unit_indices , window_size , bin_size , num_half_bins , num_threads
545570 ):
546571 """
547572 Compute the correlograms using `numba` for speed.
@@ -570,9 +595,12 @@ def _compute_correlograms_one_segment_numba(
570595 The window size over which to perform the cross-correlation, in samples
571596 bin_size : int
572597 The size of which to bin lags, in samples.
598+ num_threads : int
599+ The number of threads to use in parallel.
573600 """
601+ numba .set_num_threads (num_threads )
574602 start_j = 0
575- for i in range (spike_times .size ):
603+ for i in numba . prange (spike_times .size ):
576604 for j in range (start_j , spike_times .size ):
577605 if i == j :
578606 continue
0 commit comments