-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
I am trying to fine-tuning the checkpoint_sevennet_mf_ompa.pth model:
import sevenn.util as util
model, config = util.model_from_checkpoint('checkpoint_sevennet_mf_ompa.pth')
cutoff = config['cutoff']
dataset = SevenNetGraphDataset(cutoff=cutoff, root=working_dir, files=dataset_files, processed_name='train.pt')
from sevenn.train.trainer import Trainer
import torch.optim.lr_scheduler as scheduler
trainer = Trainer.from_config(model, config)---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[6], line 4
1 from sevenn.train.trainer import Trainer
2 import torch.optim.lr_scheduler as scheduler
----> 4 trainer = Trainer.from_config(model, config)
6 # We have energy, force, stress loss function, which used to train 7net-0.
7 # We will use it as it is, with loss weight: 1.0, 1.0, and 0.01 for energy, force, and stress, respectively.
8 print(trainer.loss_functions)
File [~/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/trainer.py:88](http://localhost:3416/lab/tree/DRX/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/trainer.py#line=87), in Trainer.from_config(model, config)
84 @staticmethod
85 def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer':
86 trainer = Trainer(
87 model,
---> 88 loss_functions=get_loss_functions_from_config(config),
89 optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()],
90 optimizer_args=config.get(KEY.OPTIM_PARAM, {}),
91 scheduler_cls=scheduler_dict[
92 config.get(KEY.SCHEDULER, 'exponentiallr').lower()
93 ],
94 scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}),
95 device=config.get(KEY.DEVICE, 'auto'),
96 distributed=config.get(KEY.IS_DDP, False),
97 distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'),
98 )
99 return trainer
File [~/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/loss.py:211](http://localhost:3416/lab/tree/DRX/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/loss.py#line=210), in get_loss_functions_from_config(config)
207 from sevenn.train.optim import loss_dict
209 loss_functions = [] # list of tuples (loss_definition, weight)
--> 211 loss = loss_dict[config[KEY.LOSS].lower()]
212 loss_param = config.get(KEY.LOSS_PARAM, {})
214 use_weight = config.get(KEY.USE_WEIGHT, False)
AttributeError: 'dict' object has no attribute 'lower'Metadata
Metadata
Assignees
Labels
No labels