Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 191 additions & 139 deletions spikeinterface_gui/controller.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions spikeinterface_gui/correlogramview.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def _qt_make_layout(self):
self.grid = pg.GraphicsLayoutWidget()
self.layout.addWidget(self.grid)

def _reinitialize(self):
self.ccg, self.bins = self.controller.get_correlograms()
self.figure_cache = {}
self._refresh()

def _qt_refresh(self):
import pyqtgraph as pg
Expand Down
41 changes: 37 additions & 4 deletions spikeinterface_gui/curationview.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def _qt_make_layout(self):
but = QT.QPushButton("Save in analyzer")
tb.addWidget(but)
but.clicked.connect(self.save_in_analyzer)

but_apply = QT.QPushButton("Apply curation")
tb.addWidget(but_apply)
but_apply.clicked.connect(self.apply_curation_to_analyzer)

but = QT.QPushButton("Export JSON")
but.clicked.connect(self._qt_export_json)
tb.addWidget(but)
Expand Down Expand Up @@ -278,6 +283,10 @@ def on_manual_curation_updated(self):
def save_in_analyzer(self):
self.controller.save_curation_in_analyzer()

def apply_curation_to_analyzer(self):
with self.busy_cursor():
self.controller.apply_curation()

def _qt_export_json(self):
from .myqt import QT
fd = QT.QFileDialog(fileMode=QT.QFileDialog.AnyFile, acceptMode=QT.QFileDialog.AcceptSave)
Expand All @@ -286,10 +295,23 @@ def _qt_export_json(self):
fd.setViewMode(QT.QFileDialog.Detail)
if fd.exec_():
json_file = Path(fd.selectedFiles()[0])
curation_model = self.controller.construct_final_curation()
with json_file.open("w") as f:
f.write(curation_model.model_dump_json(indent=4))
self.controller.current_curation_saved = True
if len(self.controller.applied_curations) == 0:
curation_model = self.controller.construct_final_curation()
with json_file.open("w") as f:
f.write(curation_model.model_dump_json(indent=4))
self.controller.current_curation_saved = True
else:
# Keep this here until `SeqentialCuration` in release of spikeinterface
from spikeinterface.curation.curation_model import SequentialCuration

current_curation_model = self.controller.construct_final_curation()
applied_curations = self.controller.applied_curations
current_and_applied_curations = applied_curations + [current_curation_model]

sequential_curation_model = SequentialCuration(curation_steps=current_and_applied_curations)
with json_file.open("w") as f:
f.write(sequential_curation_model.model_dump_json(indent=4))
self.controller.current_curation_saved = True

# PANEL
def _panel_make_layout(self):
Expand Down Expand Up @@ -360,6 +382,13 @@ def _panel_make_layout(self):
)
save_button.on_click(self._panel_save_in_analyzer)

apply_button = pn.widgets.Button(
name="Apply curation",
button_type="primary",
height=30
)
apply_button.on_click(self._panel_apply_curation_to_analyzer)

download_button = pn.widgets.FileDownload(
button_type="primary",
filename="curation.json",
Expand Down Expand Up @@ -391,6 +420,7 @@ def _panel_make_layout(self):
buttons_save = pn.Row(
save_button,
download_button,
apply_button,
submit_button,
sizing_mode="stretch_width",
)
Expand Down Expand Up @@ -522,6 +552,9 @@ def _panel_restore_units(self, event):
def _panel_unmerge(self, event):
self.unmerge()

def _panel_apply_curation_to_analyzer(self, event):
self.apply_curation_to_analyzer()

def _panel_save_in_analyzer(self, event):
self.save_in_analyzer()
self.refresh()
Expand Down
20 changes: 19 additions & 1 deletion spikeinterface_gui/mainsettingsview.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
{'name': 'max_visible_units', 'type': 'int', 'value' : 10 },
{'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit',
'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']},
{'name': 'use_times', 'type': 'bool', 'value': False}
{'name': 'use_times', 'type': 'bool', 'value': False},
{'name': 'merge_new_id_strategy', 'type': 'list', 'limits' : ['take_first', 'append', 'join']},
{'name': 'split_new_id_strategy', 'type': 'list', 'limits' : ['append', 'split']},
]


Expand Down Expand Up @@ -45,6 +47,12 @@ def on_use_times(self):
self.controller.update_time_info()
self.notify_use_times_updated()

def on_merge_new_id_strategy(self):
self.controller.main_settings['merge_new_id_strategy'] = self.main_settings['merge_new_id_strategy']

def on_split_new_id_strategy(self):
self.controller.main_settings['split_new_id_strategy'] = self.main_settings['split_new_id_strategy']

def save_current_settings(self, event=None):

backend = self.controller.backend
Expand Down Expand Up @@ -106,6 +114,8 @@ def _qt_make_layout(self):
self.main_settings.param('max_visible_units').sigValueChanged.connect(self.on_max_visible_units_changed)
self.main_settings.param('color_mode').sigValueChanged.connect(self.on_change_color_mode)
self.main_settings.param('use_times').sigValueChanged.connect(self.on_use_times)
self.main_settings.param('merge_new_id_strategy').sigValueChanged.connect(self.on_merge_new_id_strategy)
self.main_settings.param('split_new_id_strategy').sigValueChanged.connect(self.on_split_new_id_strategy)

def qt_make_settings_dict(self, view):
"""For a given view, return the current settings in a dict"""
Expand Down Expand Up @@ -141,6 +151,8 @@ def _panel_make_layout(self):
self.main_settings._parameterized.param.watch(self._panel_on_max_visible_units_changed, 'max_visible_units')
self.main_settings._parameterized.param.watch(self._panel_on_change_color_mode, 'color_mode')
self.main_settings._parameterized.param.watch(self._panel_on_use_times, 'use_times')
self.main_settings._parameterized.param.watch(self._panel_on_merge_new_id_strategy, 'merge_new_id_strategy')
self.main_settings._parameterized.param.watch(self._panel_on_split_new_id_strategy, 'split_new_id_strategy')
self.layout = pn.Column(self.save_setting_button, self.main_settings_layout, sizing_mode="stretch_both")

def panel_make_settings_dict(self, view):
Expand All @@ -160,6 +172,12 @@ def _panel_on_max_visible_units_changed(self, event):
def _panel_on_change_color_mode(self, event):
self.on_change_color_mode()

def _panel_on_merge_new_id_strategy(self, event):
self.on_merge_new_id_strategy()

def _panel_on_split_new_id_strategy(self, event):
self.on_split_new_id_strategy()

def _panel_on_use_times(self, event):
self.on_use_times()

Expand Down
5 changes: 5 additions & 0 deletions spikeinterface_gui/mergeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def accept_group_merge(self, group_ids):
self.notify_manual_curation_updated()
self.refresh()

def _reinitialize(self):
self.proposed_merge_unit_groups = []
self.merge_info = {}
self._refresh()

### QT
def _qt_get_selected_group_ids(self):
inds = self.table.selectedIndexes()
Expand Down
25 changes: 20 additions & 5 deletions spikeinterface_gui/probeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ def _qt_make_layout(self):

self.roi_units.sigRegionChangeFinished.connect(self._qt_on_roi_units_changed)

def _qt_reinitialize(self):
import pyqtgraph as pg
unit_positions = self.controller.unit_positions
brush = [self.get_unit_color(u) for u in self.controller.unit_ids]
self.scatter = pg.ScatterPlotItem(pos=unit_positions, pxMode=False, size=10, brush=brush)

self._qt_refresh()

def _qt_refresh(self):
current_unit_positions = self.controller.unit_positions
# if not np.array_equal(current_unit_positions, self._unit_positions):
Expand Down Expand Up @@ -479,11 +487,14 @@ def _panel_make_layout(self):
self.should_resize_unit_circle = None

