Skip to content

Bug for fine-tuning Omat24 checkpoint model #2

@jinlhr542

Description

@jinlhr542

I am trying to fine-tuning the checkpoint_sevennet_mf_ompa.pth model:

import sevenn.util as util
model, config = util.model_from_checkpoint('checkpoint_sevennet_mf_ompa.pth')
cutoff = config['cutoff'] 
dataset = SevenNetGraphDataset(cutoff=cutoff, root=working_dir, files=dataset_files, processed_name='train.pt')

from sevenn.train.trainer import Trainer
import torch.optim.lr_scheduler as scheduler

trainer = Trainer.from_config(model, config)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 4
      1 from sevenn.train.trainer import Trainer
      2 import torch.optim.lr_scheduler as scheduler
----> 4 trainer = Trainer.from_config(model, config)
      6 # We have energy, force, stress loss function, which used to train 7net-0.
      7 # We will use it as it is, with loss weight: 1.0, 1.0, and 0.01 for energy, force, and stress, respectively.
      8 print(trainer.loss_functions)

File [~/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/trainer.py:88](http://localhost:3416/lab/tree/DRX/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/trainer.py#line=87), in Trainer.from_config(model, config)
     84 @staticmethod
     85 def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer':
     86     trainer = Trainer(
     87         model,
---> 88         loss_functions=get_loss_functions_from_config(config),
     89         optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()],
     90         optimizer_args=config.get(KEY.OPTIM_PARAM, {}),
     91         scheduler_cls=scheduler_dict[
     92             config.get(KEY.SCHEDULER, 'exponentiallr').lower()
     93         ],
     94         scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}),
     95         device=config.get(KEY.DEVICE, 'auto'),
     96         distributed=config.get(KEY.IS_DDP, False),
     97         distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'),
     98     )
     99     return trainer

File [~/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/loss.py:211](http://localhost:3416/lab/tree/DRX/miniconda3/envs/atomate2/lib/python3.12/site-packages/sevenn/train/loss.py#line=210), in get_loss_functions_from_config(config)
    207 from sevenn.train.optim import loss_dict
    209 loss_functions = []  # list of tuples (loss_definition, weight)
--> 211 loss = loss_dict[config[KEY.LOSS].lower()]
    212 loss_param = config.get(KEY.LOSS_PARAM, {})
    214 use_weight = config.get(KEY.USE_WEIGHT, False)

AttributeError: 'dict' object has no attribute 'lower'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions