@@ -119,7 +119,7 @@ def _block_summ_cols(Y, Y_col_block_size):
119119 print ("Can only average target across columns blocks if blocks fit evenly" )
120120 return Y
121121
122- def MTLassoCV_MatchSpace_factory (v_pens = None , n_v_cv = 5 , sample_frac = 1 , Y_col_block_size = None , se_factor = None , normalize = True ):
122+ def MTLassoCV_MatchSpace_factory (v_pens = None , n_v_cv = 5 , sample_frac = 1 , Y_col_block_size = None , se_factor = None , normalize = True , fit_args = {} ):
123123 """
124124 Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X
125125
@@ -133,14 +133,14 @@ def MTLassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, Y_col_blo
133133
134134 def _MTLassoCV_MatchSpace_wrapper (X , Y , ** kwargs ):
135135 return _MTLassoCV_MatchSpace (
136- X , Y , v_pens = v_pens , n_v_cv = n_v_cv , sample_frac = sample_frac , Y_col_block_size = Y_col_block_size , se_factor = se_factor , normalize = normalize , ** kwargs
136+ X , Y , v_pens = v_pens , n_v_cv = n_v_cv , sample_frac = sample_frac , Y_col_block_size = Y_col_block_size , se_factor = se_factor , normalize = normalize , fit_args = fit_args , ** kwargs
137137 )
138138
139139 return _MTLassoCV_MatchSpace_wrapper
140140
141141
142142def _MTLassoCV_MatchSpace (
143- X , Y , v_pens = None , n_v_cv = 5 , sample_frac = 1 , Y_col_block_size = None , se_factor = None , normalize = True , ** kwargs
143+ X , Y , v_pens = None , n_v_cv = 5 , sample_frac = 1 , Y_col_block_size = None , se_factor = None , normalize = True , fit_args = {}, ** kwargs
144144): # pylint: disable=missing-param-doc, unused-argument
145145 # A fake MT would do Lasso on y_mean = Y.mean(axis=1)
146146 if sample_frac < 1 :
@@ -150,7 +150,7 @@ def _MTLassoCV_MatchSpace(
150150 Y = Y [sample , :]
151151 if Y_col_block_size is not None :
152152 Y = _block_summ_cols (Y , Y_col_block_size )
153- varselectorfit = MultiTaskLassoCV (normalize = normalize , cv = n_v_cv , alphas = v_pens ).fit (
153+ varselectorfit = MultiTaskLassoCV (normalize = normalize , cv = n_v_cv , alphas = v_pens , ** fit_args ).fit (
154154 X , Y
155155 )
156156 best_v_pen = varselectorfit .alpha_
0 commit comments