Skip to content

Commit 7f4641a

Browse files
authored
issue(medcat-den): CU-869c96xwg Improve model card adaption (#352)
* CU-869c96xwg: Add v1 model pack to resources * CU-869c96xwg: Use/unpack v1 test model in base package for tests * CU-869c96xwg: Add a test for getting model card directly from unpacked v1 model * CU-869c96xwg: Fix issue with legacy / half-filled model card * CU-869c96xwg: Force-unpack v1 model within test (instead of globally) * CU-869c96xwg: Improve model card tests with copies of dicts * CU-869c96xwg: Add a couple of simple tests for model card load / validation data integrity
1 parent 276a3b7 commit 7f4641a

File tree

5 files changed

+91
-8
lines changed

5 files changed

+91
-8
lines changed

medcat-den/src/medcat_den/base.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,28 @@ def make_permissive(cls, v: dict) -> dict:
3535
defaults = {
3636
'Pipeline Description': {"core": {}, "addons": []},
3737
'Required Plugins': [],
38+
"Location": "N/A",
39+
"Basic CDB Stats": {
40+
"Unsupervised training history": [],
41+
"Supervised training history": [],
42+
},
43+
"Source Ontology": ["unknown"],
3844
}
3945
out_dict = {**defaults, **v} # v overwrites defaults
40-
if out_dict.get("Source Ontology") is None:
41-
out_dict['Source Ontology'] = ['Unknown']
46+
cls._check_key_value_recursively(out_dict, defaults)
4247
return out_dict
4348
return v
4449

50+
@classmethod
51+
def _check_key_value_recursively(
52+
cls, out_dict: dict, defaults: dict) -> None:
53+
for key, def_val in defaults.items():
54+
# NOTE: this should be mostly for nested stuff
55+
if key not in out_dict:
56+
out_dict[key] = def_val
57+
continue
58+
cur_val = out_dict[key]
59+
if cur_val is None or type(cur_val) is not type(def_val):
60+
out_dict[key] = def_val
61+
elif isinstance(def_val, dict):
62+
cls._check_key_value_recursively(cur_val, def_val)

medcat-den/tests/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@
77

88
MODEL_PATH = os.path.join(
99
os.path.dirname(__file__), "resources", "mct2_model_pack.zip")
10+
V1_MODEL_PATH = os.path.join(
11+
os.path.dirname(MODEL_PATH), "mct_v1_model_pack.zip"
12+
)
1013

1114

1215
# unpack
1316
_model_folder = CAT.attempt_unpack(MODEL_PATH)
17+
_v1_model_folder = CAT.attempt_unpack(V1_MODEL_PATH)
1418

1519

1620
def remove_model_folder():
1721
if os.path.exists(_model_folder):
1822
shutil.rmtree(_model_folder)
23+
if os.path.exists(_v1_model_folder):
24+
shutil.rmtree(_v1_model_folder)
1925

2026

2127
# cleanup
40.8 MB
Binary file not shown.

medcat-den/tests/test_model_card_validity.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from copy import deepcopy
2+
13
from medcat_den.base import ModelInfo
24

35

@@ -40,8 +42,8 @@
4042
def test_validates_with_old_format():
4143
model = ModelInfo(
4244
model_id="test_id",
43-
model_card=MODEL_CARD_NO_NEW_KEYS,
44-
base_model=None,
45+
model_card=deepcopy(MODEL_CARD_NO_NEW_KEYS),
46+
base_model=None,
4547
model_name="test_model",
4648
model_num=1,
4749
)
@@ -51,8 +53,8 @@ def test_validates_with_old_format():
5153
def test_validates_with_new_format():
5254
model = ModelInfo(
5355
model_id="test_id",
54-
model_card=MODEL_CARD_WITH_NEW_KEYS,
55-
base_model=None,
56+
model_card=deepcopy(MODEL_CARD_WITH_NEW_KEYS),
57+
base_model=None,
5658
model_name="test_model",
5759
model_num=1,
5860
)
@@ -62,11 +64,41 @@ def test_validates_with_new_format():
6264
def test_new_format_keeps_values():
6365
model = ModelInfo(
6466
model_id="test_id",
65-
model_card=MODEL_CARD_WITH_NEW_KEYS,
66-
base_model=None,
67+
model_card=deepcopy(MODEL_CARD_WITH_NEW_KEYS),
68+
base_model=None,
6769
model_name="test_model",
6870
model_num=1,
6971
)
7072
mc = model.model_card
7173
for key, exp_value in NEW_KV.items():
7274
assert exp_value == mc[key]
75+
76+
77+
def test_new_format_does_not_change_dict():
78+
model = ModelInfo(
79+
model_id="test_id",
80+
model_card=deepcopy(MODEL_CARD_WITH_NEW_KEYS),
81+
base_model=None,
82+
model_name="test_model",
83+
model_num=1,
84+
)
85+
assert model.model_card == MODEL_CARD_WITH_NEW_KEYS
86+
87+
88+
def test_old_format_will_not_change_twice():
89+
model_with_fixed_mc = ModelInfo(
90+
model_id="test_id",
91+
model_card=deepcopy(MODEL_CARD_NO_NEW_KEYS),
92+
base_model=None,
93+
model_name="test_model",
94+
model_num=1,
95+
)
96+
model = ModelInfo(
97+
model_id="test_id",
98+
model_card=model_with_fixed_mc.model_card,
99+
base_model=None,
100+
model_name="test_model",
101+
model_num=1,
102+
)
103+
assert model.model_card == model_with_fixed_mc.model_card
104+
assert model.model_card != MODEL_CARD_NO_NEW_KEYS
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
import json
3+
4+
from medcat_den import base
5+
6+
7+
from . import V1_MODEL_PATH
8+
9+
import pytest
10+
11+
12+
@pytest.fixture
13+
def v1_model_card():
14+
model_card_path = os.path.join(
15+
V1_MODEL_PATH.removesuffix(".zip"), "model_card.json")
16+
# NOTE: for some reason, this doesn't exist at this point
17+
from medcat.cat import CAT
18+
CAT.attempt_unpack(V1_MODEL_PATH)
19+
with open(model_card_path) as f:
20+
return json.load(f)
21+
22+
23+
def test_can_get_model_card(v1_model_card):
24+
model_info = base.ModelInfo(model_id=v1_model_card["Model ID"],
25+
model_card=v1_model_card,
26+
base_model=None)
27+
assert isinstance(model_info, base.ModelInfo)

0 commit comments

Comments
 (0)