Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions sageattention3_blackwell/setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
import os
import sys
from pathlib import Path
from packaging.version import parse, Version
from setuptools import setup, find_packages
Expand Down Expand Up @@ -108,18 +109,35 @@ def append_nvcc_threads(nvcc_extra_args):
"-DCTA256",
"-DDQINRMEM",
]

# Windows/MSVC needs -DUSE_CUDA for the same reason as in the main
# SageAttention package: nvcc-generated host code includes
# compiled_autograd.h, which will hit an ambiguous `std` error unless
# the header sees USE_CUDA defined. Only add the flag when using the
# MSVC toolchain (DISTUTILS_USE_SDK==1).
if sys.platform == "win32" and os.getenv("DISTUTILS_USE_SDK") == "1":
nvcc_flags.append("-DUSE_CUDA")
# undefine small macro from Windows headers; otherwise it renames
# function parameters in PyTorch headers and breaks compilation.
nvcc_flags.append("-Usmall")
include_dirs = [
repo_dir / "sageattn3",
cutlass_dir / "include",
cutlass_dir / "tools" / "util" / "include",
]

# base C++ flags for extensions; may need USE_CUDA on Windows/MSVC
default_cxx = ["-O3", "-std=c++17"]
if sys.platform == "win32" and os.getenv("DISTUTILS_USE_SDK") == "1":
default_cxx.append("-DUSE_CUDA")
default_cxx.append("-Usmall")

ext_modules.append(
CUDAExtension(
name="fp4attn_cuda",
sources=["sageattn3/blackwell/api.cu"],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"cxx": default_cxx,
"nvcc": append_nvcc_threads(
nvcc_flags + ["-DEXECMODE=0"] + cc_flag
),
Expand All @@ -134,7 +152,7 @@ def append_nvcc_threads(nvcc_extra_args):
name="fp4quant_cuda",
sources=["sageattn3/quantization/fp4_quantization_4d.cu"],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"cxx": default_cxx,
"nvcc": append_nvcc_threads(
nvcc_flags + ["-DEXECMODE=0"] + cc_flag
),
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@
"-diag-suppress=174",
]

# No behaviour is changed on Linux or when building with another compiler.
if sys.platform == "win32" and os.getenv('DISTUTILS_USE_SDK') == '1':
CXX_FLAGS.append("-DUSE_CUDA")
NVCC_FLAGS.append("-DUSE_CUDA")

# Append flags from env if provided
cxx_append = os.getenv("CXX_APPEND_FLAGS", "").strip()
if cxx_append:
Expand Down