# Main layout
self.layout = pn.Column(
self.figure,
styles={"display": "flex", "flex-direction": "column"},
sizing_mode="stretch_both",
)
if self.layout is None:
self.layout = pn.Column(
self.figure,
styles={"display": "flex", "flex-direction": "column"},
sizing_mode="stretch_both",
)
else:
self.layout.objects = [self.figure]

def _panel_refresh(self):
# Only update unit positions if they actually changed
Expand Down Expand Up @@ -529,6 +540,10 @@ def _panel_refresh(self):
self.y_range.start = y_min - margin
self.y_range.end = y_max + margin

def _panel_reinitialize(self):
self._panel_make_layout()
self._refresh()

def _panel_update_unit_glyphs(self):
# Get current data from source
current_alphas = self.unit_glyphs.data_source.data['alpha']
Expand Down
4 changes: 4 additions & 0 deletions spikeinterface_gui/spikeamplitudeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def __init__(self, controller=None, parent=None, backend="qt"):
spike_data=spike_data,
)

def _reinitialize(self):
self.spike_data = self.controller.spike_amplitudes
self._refresh()

def _qt_make_layout(self):
from .myqt import QT

Expand Down
3 changes: 3 additions & 0 deletions spikeinterface_gui/spikedepthview.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def __init__(self, controller=None, parent=None, backend="qt"):
spike_data=spike_data,
)

def _reinitialize(self):
self.spike_data = self.controller.spike_depths
self._refresh()


SpikeDepthView._gui_help_txt = """
Expand Down
69 changes: 42 additions & 27 deletions spikeinterface_gui/unitlistview.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def get_selected_unit_ids(self):
def _qt_make_layout(self):

from .myqt import QT
import pyqtgraph as pg


self.menu = None
self.layout = QT.QVBoxLayout()
Expand All @@ -51,21 +49,7 @@ def _qt_make_layout(self):
but.clicked.connect(self._qt_select_columns)
tb.addWidget(but)


visible_cols = []
for col in self.controller.units_table.columns:
visible_cols.append(
{'name': str(col), 'type': 'bool', 'value': col in self.controller.displayed_unit_properties, 'default': True}
)
self.visible_columns = pg.parametertree.Parameter.create( name='visible columns', type='group', children=visible_cols)
self.tree_visible_columns = pg.parametertree.ParameterTree(parent=self.qt_widget)
self.tree_visible_columns.header().hide()
self.tree_visible_columns.setParameters(self.visible_columns, showTop=True)
# self.tree_visible_columns.setWindowTitle(u'visible columns')
# self.tree_visible_columns.setWindowFlags(QT.Qt.Window)
self.visible_columns.sigTreeStateChanged.connect(self._qt_on_visible_columns_changed)
self.layout.addWidget(self.tree_visible_columns)
self.tree_visible_columns.hide()
self._qt_set_up_visible_columns()

# h = QT.QHBoxLayout()
# self.layout.addLayout(h)
Expand Down Expand Up @@ -129,6 +113,28 @@ def _qt_make_layout(self):
self.shortcut_noise.setKey(QT.QKeySequence('n'))
self.shortcut_noise.activated.connect(lambda: self._qt_set_default_label('noise'))

def _qt_set_up_visible_columns(self):

import pyqtgraph as pg
visible_cols = []
for col in self.controller.units_table.columns:
visible_cols.append(
{'name': str(col), 'type': 'bool', 'value': col in self.controller.displayed_unit_properties, 'default': True}
)
self.visible_columns = pg.parametertree.Parameter.create( name='visible columns', type='group', children=visible_cols)
self.tree_visible_columns = pg.parametertree.ParameterTree(parent=self.qt_widget)
self.tree_visible_columns.header().hide()
self.tree_visible_columns.setParameters(self.visible_columns, showTop=True)

self.visible_columns.sigTreeStateChanged.connect(self._qt_on_visible_columns_changed)
self.layout.addWidget(self.tree_visible_columns)
self.tree_visible_columns.hide()

def _qt_reinitialize(self):

self._qt_set_up_visible_columns()
self._qt_full_table_refresh()
self._qt_refresh()

def _qt_on_column_moved(self, logical_index, old_visual_index, new_visual_index):
# Update stored column order
Expand Down Expand Up @@ -584,16 +590,22 @@ def _panel_make_layout(self):
shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts)
shortcuts_component.on_msg(self._panel_handle_shortcut)

self.layout = pn.Column(
pn.Row(
self.info_text,
),
buttons,
sizing_mode="stretch_width",
)
if self.layout is None:
self.layout = pn.Column(
pn.Row(
self.info_text,
),
buttons,
sizing_mode="stretch_width",
)

self.layout.append(self.table)
self.layout.append(shortcuts_component)
self.layout.append(self.table)
self.layout.append(shortcuts_component)
else:
self.layout[0][0] = self.info_text
self.layout[1] = buttons
self.layout[2] = self.table
self.layout[3] = shortcuts_component

self.table.tabulator.on_edit(self._panel_on_edit)

Expand Down Expand Up @@ -650,6 +662,10 @@ def _panel_refresh(self):
# refresh header
self._panel_refresh_header()

def _panel_reinitialize(self):
self._panel_make_layout()
self._panel_refresh()

def _panel_refresh_header(self):
unit_ids = self.controller.unit_ids
n1 = len(unit_ids)
Expand All @@ -675,7 +691,6 @@ def _panel_merge_units_callback(self, event):
self.notifier.notify_active_view_updated()

def _panel_on_visible_checkbox_toggled(self, row):
# print("checkbox toggled on row", row)
unit_ids = self.table.value.index.values
selected_unit_id = unit_ids[row]
self.controller.set_unit_visibility(selected_unit_id, not self.controller.get_unit_visibility(selected_unit_id))
Expand Down
35 changes: 35 additions & 0 deletions spikeinterface_gui/utils_global.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from pathlib import Path
import os
from copy import copy

def get_config_folder() -> Path:
"""Get the config folder for spikeinterface-gui settings files.
Expand Down Expand Up @@ -58,3 +59,37 @@ def get_present_zones_in_half_of_layout(layout_zone, shift):
is_present = [views is not None and len(views) > 0 for views in half_dict.values()]
present_zones = set(np.array(list(half_dict.keys()))[np.array(is_present)])
return present_zones


