|
6 | 6 | # https://github.com/CERN/TIGRE/issues/349 |
7 | 7 | import os |
8 | 8 |
|
| 9 | +_CUDA_PATH_WIN_DIR_HANDLE = None |
| 10 | +def _add_cuda_path_on_windows(): |
| 11 | + global _CUDA_PATH_WIN_DIR_HANDLE |
| 12 | + import sys |
| 13 | + if sys.platform == "win32" and hasattr(os, "add_dll_directory"): |
| 14 | + _cuda_path = os.environ.get("CUDA_PATH") |
| 15 | + if _cuda_path: |
| 16 | + _bin = os.path.join(_cuda_path, "bin") |
| 17 | + if os.path.isdir(_bin): |
| 18 | + # a reference to the handle returned by os.add_dll_directory is necessary, |
| 19 | + # otherwise it will be garbage collected and the directory will be removed |
| 20 | + # from the search path again. |
| 21 | + _CUDA_PATH_WIN_DIR_HANDLE = os.add_dll_directory(_bin) |
| 22 | + |
9 | 23 | # if hasattr(os, "add_dll_directory"): |
10 | 24 | # # Add all the DLL directories manually |
11 | 25 | # # see: |
|
21 | 35 | from .utilities.geometry import geometry |
22 | 36 | from .utilities.geometry_default import ConeGeometryDefault as geometry_default |
23 | 37 | from .utilities.geometry_default import FanGeometryDefault as fan_geometry_default |
24 | | -from .utilities.Ax import Ax |
| 38 | + |
| 39 | +# always import Ax and _Ax_ext before all other ctypes extensions (checks for DLL import errors). |
| 40 | +from .utilities.Ax import Ax, _Ax_ext, _ensure_Ax_ext_import |
| 41 | +# _Ax_ext will be None if the import failed, |
| 42 | +# in that case we try to add the CUDA path on windows and import again, |
| 43 | +# this will raise the original import error with the full message if it still fails. |
| 44 | +if _Ax_ext is None: |
| 45 | + from .utilities.Ax import _try_import_Ax_ext |
| 46 | + _add_cuda_path_on_windows() |
| 47 | + _try_import_Ax_ext() |
| 48 | +_ensure_Ax_ext_import() # check if the import was successful, if not this will raise the original import error. |
| 49 | + |
25 | 50 | from .utilities.Atb import Atb |
26 | 51 | from .utilities.visualization.plotproj import plotproj, plotProj, plotSinogram |
27 | 52 | from .utilities.visualization.plotimg import plotimg, plotImg |
|
0 commit comments