Skip to content

Commit 2bc9006

Browse files
authored
Improve multithreading and multi processing : lock + nogil + mp_context (#4333)
1 parent 4dc7b28 commit 2bc9006

File tree

11 files changed

+104
-38
lines changed

11 files changed

+104
-38
lines changed

src/spikeinterface/core/job_tools.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,25 @@ def __init__(
418418

419419
if pool_engine == "process":
420420
if mp_context is None:
421-
mp_context = recording.get_preferred_mp_context()
422-
if mp_context is not None and platform.system() == "Windows":
423-
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
424-
elif mp_context == "fork" and platform.system() == "Darwin":
425-
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')
421+
# auto choice
422+
if platform.system() == "Windows":
423+
mp_context = "spawn"
424+
elif platform.system() == "Linux":
425+
mp_context = "fork"
426+
elif platform.system() == "Darwin":
427+
# We used to force spawn for macos, this is sad but in some cases fork in macos
428+
# is very unstable and lead to crashes.
429+
mp_context = "spawn"
430+
else:
431+
mp_context = "spawn"
432+
433+
preferred_mp_context = recording.get_preferred_mp_context()
434+
if preferred_mp_context is not None and preferred_mp_context != mp_context:
435+
warnings.warn(
436+
f"Your processing chain using pool_engine='process' and mp_context='{mp_context}' is not possible."
437+
f"So use mp_context='{preferred_mp_context}' instead"
438+
)
439+
mp_context = preferred_mp_context
426440

427441
self.mp_context = mp_context
428442

@@ -485,9 +499,14 @@ def run(self, recording_slices=None):
485499
recording_slices, desc=f"{self.job_name} (no parallelization)", total=len(recording_slices)
486500
)
487501

488-
worker_dict = self.init_func(*self.init_args)
502+
init_args = self.init_args
503+
if self.need_worker_index:
504+
worker_index = 0
505+
init_args = init_args + (worker_index,)
506+
507+
worker_dict = self.init_func(*init_args)
489508
if self.need_worker_index:
490-
worker_dict["worker_index"] = 0
509+
worker_dict["worker_index"] = worker_index
491510

492511
for segment_index, frame_start, frame_stop in recording_slices:
493512
res = self.func(segment_index, frame_start, frame_stop, worker_dict)
@@ -502,6 +521,8 @@ def run(self, recording_slices=None):
502521
if self.pool_engine == "process":
503522

504523
if self.need_worker_index:
524+
525+
multiprocessing.set_start_method(self.mp_context, force=True)
505526
lock = multiprocessing.Lock()
506527
array_pid = multiprocessing.Array("i", n_jobs)
507528
for i in range(n_jobs):
@@ -529,7 +550,9 @@ def run(self, recording_slices=None):
529550

530551
if self.progress_bar:
531552
results = tqdm(
532-
results, desc=f"{self.job_name} (workers: {n_jobs} processes)", total=len(recording_slices)
553+
results,
554+
desc=f"{self.job_name} (workers: {n_jobs} processes {self.mp_context})",
555+
total=len(recording_slices),
533556
)
534557

535558
for res in results:
@@ -618,11 +641,6 @@ def __call__(self, args):
618641

619642
def process_worker_initializer(func, init_func, init_args, max_threads_per_worker, need_worker_index, lock, array_pid):
620643
global _process_func_wrapper
621-
if max_threads_per_worker is None:
622-
worker_dict = init_func(*init_args)
623-
else:
624-
with threadpool_limits(limits=max_threads_per_worker):
625-
worker_dict = init_func(*init_args)
626644

627645
if need_worker_index:
628646
child_process = multiprocessing.current_process()
@@ -633,9 +651,19 @@ def process_worker_initializer(func, init_func, init_args, max_threads_per_worke
633651
worker_index = i
634652
array_pid[i] = child_process.ident
635653
break
636-
worker_dict["worker_index"] = worker_index
637654
lock.release()
638655

656+
init_args = init_args + (worker_index,)
657+
658+
if max_threads_per_worker is None:
659+
worker_dict = init_func(*init_args)
660+
else:
661+
with threadpool_limits(limits=max_threads_per_worker):
662+
worker_dict = init_func(*init_args)
663+
664+
if need_worker_index:
665+
worker_dict["worker_index"] = worker_index
666+
639667
_process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker)
640668

641669

@@ -651,19 +679,23 @@ def process_function_wrapper(args):
651679
def thread_worker_initializer(
652680
func, init_func, init_args, max_threads_per_worker, thread_local_data, need_worker_index, lock
653681
):
682+
683+
if need_worker_index:
684+
lock.acquire()
685+
global _thread_started
686+
worker_index = _thread_started
687+
_thread_started += 1
688+
lock.release()
689+
init_args = init_args + (worker_index,)
690+
654691
if max_threads_per_worker is None:
655692
worker_dict = init_func(*init_args)
656693
else:
657694
with threadpool_limits(limits=max_threads_per_worker):
658695
worker_dict = init_func(*init_args)
659696

660697
if need_worker_index:
661-
lock.acquire()
662-
global _thread_started
663-
worker_index = _thread_started
664-
_thread_started += 1
665698
worker_dict["worker_index"] = worker_index
666-
lock.release()
667699

668700
thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker)
669701

