Skip to content

Commit 0c6b977

Browse files
committed
refactor(prod_model): wrapping, removal of unnecessary code
1 parent 76c2db2 commit 0c6b977

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

github_dagger_workflow_project/04_prod_model.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,40 @@
88
model_name = "lead_model"
99
experiment_best = pd.read_pickle(PROD_BEST_EXPERIMENT_PATH)
1010

11-
prod_model = [
12-
model
13-
for model in client.search_model_versions(f"name='{model_name}'")
14-
if dict(model)["current_stage"] == "Production"
15-
]
16-
prod_model_exists = len(prod_model) > 0
1711

18-
if prod_model_exists:
19-
prod_model_version = dict(prod_model[0])["version"]
20-
prod_model_run_id = dict(prod_model[0])["run_id"]
12+
def get_production_model(model_name):
13+
return [
14+
model
15+
for model in client.search_model_versions(f"name='{model_name}'")
16+
if dict(model)["current_stage"] == "Production"
17+
]
18+
19+
20+
def get_model_score(run_id):
21+
data, _ = mlflow.get_run(run_id)
22+
return data[1]["metrics.f1_score"]
23+
24+
25+
def register_and_wait_model(run_id, artifact_path, model_name):
26+
model_uri = f"runs:/{run_id}/{artifact_path}"
27+
model_details = mlflow.register_model(model_uri=model_uri, name=model_name)
28+
utils.wait_until_ready(model_details.name, model_details.version)
29+
return dict(model_details)
30+
2131

2232
train_model_score = experiment_best["metrics.f1_score"]
23-
model_details = {}
24-
model_status = {}
2533
run_id = None
34+
prod_model = get_production_model(model_name)
35+
prod_model_exists = len(prod_model) > 0
2636

2737
if prod_model_exists:
28-
data, details = mlflow.get_run(prod_model_run_id)
29-
prod_model_score = data[1]["metrics.f1_score"]
30-
31-
model_status["current"] = train_model_score
32-
model_status["prod"] = prod_model_score
38+
prod_model_run_id = dict(prod_model[0])["run_id"]
39+
prod_model_score = get_model_score(prod_model_run_id)
3340

3441
if train_model_score > prod_model_score:
3542
run_id = experiment_best["run_id"]
3643
else:
3744
run_id = experiment_best["run_id"]
3845

3946
if run_id is not None:
40-
model_uri = "runs:/{run_id}/{artifact_path}".format(run_id=run_id, artifact_path=artifact_path)
41-
model_details = mlflow.register_model(model_uri=model_uri, name=model_name)
42-
utils.wait_until_ready(model_details.name, model_details.version)
43-
model_details = dict(model_details)
47+
model_details = register_and_wait_model(run_id, artifact_path, model_name)

0 commit comments

Comments
 (0)