Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions ncrf/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from math import sqrt, log10
from multiprocessing import current_process
from operator import attrgetter
from typing import Sequence, Union
from typing import Iterator, Sequence, Union

from eelbrain import fmtxt, UTS, NDVar
import numpy as np
Expand Down Expand Up @@ -234,29 +234,38 @@ class RegressionData:

Parameters
----------
tstart : float | list[float]
tstart
Start of the TRF in seconds. Can define multiple tstarts for more than 1 predictor.
tstop : float | list[float]
tstop
Stop of the TRF in seconds. Can define multiple tstops for more than 1 predictor.
nlevel : int
nlevel
Decides the density of Gabor atoms. Bigger nlevel -> less dense basis.
By default it is set to 1. ``nlevel > 2`` should be used with caution.
baseline: list | None
baseline
Mean that will be subtracted from ``stim``.
scaling: list | None
scaling
Scale by which ``stim`` was divided.
"""
_n_predictor_variables = 1
_prewhitened = None

def __init__(self, tstart, tstop, nlevel=1, baseline=None, scaling=None, stim_is_single=None, gaussian_fwhm=20.0):
def __init__(
self,
tstart: Union[float, list[float]],
tstop: Union[float, list[float]],
nlevel: int = 1,
baseline: list[float] = None,
scaling: list[float] = None,
stim_is_single: bool = None,
gaussian_fwhm: float = 20.0,
):
self.tstart = tstart if isinstance(tstart, collections.abc.Sequence) else [tstart]
self.tstop = tstop if isinstance(tstop, collections.abc.Sequence) else [tstop]
self.nlevel = nlevel
self.s_baseline = baseline
self.s_scaling = scaling
self.s_normalization = []
self.meg = []
self.meg: list[np.ndarray] = [] # (sensor, time)
self.covariates = []
self.tstep = None
self.filter_length = None
Expand Down Expand Up @@ -344,6 +353,11 @@ def add_data(self, meg, stim):
m = max([basis.shape[0] for basis in self.basis])
y = meg.get_data(('sensor', 'time'))
y = y[:, m - 1:].astype(np.float64)
ch_var = np.var(y, axis=1)
zero_var = ch_var == 0
if zero_var.any():
flat_channels = self.sensor_dim.names[np.flatnonzero(zero_var)]
raise ValueError(f"{meg=}: data contains flat channels ({', '.join(flat_channels)})")
self.meg.append(y / sqrt(y.shape[1])) # Mind the normalization

if self._norm_factor is None:
Expand Down Expand Up @@ -386,7 +400,7 @@ def _precompute(self):
self._bE.append(np.dot(b, E))
self._EtE.append(np.dot(E.T, E))

def __iter__(self):
def __iter__(self) -> Iterator[tuple[np.ndarray, np.ndarray]]:
return zip(self.meg, self.covariates)

def __len__(self):
Expand Down Expand Up @@ -1043,12 +1057,13 @@ def compute_explained_variance(self, data):
logger.debug(f'{self.mu}: {1 - temp / len(data)}')
return 1 - temp / len(data)

def _compute_voxelwise_explained_variance(self, data):
def _compute_voxelwise_explained_variance(self, data: RegressionData):
"""evaluates explained_variance

Parameters
---------
data : REG_Data instance
data
Regression data.

Returns
-------
Expand Down
Loading