Skip to content

Commit a137673

Browse files
committed
remove redundancy
1 parent ae54591 commit a137673

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

src/squidpy/_settings/_dispatch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def _inject_gpu_note(doc: str | None, func_name: str, gpu_module: str) -> str |
4545
def _get_gpu_func(gpu_module: str, func_name: str) -> Callable[..., Any]:
4646
"""Get GPU function from module, with caching.
4747
48-
4948
Raises
5049
------
5150
ImportError

src/squidpy/_settings/_settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
if TYPE_CHECKING:
1010
from collections.abc import Generator
1111

12-
__all__ = ["settings", "DeviceType", "GPU_UNAVAILABLE_MSG"]
12+
__all__ = ["settings", "DeviceType"]
1313

1414
DeviceType = Literal["cpu", "gpu"]
1515
GPU_UNAVAILABLE_MSG = (

tests/test_gpu.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class TestGPUCoOccurrence:
2020

2121
def test_co_occurrence_gpu(self, adata):
2222
"""Test co_occurrence with GPU device."""
23-
result = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True, device="gpu")
23+
with settings.use_device("gpu"):
24+
result = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True)
2425

2526
assert result is not None
2627
arr, interval = result
@@ -29,8 +30,10 @@ def test_co_occurrence_gpu(self, adata):
2930

3031
def test_co_occurrence_gpu_vs_cpu(self, adata):
3132
"""Test that GPU and CPU results are approximately equal."""
32-
cpu_arr, cpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True, device="cpu")
33-
gpu_arr, gpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True, device="gpu")
33+
with settings.use_device("cpu"):
34+
cpu_arr, cpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True)
35+
with settings.use_device("gpu"):
36+
gpu_arr, gpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True)
3437

3538
np.testing.assert_allclose(cpu_interval, gpu_interval, rtol=1e-5)
3639
np.testing.assert_allclose(cpu_arr, gpu_arr, rtol=1e-5)
@@ -42,7 +45,8 @@ class TestGPUSpatialAutocorr:
4245
def test_spatial_autocorr_gpu(self, adata):
4346
"""Test spatial_autocorr with GPU device."""
4447
sq.gr.spatial_neighbors(adata)
45-
result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True, device="gpu")
48+
with settings.use_device("gpu"):
49+
result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True)
4650

4751
assert result is not None
4852
assert "I" in result.columns
@@ -51,8 +55,10 @@ def test_spatial_autocorr_gpu(self, adata):
5155
def test_spatial_autocorr_gpu_vs_cpu(self, adata):
5256
"""Test that GPU and CPU results are approximately equal."""
5357
sq.gr.spatial_neighbors(adata)
54-
cpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True, device="cpu")
55-
gpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True, device="gpu")
58+
with settings.use_device("cpu"):
59+
cpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True)
60+
with settings.use_device("gpu"):
61+
gpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True)
5662

5763
np.testing.assert_allclose(cpu_result["I"].values, gpu_result["I"].values, rtol=1e-3, equal_nan=True)
5864

@@ -62,7 +68,8 @@ class TestGPULigrec:
6268

6369
def test_ligrec_gpu(self, adata):
6470
"""Test ligrec with GPU device."""
65-
result = sq.gr.ligrec(adata, cluster_key="leiden", copy=True, device="gpu")
71+
with settings.use_device("gpu"):
72+
result = sq.gr.ligrec(adata, cluster_key="leiden", copy=True)
6673

6774
assert result is not None
6875
assert "means" in result

tests/test_settings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def test_set_device_invalid(self):
4747
with pytest.raises(ValueError, match="device must be one of"):
4848
settings.device = "invalid"
4949

50-
5150
def test_set_device_gpu_without_rsc(self):
5251
"""Test that setting device to 'gpu' without rapids-singlecell raises RuntimeError."""
5352
if settings.gpu_available:

0 commit comments

Comments
 (0)