@@ -119,17 +119,20 @@ def test_training_server_models_list():
119119 """Test training server models list endpoint."""
120120 r = requests .get (f"{ TRAINING_URL } /models/list" )
121121 assert r .status_code == 200
122-
122+
123123 data = r .json ()
124124 assert "models" in data
125125 assert "model_type" in data
126126 assert "server_time" in data
127-
127+
128128 models = data ["models" ]
129129 expected_models = ["ttft" , "tpot" ]
130130 if data ["model_type" ] == "bayesian_ridge" :
131131 expected_models .extend (["ttft_scaler" , "tpot_scaler" ])
132-
132+ elif data ["model_type" ] in ["xgboost" , "lightgbm" ]:
133+ # TreeLite models should also be available for XGBoost and LightGBM
134+ expected_models .extend (["ttft_treelite" , "tpot_treelite" ])
135+
133136 for model_name in expected_models :
134137 assert model_name in models , f"Model { model_name } should be listed"
135138 print (f"Model { model_name } : exists={ models [model_name ]['exists' ]} , size={ models [model_name ]['size_bytes' ]} bytes" )
@@ -140,7 +143,8 @@ def test_model_download_from_training_server():
140143 # First check what models are available
141144 models_r = requests .get (f"{ TRAINING_URL } /models/list" )
142145 models_data = models_r .json ()
143-
146+
147+ # Test basic models (ttft, tpot)
144148 for model_name in ["ttft" , "tpot" ]:
145149 if models_data ["models" ][model_name ]["exists" ]:
146150 # Test model info endpoint
@@ -149,13 +153,13 @@ def test_model_download_from_training_server():
149153 info_data = info_r .json ()
150154 assert info_data ["exists" ] == True
151155 assert info_data ["size_bytes" ] > 0
152-
156+
153157 # Test model download with retry and streaming
154158 max_retries = 3
155159 for attempt in range (max_retries ):
156160 try :
157161 download_r = requests .get (
158- f"{ TRAINING_URL } /model/{ model_name } /download" ,
162+ f"{ TRAINING_URL } /model/{ model_name } /download" ,
159163 timeout = 30 ,
160164 stream = True # Use streaming to handle large files better
161165 )
@@ -164,7 +168,7 @@ def test_model_download_from_training_server():
164168 content_length = 0
165169 for chunk in download_r .iter_content (chunk_size = 8192 ):
166170 content_length += len (chunk )
167-
171+
168172 assert content_length > 0 , f"Downloaded { model_name } model is empty"
169173 print (f"Successfully downloaded { model_name } model ({ content_length } bytes)" )
170174 break
@@ -176,6 +180,79 @@ def test_model_download_from_training_server():
176180 continue
177181 time .sleep (2 ) # Wait before retry
178182
183+ # Test TreeLite models for XGBoost and LightGBM
184+ model_type = models_data ["model_type" ]
185+ if model_type in ["xgboost" , "lightgbm" ]:
186+ for model_name in ["ttft_treelite" , "tpot_treelite" ]:
187+ if models_data ["models" ].get (model_name , {}).get ("exists" ):
188+ # Test model info endpoint
189+ info_r = requests .get (f"{ TRAINING_URL } /model/{ model_name } /info" )
190+ assert info_r .status_code == 200
191+ info_data = info_r .json ()
192+ assert info_data ["exists" ] == True
193+ assert info_data ["size_bytes" ] > 0
194+
195+ # Test model download with retry and streaming
196+ max_retries = 3
197+ for attempt in range (max_retries ):
198+ try :
199+ download_r = requests .get (
200+ f"{ TRAINING_URL } /model/{ model_name } /download" ,
201+ timeout = 30 ,
202+ stream = True
203+ )
204+ if download_r .status_code == 200 :
205+ # Read content in chunks to avoid memory issues
206+ content_length = 0
207+ for chunk in download_r .iter_content (chunk_size = 8192 ):
208+ content_length += len (chunk )
209+
210+ assert content_length > 0 , f"Downloaded { model_name } model is empty"
211+ print (f"Successfully downloaded { model_name } TreeLite model ({ content_length } bytes)" )
212+ break
213+ except requests .exceptions .ChunkedEncodingError as e :
214+ print (f"Download attempt { attempt + 1 } /{ max_retries } failed for { model_name } : { e } " )
215+ if attempt == max_retries - 1 :
216+ print (f"⚠️ TreeLite model download test skipped for { model_name } due to connection issues" )
217+ continue
218+ time .sleep (2 ) # Wait before retry
219+
220+ def test_treelite_models_on_training_server ():
221+ """Test TreeLite model endpoints on training server for XGBoost and LightGBM."""
222+ model_info_r = requests .get (f"{ TRAINING_URL } /model/download/info" )
223+ model_type = model_info_r .json ().get ("model_type" )
224+
225+ if model_type not in ["xgboost" , "lightgbm" ]:
226+ print (f"Skipping TreeLite tests - model type is { model_type } " )
227+ return
228+
229+ print (f"Testing TreeLite models for { model_type } ..." )
230+
231+ # Test TTFT TreeLite model
232+ ttft_info_r = requests .get (f"{ TRAINING_URL } /model/ttft_treelite/info" )
233+ if ttft_info_r .status_code == 200 :
234+ ttft_info = ttft_info_r .json ()
235+ if ttft_info .get ("exists" ):
236+ print (f"✓ TTFT TreeLite model available ({ ttft_info ['size_bytes' ]} bytes)" )
237+ assert ttft_info ["size_bytes" ] > 0 , "TTFT TreeLite model should have non-zero size"
238+ else :
239+ print (f"TTFT TreeLite model not yet generated" )
240+ else :
241+ print (f"TTFT TreeLite model endpoint returned status { ttft_info_r .status_code } " )
242+
243+ # Test TPOT TreeLite model
244+ tpot_info_r = requests .get (f"{ TRAINING_URL } /model/tpot_treelite/info" )
245+ if tpot_info_r .status_code == 200 :
246+ tpot_info = tpot_info_r .json ()
247+ if tpot_info .get ("exists" ):
248+ print (f"✓ TPOT TreeLite model available ({ tpot_info ['size_bytes' ]} bytes)" )
249+ assert tpot_info ["size_bytes" ] > 0 , "TPOT TreeLite model should have non-zero size"
250+ else :
251+ print (f"TPOT TreeLite model not yet generated" )
252+ else :
253+ print (f"TPOT TreeLite model endpoint returned status { tpot_info_r .status_code } " )
254+
255+
179256def test_lightgbm_endpoints_on_training_server ():
180257 """Test LightGBM endpoints on training server if LightGBM is being used."""
181258 model_info_r = requests .get (f"{ TRAINING_URL } /model/download/info" )
@@ -1370,6 +1447,7 @@ def test_training_server_flush_error_handling():
13701447 ("Training Server Model Info" , test_training_server_model_info ),
13711448 ("Training Server Models List" , test_training_server_models_list ),
13721449 ("Model Download" , test_model_download_from_training_server ),
1450+ ("TreeLite Models" , test_treelite_models_on_training_server ),
13731451 ("Send Training Data" , test_add_training_data_to_training_server ),
13741452 ("Model Sync" , test_prediction_server_model_sync ),
13751453 ("Predictions" , test_prediction_via_prediction_server ),
@@ -1380,9 +1458,9 @@ def test_training_server_flush_error_handling():
13801458 ("Training Metrics" , test_training_server_metrics ),
13811459 ("Model Consistency" , test_model_consistency_between_servers ),
13821460 ("XGBoost Trees" , test_model_specific_endpoints_on_training_server ),
1383- ("Flush API" , test_training_server_flush_api ),
1461+ ("Flush API" , test_training_server_flush_api ),
13841462 ("Flush Error Handling" , test_training_server_flush_error_handling ),
1385-
1463+
13861464 ("Dual Server Model Learns Equation" , test_dual_server_quantile_regression_learns_distribution ),
13871465 ("End-to-End Workflow" , test_end_to_end_workflow ),
13881466 ("Prediction Stress Test" , test_prediction_server_stress_test ),
0 commit comments