@@ -204,7 +204,7 @@ def _MTLasso_MatchSpace(
204204 transformer = SelMatchSpace (m_sel )
205205 return transformer , V [m_sel ], v_pen , (V , varselectorfit )
206206
207- def D_LassoCV_MatchSpace_factory (v_pens = None , n_v_cv = 5 , sample_frac = 1 , y_V_share = 0.5 ):
207+ def D_LassoCV_MatchSpace_factory (v_pens = None , n_v_cv = 5 , sample_frac = 1 , y_V_share = 0.5 , fit_args = {} ):
208208 """
209209 Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X and Lasso of D_full ~ X_full
210210 and then combines the coefficients into weights using y_V_share
@@ -224,14 +224,15 @@ def _D_LassoCV_MatchSpace_wrapper(X, Y, **kwargs):
224224 n_v_cv = n_v_cv ,
225225 sample_frac = sample_frac ,
226226 y_V_share = y_V_share ,
227+ fit_args = fit_args ,
227228 ** kwargs
228229 )
229230
230231 return _D_LassoCV_MatchSpace_wrapper
231232
232233
233234def _D_LassoCV_MatchSpace (
234- X , Y , X_full , D_full , v_pens = None , n_v_cv = 5 , sample_frac = 1 , y_V_share = 0.5 , ** kwargs
235+ X , Y , X_full , D_full , v_pens = None , n_v_cv = 5 , sample_frac = 1 , y_V_share = 0.5 , fit_args = {}, ** kwargs
235236): # pylint: disable=missing-param-doc, unused-argument
236237 if sample_frac < 1 :
237238 N_y = X .shape [0 ]
@@ -242,7 +243,7 @@ def _D_LassoCV_MatchSpace(
242243 sample_d = np .random .choice (N_d , int (sample_frac * N_d ), replace = False )
243244 X_full = X_full [sample_d , :]
244245 D_full = D_full [sample_d ]
245- y_varselectorfit = MultiTaskLassoCV (normalize = True , cv = n_v_cv , alphas = v_pens ).fit (
246+ y_varselectorfit = MultiTaskLassoCV (normalize = True , cv = n_v_cv , alphas = v_pens , ** fit_args ).fit (
246247 X , Y
247248 )
248249 y_V = np .sqrt (
@@ -465,7 +466,7 @@ def transform(self, X):
465466 return M
466467
467468
468- def MTLassoMixed_MatchSpace_factory (v_pens = None , n_v_cv = 5 ):
469+ def MTLassoMixed_MatchSpace_factory (v_pens = None , n_v_cv = 5 , fit_args = {} ):
469470 """
470471 Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X with the penalization optimized to reduce errors on goal units
471472
@@ -476,17 +477,17 @@ def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5):
476477
477478 def _MTLassoMixed_MatchSpace_wrapper (X , Y , fit_model_wrapper , ** kwargs ):
478479 return _MTLassoMixed_MatchSpace (
479- X , Y , fit_model_wrapper , v_pens = v_pens , n_v_cv = n_v_cv , ** kwargs
480+ X , Y , fit_model_wrapper , v_pens = v_pens , n_v_cv = n_v_cv , fit_args = fit_args , ** kwargs
480481 )
481482
482483 return _MTLassoMixed_MatchSpace_wrapper
483484
484485
485486def _MTLassoMixed_MatchSpace (
486- X , Y , fit_model_wrapper , v_pens = None , n_v_cv = 5 , ** kwargs
487+ X , Y , fit_model_wrapper , v_pens = None , n_v_cv = 5 , fit_args = {}, ** kwargs
487488): # pylint: disable=missing-param-doc, unused-argument
488489 # Note that MultiTaskLasso(CV).path with the same alpha doesn't produce same results as MultiTaskLasso(CV)
489- mtlasso_cv_fit = MultiTaskLassoCV (normalize = True , cv = n_v_cv , alphas = v_pens ).fit (
490+ mtlasso_cv_fit = MultiTaskLassoCV (normalize = True , cv = n_v_cv , alphas = v_pens , ** fit_args ).fit (
490491 X , Y
491492 )
492493 # V_cv = np.sqrt(np.sum(np.square(mtlasso_cv_fit.coef_), axis=0)) #n_tasks x n_features -> n_feature
0 commit comments