Skip to content

Commit 25940b4

Browse files
ygerpre-commit-ci[bot]alejoe91
authored
Option to turn on parallel computation for CCG in numba (#4305)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent bbdc2ed commit 25940b4

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

src/spikeinterface/postprocessing/correlograms.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class ComputeCorrelograms(AnalyzerExtension):
4545
bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ...
4646
method : "auto" | "numpy" | "numba", default: "auto"
4747
If "auto" and numba is installed, numba is used, otherwise numpy is used.
48+
fast_mode : "auto" | "on" | "off", default: "auto"
49+
If "auto", a faster multithreaded implementations is used if method is "numba" and
50+
if the number of units is greater than 300.
4851
4952
Returns
5053
-------
@@ -88,8 +91,8 @@ class ComputeCorrelograms(AnalyzerExtension):
8891
use_nodepipeline = False
8992
need_job_kwargs = False
9093

91-
def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"):
92-
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method)
94+
def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto"):
95+
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode)
9396

9497
return params
9598

@@ -215,6 +218,7 @@ def compute_correlograms(
215218
window_ms: float = 50.0,
216219
bin_ms: float = 1.0,
217220
method: str = "auto",
221+
fast_mode: str = "auto",
218222
):
219223
"""
220224
Compute correlograms using Numba or Numpy.
@@ -225,11 +229,11 @@ def compute_correlograms(
225229

226230
if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
227231
return compute_correlograms_sorting_analyzer(
228-
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method
232+
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
229233
)
230234
else:
231235
return _compute_correlograms_on_sorting(
232-
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method
236+
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
233237
)
234238

235239

@@ -299,7 +303,7 @@ def _compute_num_bins(window_size, bin_size):
299303
return num_bins, num_half_bins
300304

301305

302-
def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"):
306+
def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode="auto"):
303307
"""
304308
Computes cross-correlograms from multiple units.
305309
@@ -318,6 +322,9 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"):
318322
method : str
319323
To use "numpy" or "numba". "auto" will use numba if available,
320324
otherwise numpy.
325+
fast_mode : "auto" | "on" | "off", default: "auto"
326+
If "auto", a faster multithreaded implementations is used if method is "numba" and
327+
if the number of units is greater than 300.
321328
322329
Returns
323330
-------
@@ -333,12 +340,20 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"):
333340
if method == "auto":
334341
method = "numba" if HAVE_NUMBA else "numpy"
335342

343+
if method == "numba" and fast_mode == "auto":
344+
num_units = len(sorting.unit_ids)
345+
fast_mode = num_units > 300
346+
elif fast_mode == "off":
347+
fast_mode = False
348+
elif fast_mode == "on":
349+
fast_mode = True
350+
336351
bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms)
337352

338353
if method == "numpy":
339354
correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size)
340355
if method == "numba":
341-
correlograms = _compute_correlograms_numba(sorting, window_size, bin_size)
356+
correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode=fast_mode)
342357

343358
return correlograms, bins
344359

@@ -483,7 +498,7 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi
483498
return correlograms
484499

485500

486-
def _compute_correlograms_numba(sorting, window_size, bin_size):
501+
def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode):
487502
"""
488503
Computes cross-correlograms between all units in `sorting`.
489504
@@ -499,6 +514,9 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
499514
The window size over which to perform the cross-correlation, in samples
500515
bin_size : int
501516
The size of which to bin lags, in samples.
517+
fast_mode : bool
518+
If True, use faster implementations (currently only if method is 'numba'),
519+
at the cost of possible minor numerical differences.
502520
503521
Returns
504522
-------
@@ -516,6 +534,11 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
516534
spikes = sorting.to_spike_vector(concatenated=False)
517535
correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64)
518536

537+
if fast_mode:
538+
num_threads = mp.cpu_count()
539+
else:
540+
num_threads = 1
541+
519542
for seg_index in range(sorting.get_num_segments()):
520543
spike_times = spikes[seg_index]["sample_index"]
521544
spike_unit_indices = spikes[seg_index]["unit_index"]
@@ -527,6 +550,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
527550
window_size,
528551
bin_size,
529552
num_half_bins,
553+
num_threads,
530554
)
531555

532556
return correlograms
@@ -539,9 +563,10 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
539563
nopython=True,
540564
nogil=True,
541565
cache=False,
566+
parallel=True,
542567
)
543568
def _compute_correlograms_one_segment_numba(
544-
correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins
569+
correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins, num_threads
545570
):
546571
"""
547572
Compute the correlograms using `numba` for speed.
@@ -570,9 +595,12 @@ def _compute_correlograms_one_segment_numba(
570595
The window size over which to perform the cross-correlation, in samples
571596
bin_size : int
572597
The size of which to bin lags, in samples.
598+
num_threads : int
599+
The number of threads to use in parallel.
573600
"""
601+
numba.set_num_threads(num_threads)
574602
start_j = 0
575-
for i in range(spike_times.size):
603+
for i in numba.prange(spike_times.size):
576604
for j in range(start_j, spike_times.size):
577605
if i == j:
578606
continue

src/spikeinterface/postprocessing/tests/test_correlograms.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,28 @@ def test_equal_results_correlograms(window_and_bin_ms):
104104
assert np.array_equal(result_numpy, result_numba)
105105

106106

107+
@pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")
108+
@pytest.mark.parametrize("window_and_bin_ms", [(60.0, 2.0), (3.57, 1.6421)])
109+
def test_equal_results_fast_correlograms(window_and_bin_ms):
110+
"""
111+
Test that the 2 methods have same results with some varied time bins
112+
that are not tested in other tests.
113+
"""
114+
115+
window_ms, bin_ms = window_and_bin_ms
116+
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)
117+
118+
result_numba_fast, bins_numba_fast = _compute_correlograms_on_sorting(
119+
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True
120+
)
121+
result_numba, bins_numba = _compute_correlograms_on_sorting(
122+
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False
123+
)
124+
from numpy.testing import assert_almost_equal
125+
126+
assert_almost_equal(result_numba_fast, result_numba)
127+
128+
107129
@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)])
108130
def test_flat_cross_correlogram(method):
109131
"""

0 commit comments

Comments
 (0)