Skip to content

Commit 017b7ce

Browse files
committed
in-progress detect saturation rough.
1 parent bffc5f2 commit 017b7ce

File tree

1 file changed

+266
-0
lines changed

1 file changed

+266
-0
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
5+
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
6+
from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording
7+
from spikeinterface.preprocessing.rectify import RectifyRecording
8+
from spikeinterface.preprocessing.common_reference import CommonReferenceRecording
9+
from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording
10+
from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs
11+
from spikeinterface.core.recording_tools import get_noise_levels
12+
from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype
13+
import numpy as np
14+
15+
16+
class DetectSaturation(PeakDetector):
17+
18+
name = "detect_saturation"
19+
preferred_mp_context = None
20+
21+
def __init__(
22+
self,
23+
recording,
24+
saturation_threshold=5, # TODO: FIX, max_voltage = max_voltage if max_voltage is not None else sr.range_volts[:-1]
25+
noise_levels=None, # TODO: REMOVE?
26+
seed=None,
27+
noise_levels_kwargs=dict(),
28+
):
29+
30+
# TODO: fix name
31+
# TODO: review this
32+
EVENT_VECTOR_TYPE = [
33+
('start_sample_index', 'int64'),
34+
('stop_sample_index', 'int64'),
35+
('segment_index', 'int64'),
36+
('channel_x_start', 'float64'),
37+
('channel_x_stop', 'float64'),
38+
('channel_y_start', 'float64'),
39+
('channel_y_stop', 'float64'),
40+
('method_id', 'U128')
41+
]
42+
self.saturation_threshold = saturation_threshold
43+
self._dtype = np.dtype(base_peak_dtype + [("front", "bool")])
44+
45+
def get_trace_margin(self): # TODO: add margin
46+
return 0
47+
48+
def get_dtype(self):
49+
return self._dtype
50+
51+
def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # TODO: required arguments
52+
"""
53+
Computes
54+
:param data: [nc, ns]: voltage traces array
55+
:param max_voltage: maximum value of the voltage: scalar or array of size nc (same units as data)
56+
:param v_per_sec: maximum derivative of the voltage in V/s (or units/s)
57+
:param fs: sampling frequency Hz (defaults to 30kHz)
58+
:param proportion: 0 < proportion <1 of channels above threshold to consider the sample as saturated (0.2)
59+
:param mute_window_samples=7: number of samples for the cosine taper applied to the saturation
60+
:return:
61+
saturation [ns]: boolean array indicating the saturated samples
62+
mute [ns]: float array indicating the mute function to apply to the data [0-1]
63+
"""
64+
import scipy # TODO: handle import
65+
max_voltage = self.saturation_threshold
66+
data = traces.T # TODO: handle
67+
68+
# first computes the saturated samples
69+
max_voltage = np.atleast_1d(max_voltage)[:, np.newaxis]
70+
saturation = np.mean(np.abs(data) > max_voltage * 0.98, axis=0)
71+
72+
# then compute the derivative of the voltage saturation
73+
n_diff_saturated = np.mean(np.abs(np.diff(data, axis=-1)) / fs >= v_per_sec, axis=0)
74+
n_diff_saturated = np.r_[n_diff_saturated, 0]
75+
76+
# if either of those reaches more than the proportion of channels labels the sample as saturated
77+
saturation = np.logical_or(saturation > proportion, n_diff_saturated > proportion)
78+
79+
# apply a cosine taper to the saturation to create a mute function
80+
win = scipy.signal.windows.cosine(mute_window_samples)
81+
mute = np.maximum(0, 1 - scipy.signal.convolve(saturation, win, mode="same"))
82+
return saturation, mute
83+
84+
85+
86+
#z = np.median(traces / self.abs_thresholds, 1)
87+
#threshold_mask = np.diff((z > 1) != 0, axis=0)
88+
#indices = np.flatnonzero(threshold_mask)
89+
#threshold_crossings = np.zeros(indices.size, dtype=self._dtype)
90+
#threshold_crossings["sample_index"] = indices
91+
#threshold_crossings["front"][::2] = True
92+
#threshold_crossings["front"][1::2] = False
93+
#return (threshold_crossings,)
94+
95+
96+
def detect_period_artifacts_by_envelope(
97+
recording,
98+
detect_threshold=5,
99+
min_duration_ms=50,
100+
freq_max=20.0,
101+
seed=None,
102+
noise_levels=None,
103+
**noise_levels_kwargs,
104+
):
105+
"""
106+
Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of
107+
a global envelope of the channels.
108+
109+
Parameters
110+
----------
111+
recording : RecordingExtractor
112+
The recording extractor to detect putative artifacts
113+
detect_threshold : float, default: 5
114+
The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level`
115+
freq_max : float, default: 20
116+
The maximum frequency for the low pass filter used
117+
min_duration_ms : float, default: 50
118+
The minimum duration for a threshold crossing to be considered as an artefact.
119+
noise_levels : array
120+
Noise levels if already computed
121+
seed : int | None, default: None
122+
Random seed for `get_noise_levels`.
123+
If none, `get_noise_levels` uses `seed=0`.
124+
**noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function
125+
126+
"""
127+
128+
envelope = RectifyRecording(recording)
129+
envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max)
130+
envelope = CommonReferenceRecording(envelope)
131+
132+
from spikeinterface.core.node_pipeline import (
133+
run_node_pipeline,
134+
)
135+
136+
_, job_kwargs = split_job_kwargs(noise_levels_kwargs)
137+
job_kwargs = fix_job_kwargs(job_kwargs)
138+
139+
node0 = DetectThresholdCrossing(
140+
recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs
141+
)
142+
143+
threshold_crossings = run_node_pipeline(
144+
recording,
145+
[node0],
146+
job_kwargs,
147+
job_name="detect threshold crossings",
148+
)
149+
150+
order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"]))
151+
threshold_crossings = threshold_crossings[order]
152+
153+
periods = []
154+
fs = recording.sampling_frequency
155+
max_duration_samples = int(min_duration_ms * fs / 1000)
156+
num_seg = recording.get_num_segments()
157+
158+
for seg_index in range(num_seg):
159+
sub_periods = []
160+
mask = threshold_crossings["segment_index"] == seg_index
161+
sub_thr = threshold_crossings[mask]
162+
if len(sub_thr) > 0:
163+
local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")]))
164+
if not sub_thr["front"][0]:
165+
local_thr["sample_index"] = 0
166+
local_thr["front"] = True
167+
sub_thr = np.hstack((local_thr, sub_thr))
168+
if sub_thr["front"][-1]:
169+
local_thr["sample_index"] = recording.get_num_samples(seg_index)
170+
local_thr["front"] = False
171+
sub_thr = np.hstack((sub_thr, local_thr))
172+
173+
indices = np.flatnonzero(np.diff(sub_thr["front"]))
174+
for i, j in zip(indices[:-1], indices[1:]):
175+
if sub_thr["front"][i]:
176+
start = sub_thr["sample_index"][i]
177+
end = sub_thr["sample_index"][j]
178+
if end - start > max_duration_samples:
179+
sub_periods.append((start, end))
180+
181+
periods.append(sub_periods)
182+
183+
return periods, envelope
184+
185+
186+
class SilencedArtifactsRecording(SilencedPeriodsRecording):
187+
"""
188+
Silence user-defined periods from recording extractor traces. The code will construct
189+
an enveloppe of the recording (as a low pass filtered version of the traces) and detect
190+
threshold crossings to identify the periods to silence. The periods are then silenced either
191+
on a per channel basis or across all channels by replacing the values by zeros or by
192+
adding gaussian noise with the same variance as the one in the recordings
193+
194+
Parameters
195+
----------
196+
recording : RecordingExtractor
197+
The recording extractor to silence putative artifacts
198+
detect_threshold : float, default: 5
199+
The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level`
200+
freq_max : float, default: 20
201+
The maximum frequency for the low pass filter used
202+
min_duration_ms : float, default: 50
203+
The minimum duration for a threshold crossing to be considered as an artefact.
204+
noise_levels : array
205+
Noise levels if already computed
206+
seed : int | None, default: None
207+
Random seed for `get_noise_levels` and `NoiseGeneratorRecording`.
208+
If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`.
209+
mode : "zeros" | "noise", default: "zeros"
210+
Determines what periods are replaced by. Can be one of the following:
211+
212+
- "zeros": Artifacts are replaced by zeros.
213+
214+
- "noise": The periods are filled with a gaussion noise that has the
215+
same variance that the one in the recordings, on a per channel
216+
basis
217+
**noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function
218+
219+
Returns
220+
-------
221+
silenced_recording : SilencedArtifactsRecording
222+
The recording extractor after silencing detected artifacts
223+
"""
224+
225+
_precomputable_kwarg_names = ["list_periods"]
226+
227+
def __init__(
228+
self,
229+
recording,
230+
detect_threshold=5,
231+
verbose=False,
232+
freq_max=20.0,
233+
min_duration_ms=50,
234+
mode="zeros",
235+
noise_levels=None,
236+
seed=None,
237+
list_periods=None,
238+
**noise_levels_kwargs,
239+
):
240+
241+
if list_periods is None:
242+
list_periods, _ = detect_period_artifacts_by_envelope(
243+
recording,
244+
detect_threshold=detect_threshold,
245+
min_duration_ms=min_duration_ms,
246+
freq_max=freq_max,
247+
seed=seed,
248+
noise_levels=noise_levels,
249+
**noise_levels_kwargs,
250+
)
251+
252+
if verbose:
253+
for i, periods in enumerate(list_periods):
254+
total_time = np.sum([end - start for start, end in periods])
255+
percentage = 100 * total_time / recording.get_num_samples(i)
256+
print(f"{percentage}% of segment {i} has been flagged as artifactual")
257+
258+
SilencedPeriodsRecording.__init__(
259+
self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs
260+
)
261+
262+
263+
# function for API
264+
silence_artifacts = define_function_handling_dict_from_class(
265+
source_class=SilencedArtifactsRecording, name="silence_artifacts"
266+
)

0 commit comments

Comments
 (0)