Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions backend/app/api/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
import json
from hashlib import md5

from flask import jsonify, request
from flask_restful import Resource, reqparse, fields, abort
Expand All @@ -9,7 +10,7 @@
from app import db, models_uploadset
from app.models import Model, ModelSchema, ModelLesserSchema
from app.api import paginated_parser
from app.api.utils import NestedResponse
from app.api.utils import NestedResponse, str_type

from sqlalchemy.dialects.postgresql import array as postgres_array
from werkzeug.datastructures import FileStorage
Expand Down Expand Up @@ -146,12 +147,18 @@ def post(self) -> dict:
:param git_active_branch: an active branch of the uploaded model
:param git_commit_hash: hash of the most recent commit
:param file: contents of the file selected to upload
:param checksum: md5 hash computed from the file and stored in hexadecimal format
:param private: whether to mark the model as private
:returns: a newly uploaded model
"""

parser = reqparse.RequestParser()
parser.add_argument("name", type=str)
parser.add_argument("dataset_name", type=str)
parser.add_argument(
"name",
type=str_type(max_length=40),
help="Name of the model must be at most 40 characters long.",
)
parser.add_argument("dataset_name", type=str_type(max_length=120))
parser.add_argument("dataset_description", type=str)
parser.add_argument("project_id", type=int, required=True)
# user_id : deprecated
Expand All @@ -164,6 +171,7 @@ def post(self) -> dict:
parser.add_argument("git_commit_hash", type=str, default=None)
parser.add_argument("file", type=FileStorage, location="files", required=True)
parser.add_argument("private", type=bool, default=False)
parser.add_argument("checksum", type=str)
args = parser.parse_args()

if "file" in args:
Expand All @@ -182,6 +190,7 @@ def post(self) -> dict:
metrics=args["metrics"],
name=args["name"],
path=filename,
checksum=args["checksum"],
dataset_name=args["dataset_name"],
dataset_description=args["dataset_description"],
git_active_branch=args["git_active_branch"],
Expand Down
13 changes: 13 additions & 0 deletions backend/app/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,16 @@ def dump(self, data):
response["pagination"] = self.gather_pagination_info()

return response


def str_type(min_length=None, max_length=None):
def validate(str_to_check):
if len(str_to_check) == 0:
raise ValueError(f"String must be at least 1 character long")
if min_length and len(str_to_check) < min_length:
raise ValueError(f"String must be at least {min_length} characters long")
if max_length and len(str_to_check) > max_length:
raise ValueError(f"String must be at most {max_length} characters long")
return str_to_check

return validate
47 changes: 34 additions & 13 deletions client/maisie/resources/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union
from maisie import BaseAction
from maisie.utils.git import GitProvider
from hashlib import md5

import os

Expand Down Expand Up @@ -41,10 +42,15 @@ def upload(
parameters = self._determine_input(parameters)
metrics = self._determine_input(metrics)

checksum = ""
with self.config.session as session:
files = {}
try:
files["file"] = open(filename, "rb")
with open(filename, "rb") as f:
files["file"] = f.read()
checksum = md5(files["file"]).hexdigest()
# files["file"] = open(filename, "rb").read()
# checksum = md5(files["file"]).hexdigest()
except FileNotFoundError:
logger.error(f"Model `{filename}` could not be found.")

Expand All @@ -59,20 +65,23 @@ def upload(
"project_id": self.config.selected_project,
"git_active_branch": git.active_branch,
"git_commit_hash": git.latest_commit,
"dataset_name": dataset_name,
"dataset_description": dataset_description,
"checksum": checksum,
}
request = session.post(
f"{self.config.api_url}/models/", files=files, data=payload
)

results = []
# print(payload)
# print(request.text)
if "data" in request.json():
results.append(request.json()["data"])
else:
# results = []
# if "data" in request.json():
# results.append(request.json()["data"])
# else:
# logger.error("Could not upload selected model.")
if not "data" in request.json():
logger.error("Could not upload selected model.")

return results
return request.json()

def update(self, id: int, data: dict):
"""Update selected model.
Expand All @@ -83,27 +92,39 @@ def update(self, id: int, data: dict):
with self.config.session as session:
pass

def download(self, id: int):
def download(self, id: int, path):
"""Downloads requested model.

:param id: id of the model to download
"""
with self.config.session as session:
request = session.get(f"{self.config.api_url}/models/{id}/")
if request.status_code == 404:
return "Model with the specified id was not found"
request = request.json()
if (
("data") in request
and "checksum" in request["data"]
and "_links" in request["data"]
and "name" in request["data"]
and "download" in request["data"]["_links"]
):
download_link = request["data"]["_links"]["download"]
download_data = session.get(download_link)
source_data = session.get(request["data"]["_links"]["download"])
model_name = request["data"]["name"]
if path:
source_checksum = request["data"]["checksum"]
if path and model_name:
model_name = os.path.join(path, model_name)
response = "Checksums differ"
with open(model_name, "wb") as model_file:
model_file.write(download_data.content)
for chunk in source_data.iter_content(chunk_size=128):
model_file.write(chunk)
with open(model_name, "rb") as model_file:
local_checksum = md5(model_file.read()).hexdigest()
# local_checksum = md5(open(model_name, "rb").read()).hexdigest()
# model_name.close()
if local_checksum and source_checksum and local_checksum == source_checksum:
response = "Model downloaded successfully"
return response

def get(self, id: int) -> list:
"""Fetches a single model.
Expand Down
1 change: 1 addition & 0 deletions client/maisie/tests/test_resources/params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"a":"b"}
52 changes: 52 additions & 0 deletions client/maisie/tests/test_resources/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest

from maisie.resources.models import Models


models = Models()


def test_correct_upload():
# correct input
# test will fail if it raises any kind of unexpected Exception
models.upload(
name="some_name",
filename="params.json",
hyperparameters="params.json",
parameters="params.json",
metrics="params.json",
dataset_name="dataset_name",
)


def test_too_long_name():
# name in database has a limit of 40 characters

result = models.upload(
name="111111111111111111111111111111111111111112",
filename="params.json",
hyperparameters="params.json",
parameters="params.json",
metrics="params.json",
dataset_name="dataset_name",
)
assert (
result["message"]["name"]
== "Name of the model must be at most 40 characters long."
)


def test_not_existing_file():
filename = "nonexistent_file.json"
result = models.upload(
name="some_name",
filename=filename,
hyperparameters="params.json",
parameters="params.json",
metrics="params.json",
dataset_name="dataset_name",
)
assert result["message"]["file"] == "Missing required parameter in an uploaded file"


# def test_too_big_file():