|
5 | 5 | from PIL import Image |
6 | 6 | from celery import shared_task |
7 | 7 | import cv2 |
| 8 | +from django.conf import settings |
8 | 9 | from django.core.files import File |
9 | 10 | import librosa |
10 | 11 | import matplotlib.pyplot as plt |
@@ -504,3 +505,56 @@ def predict_compressed(image_file): |
504 | 505 | confs = dict(zip(labels, outputs)) |
505 | 506 |
|
506 | 507 | return label, score, confs |
| 508 | + |
| 509 | + |
| 510 | +def train_body(experiment_name: str): |
| 511 | + import mlflow |
| 512 | + from mlflow.models import infer_signature |
| 513 | + from sklearn import datasets |
| 514 | + from sklearn.linear_model import LogisticRegression |
| 515 | + from sklearn.metrics import accuracy_score |
| 516 | + from sklearn.model_selection import train_test_split |
| 517 | + |
| 518 | + X, y = datasets.load_iris(return_X_y=True) |
| 519 | + |
| 520 | + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) |
| 521 | + |
| 522 | + params = { |
| 523 | + 'solver': 'lbfgs', |
| 524 | + 'max_iter': 1000, |
| 525 | + 'multi_class': 'auto', |
| 526 | + 'random_state': 8888, |
| 527 | + } |
| 528 | + |
| 529 | + lr = LogisticRegression(**params) |
| 530 | + lr.fit(X_train, y_train) |
| 531 | + |
| 532 | + y_pred = lr.predict(X_test) |
| 533 | + |
| 534 | + accuracy = accuracy_score(y_test, y_pred) |
| 535 | + |
| 536 | + mlflow.set_tracking_uri(settings.MLFLOW_ENDPOINT) |
| 537 | + mlflow.set_experiment(experiment_name) |
| 538 | + |
| 539 | + print(mlflow.get_tracking_uri()) |
| 540 | + print(mlflow.get_artifact_uri()) |
| 541 | + |
| 542 | + mlflow.end_run() |
| 543 | + with mlflow.start_run(): |
| 544 | + mlflow.log_params(params) |
| 545 | + mlflow.log_metric('accuracy', accuracy) |
| 546 | + mlflow.set_tag('Training Info', 'Basic LR model for iris data') |
| 547 | + |
| 548 | + signature = infer_signature(X_train, lr.predict(X_train)) |
| 549 | + _ = mlflow.sklearn.log_model( |
| 550 | + sk_model=lr, |
| 551 | + artifact_path='iris_model', |
| 552 | + signature=signature, |
| 553 | + input_example=X_train, |
| 554 | + registered_model_name='tracking-quickstart', |
| 555 | + ) |
| 556 | + |
| 557 | + |
| 558 | +@shared_task |
| 559 | +def example_train(experiment_name: str): |
| 560 | + train_body(experiment_name) |
0 commit comments