Skip to content

Commit 880fc64

Browse files
Pr 293 (#294)
* Add enable_flash/cueq attributes to SevenNetModel class * Add enable_flash/cueq in SevenNetModel class. * chore --------- Co-authored-by: hswoo369 <hswoo369@gmail.com>
1 parent dd71896 commit 880fc64

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

sevenn/torchsim.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __init__(
7474
*, # force remaining arguments to be keyword-only
7575
modal: str | None = None,
7676
neighbor_list_fn: Callable | None = None,
77+
enable_flash: bool = False,
78+
enable_cueq: bool = False,
7779
device: torch.device | str = 'auto',
7880
dtype: torch.dtype = torch.float32,
7981
) -> None:
@@ -90,6 +92,8 @@ def __init__(
9092
modal (str | None): modal (fidelity) if given model is multi-modal model.
9193
for 7net-mf-ompa, it should be one of 'mpa' (MPtrj + sAlex) or
9294
'omat24' (OMat24).
95+
enable_cueq (bool): Enable cuEquivariance backend.
96+
enable_flash (bool): Enable flashTP backend.
9397
neighbor_list_fn (Callable): Neighbor list function to use.
9498
Default is torch_nl_linked_cell.
9599
device (torch.device | str): Device to run the model on
@@ -126,7 +130,10 @@ def __init__(
126130

127131
if isinstance(model, (str, Path)):
128132
cp = load_checkpoint(model)
129-
model = cp.build_model()
133+
model = cp.build_model(
134+
enable_flash=enable_flash,
135+
enable_cueq=enable_cueq,
136+
)
130137

131138
_validate(model, modal)
132139

0 commit comments

Comments
 (0)