Skip to content

Commit cfb9e23

Browse files
test and manifest updates
Signed-off-by: greg pereira <[email protected]>
1 parent 33c95d7 commit cfb9e23

File tree

2 files changed

+93
-12
lines changed

2 files changed

+93
-12
lines changed

latencypredictor/manifests/dual-server-deployment.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ data:
1313
LATENCY_TPOT_MODEL_PATH: "/models/tpot.joblib"
1414
LATENCY_TTFT_SCALER_PATH: "/models/ttft_scaler.joblib"
1515
LATENCY_TPOT_SCALER_PATH: "/models/tpot_scaler.joblib"
16-
LOCAL_TPOT_SCALER_PATH: "/models/tpot_scaler.so"
17-
LOCAL_TPOT_TREELITE_PATH: "/models/tpot_treelite.so"
16+
LATENCY_TTFT_TREELITE_PATH: "/models/ttft_treelite.so"
17+
LATENCY_TPOT_TREELITE_PATH: "/models/tpot_treelite.so"
1818
LATENCY_MODEL_TYPE: "xgboost"
1919

2020
---
@@ -24,7 +24,7 @@ metadata:
2424
name: prediction-server-config
2525
namespace: default
2626
data:
27-
MODEL_SYNC_INTERVAL_SEC: "10" # Download models every 5 seconds
27+
MODEL_SYNC_INTERVAL_SEC: "10" # Download models every 10 seconds
2828
LATENCY_MODEL_TYPE: "xgboost"
2929
PREDICT_HOST: "0.0.0.0"
3030
PREDICT_PORT: "8001"
@@ -33,6 +33,9 @@ data:
3333
LOCAL_TPOT_MODEL_PATH: "/local_models/tpot.joblib"
3434
LOCAL_TTFT_SCALER_PATH: "/local_models/ttft_scaler.joblib"
3535
LOCAL_TPOT_SCALER_PATH: "/local_models/tpot_scaler.joblib"
36+
LOCAL_TTFT_TREELITE_PATH: "/local_models/ttft_treelite.so"
37+
LOCAL_TPOT_TREELITE_PATH: "/local_models/tpot_treelite.so"
38+
USE_TREELITE: "true" # Enable TreeLite for faster inference
3639
HTTP_TIMEOUT: "30"
3740

3841
---

latencypredictor/test_dual_server_client.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,20 @@ def test_training_server_models_list():
119119
"""Test training server models list endpoint."""
120120
r = requests.get(f"{TRAINING_URL}/models/list")
121121
assert r.status_code == 200
122-
122+
123123
data = r.json()
124124
assert "models" in data
125125
assert "model_type" in data
126126
assert "server_time" in data
127-
127+
128128
models = data["models"]
129129
expected_models = ["ttft", "tpot"]
130130
if data["model_type"] == "bayesian_ridge":
131131
expected_models.extend(["ttft_scaler", "tpot_scaler"])
132-
132+
elif data["model_type"] in ["xgboost", "lightgbm"]:
133+
# TreeLite models should also be available for XGBoost and LightGBM
134+
expected_models.extend(["ttft_treelite", "tpot_treelite"])
135+
133136
for model_name in expected_models:
134137
assert model_name in models, f"Model {model_name} should be listed"
135138
print(f"Model {model_name}: exists={models[model_name]['exists']}, size={models[model_name]['size_bytes']} bytes")
@@ -140,7 +143,8 @@ def test_model_download_from_training_server():
140143
# First check what models are available
141144
models_r = requests.get(f"{TRAINING_URL}/models/list")
142145
models_data = models_r.json()
143-
146+
147+
# Test basic models (ttft, tpot)
144148
for model_name in ["ttft", "tpot"]:
145149
if models_data["models"][model_name]["exists"]:
146150
# Test model info endpoint
@@ -149,13 +153,13 @@ def test_model_download_from_training_server():
149153
info_data = info_r.json()
150154
assert info_data["exists"] == True
151155
assert info_data["size_bytes"] > 0
152-
156+
153157
# Test model download with retry and streaming
154158
max_retries = 3
155159
for attempt in range(max_retries):
156160
try:
157161
download_r = requests.get(
158-
f"{TRAINING_URL}/model/{model_name}/download",
162+
f"{TRAINING_URL}/model/{model_name}/download",
159163
timeout=30,
160164
stream=True # Use streaming to handle large files better
161165
)
@@ -164,7 +168,7 @@ def test_model_download_from_training_server():
164168
content_length = 0
165169
for chunk in download_r.iter_content(chunk_size=8192):
166170
content_length += len(chunk)
167-
171+
168172
assert content_length > 0, f"Downloaded {model_name} model is empty"
169173
print(f"Successfully downloaded {model_name} model ({content_length} bytes)")
170174
break
@@ -176,6 +180,79 @@ def test_model_download_from_training_server():
176180
continue
177181
time.sleep(2) # Wait before retry
178182

