Skip to content

Commit a3d1f16

Browse files
committed
mlflow: Add version tags for registered models
Add the following model version tags when logging a model to MLflow: * model_uri: The URI of the model artifact * model_type: The type of the model (e.g. 'medcat_snomed') * validation_status: The validation status of the model (e.g. 'pending') Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 40cb37d commit a3d1f16

File tree

8 files changed

+105
-14
lines changed

8 files changed

+105
-14
lines changed

app/cli/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def register_model(
304304
model_config=m_config,
305305
model_metrics=m_metrics,
306306
model_tags=m_tags,
307+
model_type=model_type.value,
307308
)
308309
typer.echo(f"Pushed {model_path} as a new model version ({run_name})")
309310

app/management/tracker_client.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,39 @@ def log_model_config(config: Dict[str, str]) -> None:
346346

347347
mlflow.log_params(config)
348348

349+
@staticmethod
350+
def _set_model_version_tags(
351+
client: MlflowClient,
352+
model_name: str,
353+
version: str,
354+
model_type: Optional[str] = None,
355+
validation_status: Optional[str] = None,
356+
) -> None:
357+
"""
358+
Sets standard tags on a model version for serving and discovery.
359+
360+
Args:
361+
client (MlflowClient): The MLflow client to use for setting tags.
362+
model_name (str): The name of the registered model.
363+
version (str): The version of the model.
364+
model_type (Optional[str]): The type of the model (e.g., "medcat_snomed").
365+
validation_status (Optional[str]): The status of the model validation (e.g., "pending").
366+
"""
367+
try:
368+
client.set_model_version_tag(
369+
name=model_name, version=version, key="model_uri", value=f"models:/{model_name}/{version}"
370+
)
371+
if model_type is not None:
372+
client.set_model_version_tag(
373+
name=model_name, version=version, key="model_type", value=model_type
374+
)
375+
if validation_status is not None:
376+
client.set_model_version_tag(
377+
name=model_name, version=version, key="validation_status", value=validation_status
378+
)
379+
except Exception:
380+
logger.warning("Failed to set tags on version %s of model %s", version, model_name)
381+
349382
@staticmethod
350383
def log_model(
351384
model_name: str,
@@ -386,6 +419,7 @@ def save_pretrained_model(
386419
model_config: Optional[Dict] = None,
387420
model_metrics: Optional[List[Dict]] = None,
388421
model_tags: Optional[Dict] = None,
422+
model_type: Optional[str] = None,
389423
) -> None:
390424
"""
391425
Saves a pretrained model to the tracking backend and associated metadata.
@@ -399,6 +433,7 @@ def save_pretrained_model(
399433
model_config (Optional[Dict]): The configuration of the model to save.
400434
model_metrics (Optional[List[Dict]]): The list of dictionaries containing model metrics to save.
401435
model_tags (Optional[Dict]): The dictionary of tags to set for the model.
436+
model_type (Optional[str]): The type of the model (e.g., "medcat_snomed").
402437
"""
403438

404439
experiment_name = TrackerClient.get_experiment_name(model_name, training_type)
@@ -423,6 +458,10 @@ def save_pretrained_model(
423458
mlflow.set_tags(tags)
424459
model_name = model_name.replace(" ", "_")
425460
TrackerClient.log_model(model_name, model_path, model_manager, model_name)
461+
client = MlflowClient()
462+
versions = client.search_model_versions(f"name='{model_name}'")
463+
if versions:
464+
TrackerClient._set_model_version_tags(client, model_name, versions[0].version, model_type)
426465
TrackerClient.end_with_success()
427466
except KeyboardInterrupt:
428467
TrackerClient.end_with_interruption()
@@ -503,6 +542,7 @@ def save_model(
503542
model_name: str,
504543
model_manager: ModelManager,
505544
validation_status: str = "pending",
545+
model_type: Optional[str] = None,
506546
) -> str:
507547
"""
508548
Saves a model and its information to the tracking backend.
@@ -512,6 +552,7 @@ def save_model(
512552
model_name (str): The name of the model.
513553
model_manager (ModelManager): The instance of ModelManager used for model saving.
514554
validation_status (str): The status of the model validation (default: "pending").
555+
model_type (Optional[str]): The type of the model (e.g., "medcat_snomed").
515556
516557
Returns:
517558
str: The artifact URI of the saved model.
@@ -524,12 +565,10 @@ def save_model(
524565
if not mlflow.get_tracking_uri().startswith("file:/"):
525566
TrackerClient.log_model(model_name, filepath, model_manager, model_name)
526567
versions = self.mlflow_client.search_model_versions(f"name='{model_name}'")
527-
self.mlflow_client.set_model_version_tag(
528-
name=model_name,
529-
version=versions[0].version,
530-
key="validation_status",
531-
value=validation_status,
532-
)
568+
if versions:
569+
TrackerClient._set_model_version_tags(
570+
self.mlflow_client, model_name, versions[0].version, model_type, validation_status
571+
)
533572
else:
534573
TrackerClient.log_model(model_name, filepath, model_manager)
535574

app/trainers/huggingface_llm_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def run(
436436
retrained_model_pack_path,
437437
self._model_name,
438438
self._model_manager,
439+
model_type=self._model_service.info().model_type.value,
439440
)
440441
logger.info(f"Retrained model saved: {model_uri}")
441442
else:

app/trainers/huggingface_ner_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def run(
237237
retrained_model_pack_path,
238238
self._model_name,
239239
self._model_manager,
240+
model_type=self._model_service.info().model_type.value,
240241
)
241242
logger.info(f"Retrained model saved: {model_uri}")
242243
else:
@@ -664,6 +665,7 @@ def run(
664665
retrained_model_pack_path,
665666
self._model_name,
666667
self._model_manager,
668+
model_type=self._model_service.info().model_type.value,
667669
)
668670
logger.info(f"Retrained model saved: {model_uri}")
669671
else:

app/trainers/medcat_deid_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ def run(
185185
)
186186
with open(cdb_config_path, "w") as f:
187187
json.dump(dump_pydantic_object_to_dict(model.config), f)
188-
model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager)
188+
model_uri = self._tracker_client.save_model(
189+
model_pack_path,
190+
self._model_name,
191+
self._model_manager,
192+
model_type=self._model_service.info().model_type.value,
193+
)
189194
logger.info("Retrained model saved: %s", model_uri)
190195
self._tracker_client.save_model_artifact(cdb_config_path, self._model_name)
191196
else:

app/trainers/medcat_trainer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,12 @@ def run(
211211
)
212212
with open(cdb_config_path, "w") as f:
213213
json.dump(dump_pydantic_object_to_dict(model.config), f)
214-
model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager)
214+
model_uri = self._tracker_client.save_model(
215+
model_pack_path,
216+
self._model_name,
217+
self._model_manager,
218+
model_type=self._model_service.info().model_type.value,
219+
)
215220
logger.info("Retrained model saved: %s", model_uri)
216221
self._tracker_client.save_model_artifact(cdb_config_path, self._model_name)
217222
else:
@@ -472,7 +477,12 @@ def run(
472477
)
473478
with open(cdb_config_path, "w") as f:
474479
json.dump(dump_pydantic_object_to_dict(model.config), f)
475-
model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager)
480+
model_uri = self._tracker_client.save_model(
481+
model_pack_path,
482+
self._model_name,
483+
self._model_manager,
484+
model_type=self._model_service.info().model_type.value,
485+
)
476486
logger.info(f"Retrained model saved: {model_uri}")
477487
self._tracker_client.save_model_artifact(cdb_config_path, self._model_name)
478488
else:

app/trainers/metacat_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ def run(
159159
)
160160
with open(cdb_config_path, "w") as f:
161161
json.dump(dump_pydantic_object_to_dict(model.config), f)
162-
model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager)
162+
model_uri = self._tracker_client.save_model(
163+
model_pack_path,
164+
self._model_name,
165+
self._model_manager,
166+
model_type=self._model_service.info().model_type.value,
167+
)
163168
logger.info("Retrained model saved: %s", model_uri)
164169
self._tracker_client.save_model_artifact(cdb_config_path, self._model_name)
165170
else:

tests/app/monitoring/test_tracker_client.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datasets
44
import pytest
55
import pandas as pd
6-
from unittest.mock import Mock, call, ANY
6+
from unittest.mock import Mock, call, patch, ANY
77
from app.management.tracker_client import TrackerClient
88
from app.data import doc_dataset
99
from app.domain import TrainerBackend
@@ -161,11 +161,23 @@ def test_save_model(mlflow_fixture):
161161
mlflow_client.search_model_versions.return_value = [version]
162162
tracker_client.mlflow_client = mlflow_client
163163

164-
artifact_uri = tracker_client.save_model("path/to/file.zip", "model_name", model_manager, "validation_status")
164+
artifact_uri = tracker_client.save_model(
165+
"path/to/file.zip", "model_name", model_manager, "validation_status", "model_type"
166+
)
165167

166168
assert "artifacts/model_name" in artifact_uri
167169
model_manager.log_model.assert_called_once_with("model_name", "path/to/file.zip", "model_name")
168-
mlflow_client.set_model_version_tag.assert_called_once_with(name="model_name", version="1", key="validation_status", value="validation_status")
170+
mlflow_client.search_model_versions.assert_called_once_with("name='model_name'")
171+
assert mlflow_client.set_model_version_tag.call_count == 3
172+
mlflow_client.set_model_version_tag.assert_any_call(
173+
name="model_name", version="1", key="model_uri", value="models:/model_name/1"
174+
)
175+
mlflow_client.set_model_version_tag.assert_any_call(
176+
name="model_name", version="1", key="model_type", value="model_type"
177+
)
178+
mlflow_client.set_model_version_tag.assert_any_call(
179+
name="model_name", version="1", key="validation_status", value="validation_status"
180+
)
169181
mlflow.set_tag.has_calls(
170182
[
171183
call("training.output.package", "file.zip"),
@@ -184,9 +196,15 @@ def test_save_model_local(mlflow_fixture):
184196
model_manager.save_model.assert_called_once_with("local_dir", "filepath")
185197

186198

187-
def test_save_pretrained_model(mlflow_fixture):
199+
@patch("app.management.tracker_client.MlflowClient")
200+
def test_save_pretrained_model(mock_mlflow_client_class, mlflow_fixture):
188201
tracker_client = TrackerClient("")
189202
model_manager = Mock()
203+
mlflow_client = Mock()
204+
version = Mock()
205+
version.version = "1"
206+
mlflow_client.search_model_versions.return_value = [version]
207+
mock_mlflow_client_class.return_value = mlflow_client
190208

191209
tracker_client.save_pretrained_model(
192210
"model_name",
@@ -197,6 +215,7 @@ def test_save_pretrained_model(mlflow_fixture):
197215
{"param": "value"},
198216
[{"p": 0.8, "r": 0.8}, {"p": 0.9, "r": 0.9}],
199217
{"tag_name": "tag_value"},
218+
"model_type",
200219
)
201220

202221
mlflow.get_experiment_by_name.assert_called_once_with("model_name_training_type")
@@ -212,6 +231,15 @@ def test_save_pretrained_model(mlflow_fixture):
212231
assert len(mlflow.set_tags.call_args.args[0]["mlflow.source.name"]) > 0
213232
assert mlflow.set_tags.call_args.args[0]["tag_name"] == "tag_value"
214233

234+
mlflow_client.search_model_versions.assert_called_once_with("name='model_name'")
235+
assert mlflow_client.set_model_version_tag.call_count == 2
236+
mlflow_client.set_model_version_tag.assert_any_call(
237+
name="model_name", version="1", key="model_uri", value="models:/model_name/1"
238+
)
239+
mlflow_client.set_model_version_tag.assert_any_call(
240+
name="model_name", version="1", key="model_type", value="model_type"
241+
)
242+
215243

216244
def test_log_single_exception(mlflow_fixture):
217245
tracker_client = TrackerClient("")

0 commit comments

Comments
 (0)