@@ -30,6 +30,8 @@ def __init__(self, dataset: SampleDataset):
3030 # used to query the device of the model
3131 self ._dummy_param = nn .Parameter (torch .empty (0 ))
3232
33+ self .mode = None # legacy API for backward compatibility with the trainer.
34+
3335 @property
3436 def device (self ) -> torch .device :
3537 """
@@ -50,9 +52,9 @@ def get_output_size(self) -> int:
5052 Returns:
5153 int: The output size of the model.
5254 """
53- assert len ( self . label_keys ) == 1 , (
54- "Only one label key is supported if get_output_size is called"
55- )
55+ assert (
56+ len ( self . label_keys ) == 1
57+ ), "Only one label key is supported if get_output_size is called"
5658 output_size = self .dataset .output_processors [self .label_keys [0 ]].size ()
5759 return output_size
5860
@@ -69,9 +71,9 @@ def get_loss_function(self) -> Callable:
6971 Returns:
7072 Callable: The default loss function.
7173 """
72- assert len ( self . label_keys ) == 1 , (
73- "Only one label key is supported if get_loss_function is called"
74- )
74+ assert (
75+ len ( self . label_keys ) == 1
76+ ), "Only one label key is supported if get_loss_function is called"
7577 label_key = self .label_keys [0 ]
7678 mode = self .dataset .output_schema [label_key ]
7779 if mode == "binary" :
@@ -106,9 +108,9 @@ def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor:
106108 Returns:
107109 torch.Tensor: The predicted probability tensor.
108110 """
109- assert len ( self . label_keys ) == 1 , (
110- "Only one label key is supported if get_loss_function is called"
111- )
111+ assert (
112+ len ( self . label_keys ) == 1
113+ ), "Only one label key is supported if get_loss_function is called"
112114 label_key = self .label_keys [0 ]
113115 mode = self .dataset .output_schema [label_key ]
114116 if mode in ["binary" ]:
0 commit comments