1+ from __future__ import annotations
2+
3+ import numpy as np
4+
5+ from spikeinterface .core .core_tools import define_function_handling_dict_from_class
6+ from spikeinterface .preprocessing .silence_periods import SilencedPeriodsRecording
7+ from spikeinterface .preprocessing .rectify import RectifyRecording
8+ from spikeinterface .preprocessing .common_reference import CommonReferenceRecording
9+ from spikeinterface .preprocessing .filter_gaussian import GaussianFilterRecording
10+ from spikeinterface .core .job_tools import split_job_kwargs , fix_job_kwargs
11+ from spikeinterface .core .recording_tools import get_noise_levels
12+ from spikeinterface .core .node_pipeline import PeakDetector , base_peak_dtype
13+ import numpy as np
14+
15+
16+ class DetectSaturation (PeakDetector ):
17+
18+ name = "detect_saturation"
19+ preferred_mp_context = None
20+
21+ def __init__ (
22+ self ,
23+ recording ,
24+ saturation_threshold = 5 , # TODO: FIX, max_voltage = max_voltage if max_voltage is not None else sr.range_volts[:-1]
25+ noise_levels = None , # TODO: REMOVE?
26+ seed = None ,
27+ noise_levels_kwargs = dict (),
28+ ):
29+
30+ # TODO: fix name
31+ # TODO: review this
32+ EVENT_VECTOR_TYPE = [
33+ ('start_sample_index' , 'int64' ),
34+ ('stop_sample_index' , 'int64' ),
35+ ('segment_index' , 'int64' ),
36+ ('channel_x_start' , 'float64' ),
37+ ('channel_x_stop' , 'float64' ),
38+ ('channel_y_start' , 'float64' ),
39+ ('channel_y_stop' , 'float64' ),
40+ ('method_id' , 'U128' )
41+ ]
42+ self .saturation_threshold = saturation_threshold
43+ self ._dtype = np .dtype (base_peak_dtype + [("front" , "bool" )])
44+
45+ def get_trace_margin (self ): # TODO: add margin
46+ return 0
47+
48+ def get_dtype (self ):
49+ return self ._dtype
50+
51+ def compute (self , traces , start_frame , end_frame , segment_index , max_margin ): # TODO: required arguments
52+ """
53+ Computes
54+ :param data: [nc, ns]: voltage traces array
55+ :param max_voltage: maximum value of the voltage: scalar or array of size nc (same units as data)
56+ :param v_per_sec: maximum derivative of the voltage in V/s (or units/s)
57+ :param fs: sampling frequency Hz (defaults to 30kHz)
58+ :param proportion: 0 < proportion <1 of channels above threshold to consider the sample as saturated (0.2)
59+ :param mute_window_samples=7: number of samples for the cosine taper applied to the saturation
60+ :return:
61+ saturation [ns]: boolean array indicating the saturated samples
62+ mute [ns]: float array indicating the mute function to apply to the data [0-1]
63+ """
64+ import scipy # TODO: handle import
65+ max_voltage = self .saturation_threshold
66+ data = traces .T # TODO: handle
67+
68+ # first computes the saturated samples
69+ max_voltage = np .atleast_1d (max_voltage )[:, np .newaxis ]
70+ saturation = np .mean (np .abs (data ) > max_voltage * 0.98 , axis = 0 )
71+
72+ # then compute the derivative of the voltage saturation
73+ n_diff_saturated = np .mean (np .abs (np .diff (data , axis = - 1 )) / fs >= v_per_sec , axis = 0 )
74+ n_diff_saturated = np .r_ [n_diff_saturated , 0 ]
75+
76+ # if either of those reaches more than the proportion of channels labels the sample as saturated
77+ saturation = np .logical_or (saturation > proportion , n_diff_saturated > proportion )
78+
79+ # apply a cosine taper to the saturation to create a mute function
80+ win = scipy .signal .windows .cosine (mute_window_samples )
81+ mute = np .maximum (0 , 1 - scipy .signal .convolve (saturation , win , mode = "same" ))
82+ return saturation , mute
83+
84+
85+
86+ #z = np.median(traces / self.abs_thresholds, 1)
87+ #threshold_mask = np.diff((z > 1) != 0, axis=0)
88+ #indices = np.flatnonzero(threshold_mask)
89+ #threshold_crossings = np.zeros(indices.size, dtype=self._dtype)
90+ #threshold_crossings["sample_index"] = indices
91+ #threshold_crossings["front"][::2] = True
92+ #threshold_crossings["front"][1::2] = False
93+ #return (threshold_crossings,)
94+
95+
96+ def detect_period_artifacts_by_envelope (
97+ recording ,
98+ detect_threshold = 5 ,
99+ min_duration_ms = 50 ,
100+ freq_max = 20.0 ,
101+ seed = None ,
102+ noise_levels = None ,
103+ ** noise_levels_kwargs ,
104+ ):
105+ """
106+ Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of
107+ a global envelope of the channels.
108+
109+ Parameters
110+ ----------
111+ recording : RecordingExtractor
112+ The recording extractor to detect putative artifacts
113+ detect_threshold : float, default: 5
114+ The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level`
115+ freq_max : float, default: 20
116+ The maximum frequency for the low pass filter used
117+ min_duration_ms : float, default: 50
118+ The minimum duration for a threshold crossing to be considered as an artefact.
119+ noise_levels : array
120+ Noise levels if already computed
121+ seed : int | None, default: None
122+ Random seed for `get_noise_levels`.
123+ If none, `get_noise_levels` uses `seed=0`.
124+ **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function
125+
126+ """
127+
128+ envelope = RectifyRecording (recording )
129+ envelope = GaussianFilterRecording (envelope , freq_min = None , freq_max = freq_max )
130+ envelope = CommonReferenceRecording (envelope )
131+
132+ from spikeinterface .core .node_pipeline import (
133+ run_node_pipeline ,
134+ )
135+
136+ _ , job_kwargs = split_job_kwargs (noise_levels_kwargs )
137+ job_kwargs = fix_job_kwargs (job_kwargs )
138+
139+ node0 = DetectThresholdCrossing (
140+ recording , detect_threshold = detect_threshold , noise_levels = noise_levels , seed = seed , ** noise_levels_kwargs
141+ )
142+
143+ threshold_crossings = run_node_pipeline (
144+ recording ,
145+ [node0 ],
146+ job_kwargs ,
147+ job_name = "detect threshold crossings" ,
148+ )
149+
150+ order = np .lexsort ((threshold_crossings ["sample_index" ], threshold_crossings ["segment_index" ]))
151+ threshold_crossings = threshold_crossings [order ]
152+
153+ periods = []
154+ fs = recording .sampling_frequency
155+ max_duration_samples = int (min_duration_ms * fs / 1000 )
156+ num_seg = recording .get_num_segments ()
157+
158+ for seg_index in range (num_seg ):
159+ sub_periods = []
160+ mask = threshold_crossings ["segment_index" ] == seg_index
161+ sub_thr = threshold_crossings [mask ]
162+ if len (sub_thr ) > 0 :
163+ local_thr = np .zeros (1 , dtype = np .dtype (base_peak_dtype + [("front" , "bool" )]))
164+ if not sub_thr ["front" ][0 ]:
165+ local_thr ["sample_index" ] = 0
166+ local_thr ["front" ] = True
167+ sub_thr = np .hstack ((local_thr , sub_thr ))
168+ if sub_thr ["front" ][- 1 ]:
169+ local_thr ["sample_index" ] = recording .get_num_samples (seg_index )
170+ local_thr ["front" ] = False
171+ sub_thr = np .hstack ((sub_thr , local_thr ))
172+
173+ indices = np .flatnonzero (np .diff (sub_thr ["front" ]))
174+ for i , j in zip (indices [:- 1 ], indices [1 :]):
175+ if sub_thr ["front" ][i ]:
176+ start = sub_thr ["sample_index" ][i ]
177+ end = sub_thr ["sample_index" ][j ]
178+ if end - start > max_duration_samples :
179+ sub_periods .append ((start , end ))
180+
181+ periods .append (sub_periods )
182+
183+ return periods , envelope
184+
185+
186+ class SilencedArtifactsRecording (SilencedPeriodsRecording ):
187+ """
188+ Silence user-defined periods from recording extractor traces. The code will construct
189+ an enveloppe of the recording (as a low pass filtered version of the traces) and detect
190+ threshold crossings to identify the periods to silence. The periods are then silenced either
191+ on a per channel basis or across all channels by replacing the values by zeros or by
192+ adding gaussian noise with the same variance as the one in the recordings
193+
194+ Parameters
195+ ----------
196+ recording : RecordingExtractor
197+ The recording extractor to silence putative artifacts
198+ detect_threshold : float, default: 5
199+ The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level`
200+ freq_max : float, default: 20
201+ The maximum frequency for the low pass filter used
202+ min_duration_ms : float, default: 50
203+ The minimum duration for a threshold crossing to be considered as an artefact.
204+ noise_levels : array
205+ Noise levels if already computed
206+ seed : int | None, default: None
207+ Random seed for `get_noise_levels` and `NoiseGeneratorRecording`.
208+ If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`.
209+ mode : "zeros" | "noise", default: "zeros"
210+ Determines what periods are replaced by. Can be one of the following:
211+
212+ - "zeros": Artifacts are replaced by zeros.
213+
214+ - "noise": The periods are filled with a gaussion noise that has the
215+ same variance that the one in the recordings, on a per channel
216+ basis
217+ **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function
218+
219+ Returns
220+ -------
221+ silenced_recording : SilencedArtifactsRecording
222+ The recording extractor after silencing detected artifacts
223+ """
224+
225+ _precomputable_kwarg_names = ["list_periods" ]
226+
227+ def __init__ (
228+ self ,
229+ recording ,
230+ detect_threshold = 5 ,
231+ verbose = False ,
232+ freq_max = 20.0 ,
233+ min_duration_ms = 50 ,
234+ mode = "zeros" ,
235+ noise_levels = None ,
236+ seed = None ,
237+ list_periods = None ,
238+ ** noise_levels_kwargs ,
239+ ):
240+
241+ if list_periods is None :
242+ list_periods , _ = detect_period_artifacts_by_envelope (
243+ recording ,
244+ detect_threshold = detect_threshold ,
245+ min_duration_ms = min_duration_ms ,
246+ freq_max = freq_max ,
247+ seed = seed ,
248+ noise_levels = noise_levels ,
249+ ** noise_levels_kwargs ,
250+ )
251+
252+ if verbose :
253+ for i , periods in enumerate (list_periods ):
254+ total_time = np .sum ([end - start for start , end in periods ])
255+ percentage = 100 * total_time / recording .get_num_samples (i )
256+ print (f"{ percentage } % of segment { i } has been flagged as artifactual" )
257+
258+ SilencedPeriodsRecording .__init__ (
259+ self , recording , list_periods , mode = mode , noise_levels = noise_levels , seed = seed , ** noise_levels_kwargs
260+ )
261+
262+
263+ # function for API
264+ silence_artifacts = define_function_handling_dict_from_class (
265+ source_class = SilencedArtifactsRecording , name = "silence_artifacts"
266+ )
0 commit comments