183+
# Test TreeLite models for XGBoost and LightGBM
184+
model_type = models_data["model_type"]
185+
if model_type in ["xgboost", "lightgbm"]:
186+
for model_name in ["ttft_treelite", "tpot_treelite"]:
187+
if models_data["models"].get(model_name, {}).get("exists"):
188+
# Test model info endpoint
189+
info_r = requests.get(f"{TRAINING_URL}/model/{model_name}/info")
190+
assert info_r.status_code == 200
191+
info_data = info_r.json()
192+
assert info_data["exists"] == True
193+
assert info_data["size_bytes"] > 0
194+
195+
# Test model download with retry and streaming
196+
max_retries = 3
197+
for attempt in range(max_retries):
198+
try:
199+
download_r = requests.get(
200+
f"{TRAINING_URL}/model/{model_name}/download",
201+
timeout=30,
202+
stream=True
203+
)
204+
if download_r.status_code == 200:
205+
# Read content in chunks to avoid memory issues
206+
content_length = 0
207+
for chunk in download_r.iter_content(chunk_size=8192):
208+
content_length += len(chunk)
209+
210+
assert content_length > 0, f"Downloaded {model_name} model is empty"
211+
print(f"Successfully downloaded {model_name} TreeLite model ({content_length} bytes)")
212+
break
213+
except requests.exceptions.ChunkedEncodingError as e:
214+
print(f"Download attempt {attempt + 1}/{max_retries} failed for {model_name}: {e}")
215+
if attempt == max_retries - 1:
216+
print(f"⚠️ TreeLite model download test skipped for {model_name} due to connection issues")
217+
continue
218+
time.sleep(2) # Wait before retry
219+
220+
def test_treelite_models_on_training_server():
221+
"""Test TreeLite model endpoints on training server for XGBoost and LightGBM."""
222+
model_info_r = requests.get(f"{TRAINING_URL}/model/download/info")
223+
model_type = model_info_r.json().get("model_type")
224+
225+
if model_type not in ["xgboost", "lightgbm"]:
226+
print(f"Skipping TreeLite tests - model type is {model_type}")
227+
return
228+
229+
print(f"Testing TreeLite models for {model_type}...")
230+
231+
# Test TTFT TreeLite model
232+
ttft_info_r = requests.get(f"{TRAINING_URL}/model/ttft_treelite/info")
233+
if ttft_info_r.status_code == 200:
234+
ttft_info = ttft_info_r.json()
235+
if ttft_info.get("exists"):
236+
print(f"✓ TTFT TreeLite model available ({ttft_info['size_bytes']} bytes)")
237+
assert ttft_info["size_bytes"] > 0, "TTFT TreeLite model should have non-zero size"
238+
else:
239+
print(f"TTFT TreeLite model not yet generated")
240+
else:
241+
print(f"TTFT TreeLite model endpoint returned status {ttft_info_r.status_code}")
242+
243+
# Test TPOT TreeLite model
244+
tpot_info_r = requests.get(f"{TRAINING_URL}/model/tpot_treelite/info")
245+
if tpot_info_r.status_code == 200:
246+
tpot_info = tpot_info_r.json()
247+
if tpot_info.get("exists"):
248+
print(f"✓ TPOT TreeLite model available ({tpot_info['size_bytes']} bytes)")
249+
assert tpot_info["size_bytes"] > 0, "TPOT TreeLite model should have non-zero size"
250+
else:
251+
print(f"TPOT TreeLite model not yet generated")
252+
else:
253+
print(f"TPOT TreeLite model endpoint returned status {tpot_info_r.status_code}")
254+
255+
179256
def test_lightgbm_endpoints_on_training_server():
180257
"""Test LightGBM endpoints on training server if LightGBM is being used."""
181258
model_info_r = requests.get(f"{TRAINING_URL}/model/download/info")
@@ -1370,6 +1447,7 @@ def test_training_server_flush_error_handling():
13701447
("Training Server Model Info", test_training_server_model_info),
13711448
("Training Server Models List", test_training_server_models_list),
13721449
("Model Download", test_model_download_from_training_server),
1450+
("TreeLite Models", test_treelite_models_on_training_server),
13731451
("Send Training Data", test_add_training_data_to_training_server),
13741452
("Model Sync", test_prediction_server_model_sync),
13751453
("Predictions", test_prediction_via_prediction_server),
@@ -1380,9 +1458,9 @@ def test_training_server_flush_error_handling():
13801458
("Training Metrics", test_training_server_metrics),
13811459
("Model Consistency", test_model_consistency_between_servers),
13821460
("XGBoost Trees", test_model_specific_endpoints_on_training_server),
1383-
("Flush API", test_training_server_flush_api),
1461+
("Flush API", test_training_server_flush_api),
13841462
("Flush Error Handling", test_training_server_flush_error_handling),
1385-
1463+
13861464
("Dual Server Model Learns Equation", test_dual_server_quantile_regression_learns_distribution),
13871465
("End-to-End Workflow", test_end_to_end_workflow),
13881466
("Prediction Stress Test", test_prediction_server_stress_test),

0 commit comments

Comments
 (0)