diff --git a/sageattention3_blackwell/setup.py b/sageattention3_blackwell/setup.py index 04e3888b..28384aae 100644 --- a/sageattention3_blackwell/setup.py +++ b/sageattention3_blackwell/setup.py @@ -1,5 +1,6 @@ import warnings import os +import sys from pathlib import Path from packaging.version import parse, Version from setuptools import setup, find_packages @@ -108,18 +109,35 @@ def append_nvcc_threads(nvcc_extra_args): "-DCTA256", "-DDQINRMEM", ] + + # Windows/MSVC needs -DUSE_CUDA for the same reason as in the main + # SageAttention package: nvcc-generated host code includes + # compiled_autograd.h, which will hit an ambiguous `std` error unless + # the header sees USE_CUDA defined. Only add the flag when using the + # MSVC toolchain (DISTUTILS_USE_SDK==1). + if sys.platform == "win32" and os.getenv("DISTUTILS_USE_SDK") == "1": + nvcc_flags.append("-DUSE_CUDA") + # undefine small macro from Windows headers; otherwise it renames + # function parameters in PyTorch headers and breaks compilation. + nvcc_flags.append("-Usmall") include_dirs = [ repo_dir / "sageattn3", cutlass_dir / "include", cutlass_dir / "tools" / "util" / "include", ] + # base C++ flags for extensions; may need USE_CUDA on Windows/MSVC + default_cxx = ["-O3", "-std=c++17"] + if sys.platform == "win32" and os.getenv("DISTUTILS_USE_SDK") == "1": + default_cxx.append("-DUSE_CUDA") + default_cxx.append("-Usmall") + ext_modules.append( CUDAExtension( name="fp4attn_cuda", sources=["sageattn3/blackwell/api.cu"], extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], + "cxx": default_cxx, "nvcc": append_nvcc_threads( nvcc_flags + ["-DEXECMODE=0"] + cc_flag ), @@ -134,7 +152,7 @@ def append_nvcc_threads(nvcc_extra_args): name="fp4quant_cuda", sources=["sageattn3/quantization/fp4_quantization_4d.cu"], extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], + "cxx": default_cxx, "nvcc": append_nvcc_threads( nvcc_flags + ["-DEXECMODE=0"] + cc_flag ), diff --git a/setup.py b/setup.py index 6b2c5b43..7e8a2831 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,11 @@ "-diag-suppress=174", ] + # No behaviour is changed on Linux or when building with another compiler. + if sys.platform == "win32" and os.getenv('DISTUTILS_USE_SDK') == '1': + CXX_FLAGS.append("-DUSE_CUDA") + NVCC_FLAGS.append("-DUSE_CUDA") + # Append flags from env if provided cxx_append = os.getenv("CXX_APPEND_FLAGS", "").strip() if cxx_append: