@@ -74,6 +74,8 @@ def __init__(
7474 * , # force remaining arguments to be keyword-only
7575 modal : str | None = None ,
7676 neighbor_list_fn : Callable | None = None ,
77+ enable_flash : bool = False ,
78+ enable_cueq : bool = False ,
7779 device : torch .device | str = 'auto' ,
7880 dtype : torch .dtype = torch .float32 ,
7981 ) -> None :
@@ -90,6 +92,8 @@ def __init__(
9092 modal (str | None): modal (fidelity) if given model is multi-modal model.
9193 for 7net-mf-ompa, it should be one of 'mpa' (MPtrj + sAlex) or
9294 'omat24' (OMat24).
95+ enable_cueq (bool): Enable cuEquivariance backend.
96+ enable_flash (bool): Enable flashTP backend.
9397 neighbor_list_fn (Callable): Neighbor list function to use.
9498 Default is torch_nl_linked_cell.
9599 device (torch.device | str): Device to run the model on
@@ -126,7 +130,10 @@ def __init__(
126130
127131 if isinstance (model , (str , Path )):
128132 cp = load_checkpoint (model )
129- model = cp .build_model ()
133+ model = cp .build_model (
134+ enable_flash = enable_flash ,
135+ enable_cueq = enable_cueq ,
136+ )
130137
131138 _validate (model , modal )
132139
0 commit comments