Skip to content

Commit ff6d9e2

Browse files
authored
[xpu][test] Port test/test_ao_models.py to intel XPU (#3481)
add test/test_ao_models.py
1 parent 17a7c37 commit ff6d9e2

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

test/test_ao_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from torch.testing._internal import common_utils
1010

1111
from torchao._models.llama.model import Transformer
12+
from torchao.utils import get_current_accelerator_device
13+
14+
_DEVICE = get_current_accelerator_device()
1215

1316

1417
def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):
@@ -22,7 +25,7 @@ class TorchAOBasicTestCase(unittest.TestCase):
2225
"""Test suite for basic Transformer inference functionality."""
2326

2427
@common_utils.parametrize(
25-
"device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
28+
"device", ["cpu", _DEVICE] if torch.accelerator.is_available() else ["cpu"]
2629
)
2730
@common_utils.parametrize("batch_size", [1, 4])
2831
@common_utils.parametrize("is_training", [True, False])

0 commit comments

Comments
 (0)