From 4485fa6ddd89b70c8c42073e18bc9b1dea555135 Mon Sep 17 00:00:00 2001 From: Zihann Date: Tue, 11 Jun 2024 00:27:51 +0800 Subject: [PATCH] FIX: get_device method lacks MPS judgment branch --- tensorizer/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorizer/utils.py b/tensorizer/utils.py index eb9ec8c8..877a0858 100644 --- a/tensorizer/utils.py +++ b/tensorizer/utils.py @@ -75,7 +75,11 @@ def convert_bytes(num, decimal=True) -> str: def get_device() -> torch.device: - return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device( + "cuda" + if torch.cuda.is_available() + else ("mps" if torch.backends.mps.is_available() else "cpu") + ) class GlobalGPUMemoryUsage(NamedTuple):