def add_new_unit_ids_to_curation_dict(curation_dict, sorting, split_new_id_strategy, merge_new_id_strategy):
"""
Explicitly adds the new unit ids to `curation_dict` based on the split and merge new id strategies.
These *should* be the ids that would have been generated during `apply_curation` with these strategies.
"""

from spikeinterface.core.sorting_tools import generate_unit_ids_for_split, generate_unit_ids_for_merge_group
from spikeinterface.curation.curation_model import CurationModel

curation_model = CurationModel(**curation_dict)
old_unit_ids = copy(curation_model.unit_ids)

if len(curation_model.splits) > 0:
unit_splits = {split.unit_id: split.get_full_spike_indices(sorting) for split in curation_model.splits}
new_split_unit_ids = generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy=split_new_id_strategy)

all_new_unit_ids = []
for split_index, new_unit_ids in enumerate(new_split_unit_ids):
curation_dict['splits'][split_index]['new_unit_ids'] = new_unit_ids
all_new_unit_ids = all_new_unit_ids + new_unit_ids

# update old unit ids with the newly split units
old_unit_ids = np.setdiff1d(old_unit_ids, np.array(list(unit_splits.keys())))
old_unit_ids = np.concat([old_unit_ids, all_new_unit_ids])

if len(curation_model.merges) > 0:
merge_unit_groups = [m.unit_ids for m in curation_model.merges]
new_merge_unit_ids = generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy=merge_new_id_strategy)
for merge_index, new_unit_id in enumerate(new_merge_unit_ids):
curation_dict['merges'][merge_index]['new_unit_id'] = new_unit_id

return curation_dict
Loading