@@ -346,6 +346,39 @@ def log_model_config(config: Dict[str, str]) -> None:
346346
347347 mlflow .log_params (config )
348348
349+ @staticmethod
350+ def _set_model_version_tags (
351+ client : MlflowClient ,
352+ model_name : str ,
353+ version : str ,
354+ model_type : Optional [str ] = None ,
355+ validation_status : Optional [str ] = None ,
356+ ) -> None :
357+ """
358+ Sets standard tags on a model version for serving and discovery.
359+
360+ Args:
361+ client (MlflowClient): The MLflow client to use for setting tags.
362+ model_name (str): The name of the registered model.
363+ version (str): The version of the model.
364+ model_type (Optional[str]): The type of the model (e.g., "medcat_snomed").
365+ validation_status (Optional[str]): The status of the model validation (e.g., "pending").
366+ """
367+ try :
368+ client .set_model_version_tag (
369+ name = model_name , version = version , key = "model_uri" , value = f"models:/{ model_name } /{ version } "
370+ )
371+ if model_type is not None :
372+ client .set_model_version_tag (
373+ name = model_name , version = version , key = "model_type" , value = model_type
374+ )
375+ if validation_status is not None :
376+ client .set_model_version_tag (
377+ name = model_name , version = version , key = "validation_status" , value = validation_status
378+ )
379+ except Exception :
380+ logger .warning ("Failed to set tags on version %s of model %s" , version , model_name )
381+
349382 @staticmethod
350383 def log_model (
351384 model_name : str ,
@@ -386,6 +419,7 @@ def save_pretrained_model(
386419 model_config : Optional [Dict ] = None ,
387420 model_metrics : Optional [List [Dict ]] = None ,
388421 model_tags : Optional [Dict ] = None ,
422+ model_type : Optional [str ] = None ,
389423 ) -> None :
390424 """
391425 Saves a pretrained model to the tracking backend and associated metadata.
@@ -399,6 +433,7 @@ def save_pretrained_model(
399433 model_config (Optional[Dict]): The configuration of the model to save.
400434 model_metrics (Optional[List[Dict]]): The list of dictionaries containing model metrics to save.
401435 model_tags (Optional[Dict]): The dictionary of tags to set for the model.
436+ model_type (Optional[str]): The type of the model (e.g., "medcat_snomed").
402437 """
403438
404439 experiment_name = TrackerClient .get_experiment_name (model_name , training_type )
@@ -423,6 +458,10 @@ def save_pretrained_model(
423458 mlflow .set_tags (tags )
424459 model_name = model_name .replace (" " , "_" )
425460 TrackerClient .log_model (model_name , model_path , model_manager , model_name )
461+ client = MlflowClient ()
462+ versions = client .search_model_versions (f"name='{ model_name } '" )
463+ if versions :
464+ TrackerClient ._set_model_version_tags (client , model_name , versions [0 ].version , model_type )
426465 TrackerClient .end_with_success ()
427466 except KeyboardInterrupt :
428467 TrackerClient .end_with_interruption ()
@@ -503,6 +542,7 @@ def save_model(
503542 model_name : str ,
504543 model_manager : ModelManager ,
505544 validation_status : str = "pending" ,
545+ model_type : Optional [str ] = None ,
506546 ) -> str :
507547 """
508548 Saves a model and its information to the tracking backend.
@@ -512,6 +552,7 @@ def save_model(
512552 model_name (str): The name of the model.
513553 model_manager (ModelManager): The instance of ModelManager used for model saving.
514554 validation_status (str): The status of the model validation (default: "pending").
555+ model_type (Optional[str]): The type of the model (e.g., "medcat_snomed").
515556
516557 Returns:
517558 str: The artifact URI of the saved model.
@@ -524,12 +565,10 @@ def save_model(
524565 if not mlflow .get_tracking_uri ().startswith ("file:/" ):
525566 TrackerClient .log_model (model_name , filepath , model_manager , model_name )
526567 versions = self .mlflow_client .search_model_versions (f"name='{ model_name } '" )
527- self .mlflow_client .set_model_version_tag (
528- name = model_name ,
529- version = versions [0 ].version ,
530- key = "validation_status" ,
531- value = validation_status ,
532- )
568+ if versions :
569+ TrackerClient ._set_model_version_tags (
570+ self .mlflow_client , model_name , versions [0 ].version , model_type , validation_status
571+ )
533572 else :
534573 TrackerClient .log_model (model_name , filepath , model_manager )
535574
0 commit comments