Skip to content

Commit 5c48ca3

Browse files
committed
Add time getters for sorting
1 parent 4c3a6f9 commit 5c48ca3

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

src/spikeinterface/core/basesorting.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ def register_recording(self, recording, check_spike_frames: bool = True):
322322
"Might be necessary for further postprocessing."
323323
)
324324
self._recording = recording
325+
# The recording is now the source of truth for timestamps.
326+
# Reset the sorting's own time offset so it doesn't conflict
327+
# with the recording's t_start when accessed through get_start_time/get_end_time.
328+
for segment in self.segments:
329+
segment._t_start = 0
325330

326331
@property
327332
def sorting_info(self):
@@ -347,6 +352,51 @@ def has_time_vector(self, segment_index: int | None = None) -> bool:
347352
else:
348353
return False
349354

355+
def get_start_time(self, segment_index: int | None = None) -> float:
356+
"""Get the start time of the sorting segment.
357+
358+
If a recording is registered, returns the recording's start time.
359+
Otherwise returns the sorting segment's own t_start (or 0.0).
360+
361+
Parameters
362+
----------
363+
segment_index : int or None, default: None
364+
The segment index (required for multi-segment)
365+
366+
Returns
367+
-------
368+
float
369+
The start time in seconds
370+
"""
371+
segment_index = self._check_segment_index(segment_index)
372+
if self.has_recording():
373+
return self._recording.get_start_time(segment_index=segment_index)
374+
else:
375+
segment = self.segments[segment_index]
376+
return segment._t_start if segment._t_start is not None else 0.0
377+
378+
def get_end_time(self, segment_index: int | None = None) -> float | None:
379+
"""Get the end time of the sorting segment.
380+
381+
If a recording is registered, returns the recording's end time.
382+
Otherwise returns None (the sorting doesn't know the recording duration).
383+
384+
Parameters
385+
----------
386+
segment_index : int or None, default: None
387+
The segment index (required for multi-segment)
388+
389+
Returns
390+
-------
391+
float or None
392+
The end time in seconds, or None if no recording is registered.
393+
"""
394+
segment_index = self._check_segment_index(segment_index)
395+
if self.has_recording():
396+
return self._recording.get_end_time(segment_index=segment_index)
397+
else:
398+
return None
399+
350400
def get_times(self, segment_index=None):
351401
"""
352402
Get time vector for a registered recording segment.

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,56 @@ def test_shift_times_with_None_as_t_start():
445445
assert recording.segments[0].t_start is None
446446
recording.shift_times(shift=1.0) # Shift by one seconds should not generate an error
447447
assert recording.get_start_time() == 1.0
448+
449+
450+
class TestSortingTimeNoRecording:
451+
"""Tests for time methods on BaseSorting without a registered recording."""
452+
453+
def test_get_start_time_default(self):
454+
sorting = generate_sorting(num_units=5, durations=[10])
455+
assert sorting.get_start_time(segment_index=0) == 0.0
456+
457+
def test_get_end_time_default(self):
458+
sorting = generate_sorting(num_units=5, durations=[10])
459+
assert sorting.get_end_time(segment_index=0) is None
460+
461+
def test_get_start_time_with_t_start(self):
462+
sorting = generate_sorting(num_units=5, durations=[10])
463+
sorting.segments[0]._t_start = 100.0
464+
assert sorting.get_start_time(segment_index=0) == 100.0
465+
466+
467+
class TestSortingTimeWithRecording:
468+
"""
469+
Tests for time methods on BaseSorting with a registered recording.
470+
The key invariant: the recording is the source of truth for timestamps.
471+
"""
472+
473+
def test_get_start_end_time(self):
474+
recording = generate_recording(num_channels=4, durations=[10])
475+
sorting = generate_sorting(num_units=5, durations=[10])
476+
sorting.register_recording(recording)
477+
478+
assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0)
479+
assert sorting.get_end_time(segment_index=0) == recording.get_end_time(segment_index=0)
480+
481+
def test_register_recording_resets_t_start(self):
482+
"""Registering a recording resets _t_start so the recording is the sole source of truth."""
483+
sorting = generate_sorting(num_units=5, durations=[10])
484+
sorting.segments[0]._t_start = 100.0
485+
486+
recording = generate_recording(num_channels=4, durations=[10])
487+
sorting.register_recording(recording)
488+
489+
assert sorting.segments[0]._t_start == 0
490+
assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0)
491+
492+
def test_with_recording_shifted_start(self):
493+
"""Recording with a non-zero t_start is reflected in the sorting."""
494+
recording = generate_recording(num_channels=4, durations=[10])
495+
recording.shift_times(shift=50.0)
496+
497+
sorting = generate_sorting(num_units=5, durations=[10])
498+
sorting.register_recording(recording)
499+
500+
assert sorting.get_start_time(segment_index=0) == 50.0

0 commit comments

Comments
 (0)