Skip to content

Commit ae1a0d8

Browse files
authored
Big clean in components and reoraginze in folder (#4140)
1 parent 2e1f280 commit ae1a0d8

95 files changed

Lines changed: 3874 additions & 5399 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

doc/how_to/analyze_neuropixels.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ Let’s use here the ``locally_exclusive`` method for detection and the
442442
443443
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
444444
445-
peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs)
445+
peak_locations = localize_peaks(rec, peaks, method='center_of_mass', method_kwargs=dict(radius_um=50.), **job_kwargs)
446446
447447
448448

doc/modules/motion_correction.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,9 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte
175175
peaks = detect_peaks(recording=rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs)
176176
# (optional) sub-select some peaks to speed up the localization
177177
peaks = select_peaks(peaks=peaks, ...)
178-
peak_locations = localize_peaks(recording=rec, peaks=peaks, method="monopolar_triangulation",radius_um=75.0,
179-
max_distance_um=150.0, **job_kwargs)
178+
peak_locations = localize_peaks(recording=rec, peaks=peaks, method="monopolar_triangulation",
179+
method_kwargs(radius_um=75.0,max_distance_um=150.0),
180+
job_kwargs=job_kwargs)
180181
181182
# Step 2: motion inference
182183
motion = estimate_motion(

doc/modules/sortingcomponents.rst

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ follows:
5656
exclude_sweep_ms=0.2,
5757
noise_levels=None,
5858
random_chunk_kwargs={},
59-
gather_mode='memory',
60-
**job_kwargs,
59+
job_kwargs=job_kwargs,
6160
)
6261
6362
The output :code:`peaks` is a NumPy array with a length of the number of peaks found and the following dtype:
@@ -99,10 +98,12 @@ follows:
9998
recording=recording,
10099
peaks=peaks,
101100
method='center_of_mass',
102-
radius_um=70.,
103-
ms_before=0.3,
104-
ms_after=0.6,
105-
**job_kwargs
101+
method_kwargs=dict(
102+
radius_um=70.,
103+
ms_before=0.3,
104+
ms_after=0.6,
105+
),
106+
job_kwargs=job_kwargs,
106107
)
107108
108109

examples/how_to/analyze_neuropixels.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,18 @@
169169
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
170170

171171
job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
172-
peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16,
173-
detect_threshold=5, radius_um=50., **job_kwargs)
172+
peaks = detect_peaks(rec,
173+
method='locally_exclusive',
174+
method_kwargs=dict(
175+
noise_levels=noise_levels_int16,
176+
detect_threshold=5, radius_um=50.),
177+
job_kwargs=job_kwargs)
174178
peaks
175179

176180
# +
177181
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
178182

179-
peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs)
183+
peak_locations = localize_peaks(rec, peaks, method='center_of_mass', method_kwargs=dict(radius_um=50.), job_kwargs=job_kwargs)
180184
# -
181185

182186
# ### Check for drifts

examples/tutorials/widgets/plot_4_peaks_gallery.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@
2323

2424
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
2525

26+
si.set_global_job_kwargs(chunk_memory="10M", n_jobs=1, progress_bar=True)
2627
rec_filtered = si.bandpass_filter(recording=recording, freq_min=300.0, freq_max=6000.0, margin_ms=5.0)
2728
print(rec_filtered)
2829
peaks = detect_peaks(
2930
recording=rec_filtered,
3031
method="locally_exclusive",
31-
peak_sign="neg",
32-
detect_threshold=6,
33-
exclude_sweep_ms=0.3,
34-
radius_um=100,
35-
noise_levels=None,
36-
random_chunk_kwargs={},
37-
chunk_memory="10M",
38-
n_jobs=1,
39-
progress_bar=True,
32+
method_kwargs=dict(
33+
peak_sign="neg",
34+
detect_threshold=6,
35+
exclude_sweep_ms=0.3,
36+
radius_um=100,
37+
noise_levels=None,
38+
)
39+
4040
)
4141

4242
##############################################################################

src/spikeinterface/benchmark/benchmark_clustering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks
3+
from spikeinterface.sortingcomponents.clustering import find_clusters_from_peaks
44
from spikeinterface.core import NumpySorting
55
from spikeinterface.comparison import GroundTruthComparison
66
from spikeinterface.widgets import (
@@ -30,8 +30,8 @@ def __init__(self, recording, gt_sorting, params, indices, peaks, exhaustive_gt=
3030
self.result = {}
3131

3232
def run(self, **job_kwargs):
33-
labels, peak_labels = find_cluster_from_peaks(
34-
self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs
33+
labels, peak_labels = find_clusters_from_peaks(
34+
self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, job_kwargs=job_kwargs
3535
)
3636
self.result["peak_labels"] = peak_labels
3737

src/spikeinterface/benchmark/benchmark_matching.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ def __init__(self, recording, gt_sorting, params):
2121
self.recording = recording
2222
self.gt_sorting = gt_sorting
2323
self.method = params["method"]
24-
self.templates = params["method_kwargs"]["templates"]
24+
# self.templates = params["method_kwargs"]["templates"]
25+
self.templates = params["templates"]
2526
self.method_kwargs = params["method_kwargs"]
2627
self.result = {}
2728

2829
def run(self, **job_kwargs):
2930
spikes = find_spikes_from_templates(
30-
self.recording, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs
31+
self.recording, self.templates, method=self.method, method_kwargs=self.method_kwargs, job_kwargs=job_kwargs
3132
)
3233
unit_ids = self.templates.unit_ids
3334
sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype)

src/spikeinterface/benchmark/benchmark_motion_estimation.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,19 @@ def run(self, **job_kwargs):
8989
noise_levels = get_noise_levels(self.recording, return_in_uV=False, **job_kwargs)
9090

9191
t0 = time.perf_counter()
92-
peaks = detect_peaks(self.recording, noise_levels=noise_levels, **p["detect_kwargs"], **job_kwargs)
92+
detect_kwargs = p["detect_kwargs"].copy()
93+
detect_kwargs["noise_levels"] = noise_levels
94+
peaks = detect_peaks(self.recording, method_kwargs=detect_kwargs, job_kwargs=job_kwargs)
9395
t1 = time.perf_counter()
9496
if p["select_kwargs"] is not None:
9597
selected_peaks = select_peaks(self.peaks, **p["select_kwargs"], **job_kwargs)
9698
else:
9799
selected_peaks = peaks
98100

99101
t2 = time.perf_counter()
100-
peak_locations = localize_peaks(self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs)
102+
peak_locations = localize_peaks(
103+
self.recording, selected_peaks, method_kwargs=p["localize_kwargs"], job_kwargs=job_kwargs
104+
)
101105
t3 = time.perf_counter()
102106
motion = estimate_motion(self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"])
103107
t4 = time.perf_counter()
@@ -240,6 +244,8 @@ def plot_drift(
240244

241245
# ax0.set_ylim()
242246

247+
return fig
248+
243249
def plot_errors(self, case_keys=None, figsize=None, lim=None):
244250
import matplotlib.pyplot as plt
245251

@@ -305,6 +311,8 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None):
305311
if lim is not None:
306312
ax.set_ylim(0, lim)
307313

314+
return fig
315+
308316
def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)):
309317
import matplotlib.pyplot as plt
310318

@@ -368,6 +376,8 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5))
368376

369377
despine(ax2)
370378

379+
return fig
380+
371381
# ax1.sharey(ax0)
372382
# ax2.sharey(ax0)
373383

src/spikeinterface/benchmark/benchmark_peak_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, recording, gt_sorting, params, gt_peaks, exhaustive_gt=True,
3333
self.result = {}
3434

3535
def run(self, **job_kwargs):
36-
peaks = detect_peaks(self.recording, self.method, **self.params, **job_kwargs)
36+
peaks = detect_peaks(self.recording, self.method, method_kwargs=self.params, job_kwargs=job_kwargs)
3737
self.result["peaks"] = peaks
3838

3939
def compute_result(self, **result_params):

src/spikeinterface/benchmark/benchmark_peak_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks
3+
from spikeinterface.sortingcomponents.clustering import find_clusters_from_peaks
44
from spikeinterface.core import NumpySorting
55
from spikeinterface.comparison import GroundTruthComparison
66
from spikeinterface.comparison.comparisontools import make_matching_events

0 commit comments

Comments
 (0)