Skip to content

Commit 52dedd7

Browse files
committed
Add fit_args param that works with MTLassoCV
1 parent 3fb4e25 commit 52dedd7

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/SparseSC/fit_fast.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def fit_fast( # pylint: disable=unused-argument, missing-raises-doc
119119
else:
120120
custom_donor_pool = np.full((N,N0), True)
121121
custom_donor_pool = _ensure_good_donor_pool(custom_donor_pool, control_units)
122-
match_space_maker = MTLassoCV_MatchSpace_factory() if match_space_maker is None else match_space_maker
122+
fit_args = kwargs.get('fit_args', {})
123+
match_space_maker = MTLassoCV_MatchSpace_factory(fit_args=fit_args) if match_space_maker is None else match_space_maker
123124

124125
fit_units = _get_fit_units(model_type, control_units, treated_units, N)
125126
X_v = X[fit_units, :]

src/SparseSC/utils/match_space.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _block_summ_cols(Y, Y_col_block_size):
119119
print("Can only average target across columns blocks if blocks fit evenly")
120120
return Y
121121

122-
def MTLassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, Y_col_block_size=None, se_factor=None, normalize=True):
122+
def MTLassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, Y_col_block_size=None, se_factor=None, normalize=True, fit_args={}):
123123
"""
124124
Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X
125125
@@ -133,14 +133,14 @@ def MTLassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, Y_col_blo
133133

134134
def _MTLassoCV_MatchSpace_wrapper(X, Y, **kwargs):
135135
return _MTLassoCV_MatchSpace(
136-
X, Y, v_pens=v_pens, n_v_cv=n_v_cv, sample_frac=sample_frac, Y_col_block_size=Y_col_block_size, se_factor=se_factor, normalize=normalize, **kwargs
136+
X, Y, v_pens=v_pens, n_v_cv=n_v_cv, sample_frac=sample_frac, Y_col_block_size=Y_col_block_size, se_factor=se_factor, normalize=normalize, fit_args=fit_args, **kwargs
137137
)
138138

139139
return _MTLassoCV_MatchSpace_wrapper
140140

141141

142142
def _MTLassoCV_MatchSpace(
143-
X, Y, v_pens=None, n_v_cv=5, sample_frac=1, Y_col_block_size=None, se_factor=None, normalize=True, **kwargs
143+
X, Y, v_pens=None, n_v_cv=5, sample_frac=1, Y_col_block_size=None, se_factor=None, normalize=True, fit_args={}, **kwargs
144144
): # pylint: disable=missing-param-doc, unused-argument
145145
# A fake MT would do Lasso on y_mean = Y.mean(axis=1)
146146
if sample_frac < 1:
@@ -150,7 +150,7 @@ def _MTLassoCV_MatchSpace(
150150
Y = Y[sample, :]
151151
if Y_col_block_size is not None:
152152
Y = _block_summ_cols(Y, Y_col_block_size)
153-
varselectorfit = MultiTaskLassoCV(normalize=normalize, cv=n_v_cv, alphas=v_pens).fit(
153+
varselectorfit = MultiTaskLassoCV(normalize=normalize, cv=n_v_cv, alphas=v_pens, **fit_args).fit(
154154
X, Y
155155
)
156156
best_v_pen = varselectorfit.alpha_

0 commit comments

Comments
 (0)