File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed
Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change 99from torch .testing ._internal import common_utils
1010
1111from torchao ._models .llama .model import Transformer
12+ from torchao .utils import get_current_accelerator_device
13+
14+ _DEVICE = get_current_accelerator_device ()
1215
1316
1417def init_model (name = "stories15M" , device = "cpu" , precision = torch .bfloat16 ):
@@ -22,7 +25,7 @@ class TorchAOBasicTestCase(unittest.TestCase):
2225 """Test suite for basic Transformer inference functionality."""
2326
2427 @common_utils .parametrize (
25- "device" , ["cpu" , "cuda" ] if torch .cuda .is_available () else ["cpu" ]
28+ "device" , ["cpu" , _DEVICE ] if torch .accelerator .is_available () else ["cpu" ]
2629 )
2730 @common_utils .parametrize ("batch_size" , [1 , 4 ])
2831 @common_utils .parametrize ("is_training" , [True , False ])
You can’t perform that action at this time.
0 commit comments