Skip to content

Commit d1d7243

Browse files
committed
Move example model training to tasks
1 parent 01d5a06 commit d1d7243

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed
Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1-
from django.conf import settings
21
import djclick as click
3-
from mlflow import MlflowClient
4-
from sklearn.datasets import load_diabetes
5-
from sklearn.ensemble import RandomForestRegressor
6-
from sklearn.model_selection import train_test_split
2+
import mlflow
73

4+
from bats_ai.core.tasks import example_train
85

9-
@click.command()
10-
def command():
11-
click.echo("Running Mlflow experiment")
126

13-
client = MlflowClient(tracking_uri=settings.MLFLOW_ENDPOINT)
7+
@click.command()
8+
@click.option('--experiment-name', type=click.STRING, required=False, default='Default')
9+
def command(experiment_name):
10+
click.echo('Finding experiment')
11+
experiment = mlflow.get_experiment_by_name(experiment_name)
12+
if experiment:
13+
click.echo(f'Creating a log for experiment {experiment_name}')
14+
example_train.delay(experiment_name)
15+
# train_body(experiment_name)
16+
else:
17+
click.echo(
18+
f'Could not find experiment {experiment_name}.'
19+
' Use the create experiment command to create a new experiement.'
20+
)

bats_ai/tasks/tasks.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from PIL import Image
66
from celery import shared_task
77
import cv2
8+
from django.conf import settings
89
from django.core.files import File
910
import librosa
1011
import matplotlib.pyplot as plt
@@ -504,3 +505,56 @@ def predict_compressed(image_file):
504505
confs = dict(zip(labels, outputs))
505506

506507
return label, score, confs
508+
509+
510+
def train_body(experiment_name: str):
511+
import mlflow
512+
from mlflow.models import infer_signature
513+
from sklearn import datasets
514+
from sklearn.linear_model import LogisticRegression
515+
from sklearn.metrics import accuracy_score
516+
from sklearn.model_selection import train_test_split
517+
518+
X, y = datasets.load_iris(return_X_y=True)
519+
520+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
521+
522+
params = {
523+
'solver': 'lbfgs',
524+
'max_iter': 1000,
525+
'multi_class': 'auto',
526+
'random_state': 8888,
527+
}
528+
529+
lr = LogisticRegression(**params)
530+
lr.fit(X_train, y_train)
531+
532+
y_pred = lr.predict(X_test)
533+
534+
accuracy = accuracy_score(y_test, y_pred)
535+
536+
mlflow.set_tracking_uri(settings.MLFLOW_ENDPOINT)
537+
mlflow.set_experiment(experiment_name)
538+
539+
print(mlflow.get_tracking_uri())
540+
print(mlflow.get_artifact_uri())
541+
542+
mlflow.end_run()
543+
with mlflow.start_run():
544+
mlflow.log_params(params)
545+
mlflow.log_metric('accuracy', accuracy)
546+
mlflow.set_tag('Training Info', 'Basic LR model for iris data')
547+
548+
signature = infer_signature(X_train, lr.predict(X_train))
549+
_ = mlflow.sklearn.log_model(
550+
sk_model=lr,
551+
artifact_path='iris_model',
552+
signature=signature,
553+
input_example=X_train,
554+
registered_model_name='tracking-quickstart',
555+
)
556+
557+
558+
@shared_task
559+
def example_train(experiment_name: str):
560+
train_body(experiment_name)

0 commit comments

Comments
 (0)