Skip to content

Commit 8b2b769

Browse files
jhnwu3John Wu
andauthored
legacy bug fix for the trainer assuming there exists a model.mode attribute when its not really required (#545)
Co-authored-by: John Wu <johnwu3@sunlab-serv-03.cs.illinois.edu>
1 parent 5200ac4 commit 8b2b769

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

pixi.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyhealth/models/base_model.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def __init__(self, dataset: SampleDataset):
3030
# used to query the device of the model
3131
self._dummy_param = nn.Parameter(torch.empty(0))
3232

33+
self.mode = None # legacy API for backward compatibility with the trainer.
34+
3335
@property
3436
def device(self) -> torch.device:
3537
"""
@@ -50,9 +52,9 @@ def get_output_size(self) -> int:
5052
Returns:
5153
int: The output size of the model.
5254
"""
53-
assert len(self.label_keys) == 1, (
54-
"Only one label key is supported if get_output_size is called"
55-
)
55+
assert (
56+
len(self.label_keys) == 1
57+
), "Only one label key is supported if get_output_size is called"
5658
output_size = self.dataset.output_processors[self.label_keys[0]].size()
5759
return output_size
5860

@@ -69,9 +71,9 @@ def get_loss_function(self) -> Callable:
6971
Returns:
7072
Callable: The default loss function.
7173
"""
72-
assert len(self.label_keys) == 1, (
73-
"Only one label key is supported if get_loss_function is called"
74-
)
74+
assert (
75+
len(self.label_keys) == 1
76+
), "Only one label key is supported if get_loss_function is called"
7577
label_key = self.label_keys[0]
7678
mode = self.dataset.output_schema[label_key]
7779
if mode == "binary":
@@ -106,9 +108,9 @@ def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor:
106108
Returns:
107109
torch.Tensor: The predicted probability tensor.
108110
"""
109-
assert len(self.label_keys) == 1, (
110-
"Only one label key is supported if get_loss_function is called"
111-
)
111+
assert (
112+
len(self.label_keys) == 1
113+
), "Only one label key is supported if get_loss_function is called"
112114
label_key = self.label_keys[0]
113115
mode = self.dataset.output_schema[label_key]
114116
if mode in ["binary"]:

0 commit comments

Comments
 (0)