Skip to content

Commit e0d0f9f

Browse files
authored
Merge pull request #719 from h2see/windows-cuda-path-bugfix
Ensure the windows CUDA_PATH is included in DLL search paths.
2 parents 464d6e2 + e6f09e1 commit e0d0f9f

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

Python/tigre/__init__.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@
66
# https://github.com/CERN/TIGRE/issues/349
77
import os
88

9+
_CUDA_PATH_WIN_DIR_HANDLE = None
10+
def _add_cuda_path_on_windows():
11+
global _CUDA_PATH_WIN_DIR_HANDLE
12+
import sys
13+
if sys.platform == "win32" and hasattr(os, "add_dll_directory"):
14+
_cuda_path = os.environ.get("CUDA_PATH")
15+
if _cuda_path:
16+
_bin = os.path.join(_cuda_path, "bin")
17+
if os.path.isdir(_bin):
18+
# a reference to the handle returned by os.add_dll_directory is necessary,
19+
# otherwise it will be garbage collected and the directory will be removed
20+
# from the search path again.
21+
_CUDA_PATH_WIN_DIR_HANDLE = os.add_dll_directory(_bin)
22+
923
# if hasattr(os, "add_dll_directory"):
1024
# # Add all the DLL directories manually
1125
# # see:
@@ -21,7 +35,18 @@
2135
from .utilities.geometry import geometry
2236
from .utilities.geometry_default import ConeGeometryDefault as geometry_default
2337
from .utilities.geometry_default import FanGeometryDefault as fan_geometry_default
24-
from .utilities.Ax import Ax
38+
39+
# always import Ax and _Ax_ext before all other ctypes extensions (checks for DLL import errors).
40+
from .utilities.Ax import Ax, _Ax_ext, _ensure_Ax_ext_import
41+
# _Ax_ext will be None if the import failed,
42+
# in that case we try to add the CUDA path on windows and import again,
43+
# this will raise the original import error with the full message if it still fails.
44+
if _Ax_ext is None:
45+
from .utilities.Ax import _try_import_Ax_ext
46+
_add_cuda_path_on_windows()
47+
_try_import_Ax_ext()
48+
_ensure_Ax_ext_import() # check if the import was successful, if not this will raise the original import error.
49+
2550
from .utilities.Atb import Atb
2651
from .utilities.visualization.plotproj import plotproj, plotProj, plotSinogram
2752
from .utilities.visualization.plotimg import plotimg, plotImg

Python/tigre/utilities/Ax.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,33 @@
11
import copy
22

33
import numpy as np
4-
from _Ax import _Ax_ext
54

6-
from .gpu import GpuIds
5+
_Ax_ext = None
6+
GpuIds = None
7+
def _try_import_Ax_ext():
8+
global _Ax_ext, GpuIds
9+
if _Ax_ext is None:
10+
try:
11+
from _Ax import _Ax_ext as imported
12+
_Ax_ext = imported
13+
except ImportError:
14+
_Ax_ext = None
15+
if _Ax_ext is not None and GpuIds is None:
16+
# GpuIds depends on ctypes as well, so we must import it
17+
# after sucessful completion of the _Ax_ext import.
18+
from .gpu import GpuIds as imported_ids
19+
GpuIds = imported_ids
20+
21+
def _ensure_Ax_ext_import():
22+
global _Ax_ext
23+
if _Ax_ext is None:
24+
from _Ax import _Ax_ext
25+
26+
_try_import_Ax_ext()
727

828

929
def Ax(img, geo, angles, projection_type="Siddon", **kwargs):
30+
_ensure_Ax_ext_import() # check if the import was successful, if not this will raise the original import error with the full message.
1031
if img.dtype != np.float32:
1132
raise TypeError("Input data should be float32, not " + str(img.dtype))
1233
if not np.isreal(img).all():

0 commit comments

Comments
 (0)