src/spikeinterface/core/node_pipeline.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional, Type
55

66
import struct
7+
import copy
78

89
from pathlib import Path
910

@@ -71,6 +72,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar
7172

7273
class PeakSource(PipelineNode):
7374

75+
# this is an important hack : this force a node.compute() before the machininery is started
76+
# this trigger eventually some numba jit compilation and avoid compilation racing
77+
# between processes or threads
78+
need_first_call_before_pipeline = False
79+
7480
def get_trace_margin(self):
7581
raise NotImplementedError
7682

@@ -86,6 +92,12 @@ def get_peak_slice(
8692
# not needed for PeakDetector
8793
raise NotImplementedError
8894

95+
def _first_call_before_pipeline(self):
96+
# see need_first_call_before_pipeline = True
97+
margin = self.get_trace_margin()
98+
traces = self.recording.get_traces(start_frame=0, end_frame=margin * 2 + 1, segment_index=0)
99+
self.compute(traces, 0, margin * 2 + 1, 0, margin)
100+
89101

90102
# this is used in sorting components
91103
class PeakDetector(PeakSource):
@@ -601,6 +613,11 @@ def run_node_pipeline(
601613
else:
602614
raise ValueError(f"wrong gather_mode : {gather_mode}")
603615

616+
node0 = nodes[0]
617+
if isinstance(node0, PeakSource) and node0.need_first_call_before_pipeline:
618+
# See need_first_call_before_pipeline : this trigger numba compilation before the run
619+
node0._first_call_before_pipeline()
620+
604621
init_args = (recording, nodes, skip_after_n_peaks_per_worker)
605622

606623
processor = ChunkRecordingExecutor(

src/spikeinterface/core/numpyextractors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def __init__(
157157
shm = SharedMemory(shm_name, create=False)
158158
self.shms.append(shm)
159159
traces = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
160+
# Force read only
161+
traces.flags.writeable = False
160162
traces_list.append(traces)
161163

162164
if channel_ids is None:

src/spikeinterface/core/recording_tools.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,9 @@ def get_chunk_with_margin(
902902
taper = taper[:, np.newaxis]
903903
traces_chunk2[:margin] *= taper
904904
traces_chunk2[-margin:] *= taper[::-1]
905+
# enforce non writable when original was not
906+
# (this help numba to have the same signature and not compile twice)
907+
traces_chunk2.flags.writeable = traces_chunk.flags.writeable
905908
traces_chunk = traces_chunk2
906909
elif add_reflect_padding:
907910
# in this case, we don't want to taper

src/spikeinterface/core/tests/test_job_tools.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,9 @@ def func2(segment_index, start_frame, end_frame, worker_dict):
242242
return worker_dict["worker_index"]
243243

244244

245-
def init_func2():
245+
def init_func2(worker_index):
246246
# this leave time for other thread/process to start
247+
# print('in init_func2 with worker_index', worker_index)
247248
time.sleep(0.010)
248249
worker_dict = {}
249250
return worker_dict
@@ -256,6 +257,7 @@ def test_worker_index():
256257
for i in range(2):
257258
# making this 2 times ensure to test that global variables are correctly reset
258259
for pool_engine in ("process", "thread"):
260+
# print(pool_engine)
259261
processor = ChunkRecordingExecutor(
260262
recording,
261263
func2,
@@ -323,7 +325,7 @@ def test_get_best_job_kwargs():
323325
# test_ChunkRecordingExecutor()
324326
# test_fix_job_kwargs()
325327
# test_split_job_kwargs()
326-
# test_worker_index()
327-
test_get_best_job_kwargs()
328+
test_worker_index()
329+
# test_get_best_job_kwargs()
328330

329331
# quick_becnhmark()

src/spikeinterface/core/waveform_tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,7 @@ def _init_worker_estimate_templates(
996996
nafter,
997997
return_in_uV,
998998
sparsity_mask,
999+
worker_index,
9991000
):
10001001
worker_dict = {}
10011002
worker_dict["recording"] = recording

src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
##########################
4848
# isocut zone
4949

50-
@numba.jit(nopython=True)
50+
@numba.jit(nopython=True, nogil=True)
5151
def jisotonic5(x, weights):
5252
N = x.shape[0]
5353

@@ -100,7 +100,7 @@ def jisotonic5(x, weights):
100100

101101
return y, MSE
102102

103-
@numba.jit(nopython=True)
103+
@numba.jit(nopython=True, nogil=True)
104104
def updown_arange(num_bins, dtype=np.int_):
105105
num_bins_1 = int(np.ceil(num_bins / 2))
106106
num_bins_2 = num_bins - num_bins_1
@@ -111,7 +111,7 @@ def updown_arange(num_bins, dtype=np.int_):
111111
)
112112
)
113113

114-
@numba.jit(nopython=True)
114+
@numba.jit(nopython=True, nogil=True)
115115
def compute_ks4(counts1, counts2):
116116
c1s = counts1.sum()
117117
c2s = counts2.sum()
@@ -123,7 +123,7 @@ def compute_ks4(counts1, counts2):
123123
ks *= np.sqrt((c1s + c2s) / 2)
124124
return ks
125125

126-
@numba.jit(nopython=True)
126+
@numba.jit(nopython=True, nogil=True)
127127
def compute_ks5(counts1, counts2):
128128
best_ks = -np.inf
129129
length = counts1.size
@@ -138,7 +138,7 @@ def compute_ks5(counts1, counts2):
138138

139139
return best_ks, best_length
140140

141-
@numba.jit(nopython=True)
141+
@numba.jit(nopython=True, nogil=True)
142142
def up_down_isotonic_regression(x, weights=None):
143143
# determine switch point
144144
_, mse1 = jisotonic5(x, weights)
@@ -153,14 +153,14 @@ def up_down_isotonic_regression(x, weights=None):
153153

154154
return np.hstack((y1, y2))
155155

156-
@numba.jit(nopython=True)
156+
@numba.jit(nopython=True, nogil=True)
157157
def down_up_isotonic_regression(x, weights=None):
158158
return -up_down_isotonic_regression(-x, weights=weights)
159159

160160
# num_bins_factor = 1
161161
float_0 = np.array([0.0])
162162

163-
@numba.jit(nopython=True)
163+
@numba.jit(nopython=True, nogil=True)
164164
def isocut(samples): # , sample_weights=None isosplit6 not handle weight anymore
165165
"""
166166
Compute a dip-test to check if 1-d samples are unimodal or not.
@@ -464,7 +464,7 @@ def ensure_continuous_labels(labels):
464464

465465
if HAVE_NUMBA:
466466

467-
@numba.jit(nopython=True)
467+
@numba.jit(nopython=True, nogil=True)
468468
def compute_centroids_and_covmats(X, centroids, covmats, labels, label_set, to_compute_mask):
469469
## manual loop with numba to be faster
470470

@@ -498,7 +498,7 @@ def compute_centroids_and_covmats(X, centroids, covmats, labels, label_set, to_c
498498
if to_compute_mask[i] and count[i] > 0:
499499
covmats[i, :, :] /= count[i]
500500

501-
@numba.jit(nopython=True)
501+
@numba.jit(nopython=True, nogil=True)
502502
def get_pairs_to_compare(centroids, comparisons_made, active_labels_mask):
503503
n = centroids.shape[0]
504504

@@ -526,7 +526,7 @@ def get_pairs_to_compare(centroids, comparisons_made, active_labels_mask):
526526

527527
return pairs
528528

529-
@numba.jit(nopython=True)
529+
@numba.jit(nopython=True, nogil=True)
530530
def compute_distances(centroids, comparisons_made, active_labels_mask):
531531
n = centroids.shape[0]
532532
dists = np.zeros((n, n), dtype=centroids.dtype)
@@ -548,7 +548,7 @@ def compute_distances(centroids, comparisons_made, active_labels_mask):
548548

549549
return dists
550550

551-
@numba.jit(nopython=True)
551+
@numba.jit(nopython=True, nogil=True)
552552
def merge_test(X1, X2, centroid1, centroid2, covmat1, covmat2, isocut_threshold):
553553

554554
if X1.size == 0 or X2.size == 0:
@@ -584,7 +584,7 @@ def merge_test(X1, X2, centroid1, centroid2, covmat1, covmat2, isocut_threshold)
584584

585585
return do_merge, L12
586586

587-
@numba.jit(nopython=True)
587+
@numba.jit(nopython=True, nogil=True)
588588
def compare_pairs(X, labels, pairs, centroids, covmats, min_cluster_size, isocut_threshold):
589589

590590
clusters_changed_mask = np.zeros(centroids.shape[0], dtype="bool")

src/spikeinterface/sortingcomponents/matching/nearest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ class NearestTemplatesPeeler(BaseTemplateMatching):
1414

1515
name = "nearest"
1616
need_noise_levels = True
17+
# this is because numba
18+
need_first_call_before_pipeline = True
19+
1720
params_doc = """
1821
peak_sign : 'neg' | 'pos' | 'both'
1922
The peak sign to use for detection

src/spikeinterface/sortingcomponents/matching/tdc_peeler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class TridesclousPeeler(BaseTemplateMatching):
4747

4848
name = "tdc-peeler"
4949
need_noise_levels = True
50+
# this is because numba
51+
need_first_call_before_pipeline = True
5052
params_doc = """
5153
peak_sign : str
5254
'neg', 'pos' or 'both'
@@ -901,7 +903,7 @@ def fit_one_amplitude_with_neighbors(
901903
if HAVE_NUMBA:
902904
from numba import jit, prange
903905

904-
@jit(nopython=True)
906+
@jit(nopython=True, nogil=True)
905907
def construct_prediction_sparse(
906908
spikes, traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, additive
907909
):
@@ -931,7 +933,7 @@ def construct_prediction_sparse(
931933
if template_sparsity_mask[cluster_index, chan]:
932934
chan_in_template += 1
933935

934-
@jit(nopython=True)
936+
@jit(nopython=True, nogil=True)
935937
def numba_sparse_distance(
936938
wf, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, possible_clusters
937939
):
@@ -967,7 +969,7 @@ def numba_sparse_distance(
967969
distances[i] = sum_dist
968970
return distances
969971

970-
@jit(nopython=True)
972+
@jit(nopython=True, nogil=True)
971973
def numba_best_shift_sparse(
972974
traces, sparse_template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity
973975
):

src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class LocallyExclusivePeakDetector(PeakDetector):
3939
engine = "numba"
4040
need_noise_levels = True
4141
preferred_mp_context = None
42+
43+
# this is because numba
44+
need_first_call_before_pipeline = True
4245
params_doc = ByChannelPeakDetector.params_doc + """
4346
radius_um: float
4447
The radius to use to select neighbour channels for locally exclusive detection.

0 commit comments

Comments
 (0)