Skip to content

Commit 3b4d2eb

Browse files
authored
Merge pull request #50 from mkincaid/fit_args
Add optional fit_args parameter to D_LassoCV and MTLassoMixed
2 parents e2b0aea + f9a180a commit 3b4d2eb

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/SparseSC/utils/match_space.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _MTLasso_MatchSpace(
204204
transformer = SelMatchSpace(m_sel)
205205
return transformer, V[m_sel], v_pen, (V, varselectorfit)
206206

207-
def D_LassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5):
207+
def D_LassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5, fit_args={}):
208208
"""
209209
Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X and Lasso of D_full ~ X_full
210210
and then combines the coefficients into weights using y_V_share
@@ -224,14 +224,15 @@ def _D_LassoCV_MatchSpace_wrapper(X, Y, **kwargs):
224224
n_v_cv=n_v_cv,
225225
sample_frac=sample_frac,
226226
y_V_share=y_V_share,
227+
fit_args=fit_args,
227228
**kwargs
228229
)
229230

230231
return _D_LassoCV_MatchSpace_wrapper
231232

232233

233234
def _D_LassoCV_MatchSpace(
234-
X, Y, X_full, D_full, v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5, **kwargs
235+
X, Y, X_full, D_full, v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5, fit_args={}, **kwargs
235236
): # pylint: disable=missing-param-doc, unused-argument
236237
if sample_frac < 1:
237238
N_y = X.shape[0]
@@ -242,7 +243,7 @@ def _D_LassoCV_MatchSpace(
242243
sample_d = np.random.choice(N_d, int(sample_frac * N_d), replace=False)
243244
X_full = X_full[sample_d, :]
244245
D_full = D_full[sample_d]
245-
y_varselectorfit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens).fit(
246+
y_varselectorfit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens, **fit_args).fit(
246247
X, Y
247248
)
248249
y_V = np.sqrt(
@@ -465,7 +466,7 @@ def transform(self, X):
465466
return M
466467

467468

468-
def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5):
469+
def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5, fit_args={}):
469470
"""
470471
Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X with the penalization optimized to reduce errors on goal units
471472
@@ -476,17 +477,17 @@ def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5):
476477

477478
def _MTLassoMixed_MatchSpace_wrapper(X, Y, fit_model_wrapper, **kwargs):
478479
return _MTLassoMixed_MatchSpace(
479-
X, Y, fit_model_wrapper, v_pens=v_pens, n_v_cv=n_v_cv, **kwargs
480+
X, Y, fit_model_wrapper, v_pens=v_pens, n_v_cv=n_v_cv, fit_args=fit_args, **kwargs
480481
)
481482

482483
return _MTLassoMixed_MatchSpace_wrapper
483484

484485

485486
def _MTLassoMixed_MatchSpace(
486-
X, Y, fit_model_wrapper, v_pens=None, n_v_cv=5, **kwargs
487+
X, Y, fit_model_wrapper, v_pens=None, n_v_cv=5, fit_args={}, **kwargs
487488
): # pylint: disable=missing-param-doc, unused-argument
488489
# Note that MultiTaskLasso(CV).path with the same alpha doesn't produce same results as MultiTaskLasso(CV)
489-
mtlasso_cv_fit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens).fit(
490+
mtlasso_cv_fit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens, **fit_args).fit(
490491
X, Y
491492
)
492493
# V_cv = np.sqrt(np.sum(np.square(mtlasso_cv_fit.coef_), axis=0)) #n_tasks x n_features -> n_feature

0 commit comments

Comments
 (0)