Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sevenn/main/sevenn_get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
'deploy LAMMPS model from the checkpoint'
)
checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}'
'Pretrained model name (7net-omni, 7net-omni-i8, 7net-omni-i12, etc.) '
'or path to checkpoint file. See documentation for all available models.'
)
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'
Expand Down
34 changes: 32 additions & 2 deletions sevenn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,49 @@ def pretrained_name_to_path(name: str) -> str:
return download_checkpoint(paths[1], url) # ~/.cache


def get_available_pretrained_models() -> List[str]:
"""
Get list of available pretrained model names by checking
which checkpoint constants are defined in _const.py.

Returns:
List of canonical pretrained model names (7net-* format)
"""
# Mapping from checkpoint constant name to canonical model name
checkpoint_to_name = {
'SEVENNET_0_11Jul2024': '7net-0',
'SEVENNET_0_22May2024': '7net-0_22may2024',
'SEVENNET_l3i5': '7net-l3i5',
'SEVENNET_MF_0': '7net-mf-0',
'SEVENNET_MF_ompa': '7net-mf-ompa',
'SEVENNET_omat': '7net-omat',
'SEVENNET_omni': '7net-omni',
'SEVENNET_omni_i8': '7net-omni-i8',
'SEVENNET_omni_i12': '7net-omni-i12',
}

models = []
for const_name, model_name in checkpoint_to_name.items():
if hasattr(_const, const_name):
models.append(model_name)

return models


def load_checkpoint(checkpoint: Union[pathlib.Path, str]) -> 'SevenNetCheckpoint':
from sevenn.checkpoint import SevenNetCheckpoint

suggests = ['7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat', '7net-omni']
suggests = get_available_pretrained_models()
if osp.isfile(checkpoint):
checkpoint_path = checkpoint
else:
try:
checkpoint_path = pretrained_name_to_path(str(checkpoint))
except ValueError:
model_list = ', '.join(suggests)
raise ValueError(
f'Given {checkpoint} does not exist.\n'
f'Valid pretrained model names: {suggests}'
f'Valid pretrained model names: {model_list}'
)
return SevenNetCheckpoint(checkpoint_path)

Expand Down
Loading