Skip to content

[Kernel] Add enable_sm120_or_later for SM121 (DGX Spark) CUTLASS support#33517

Merged
vllm-bot merged 4 commits intovllm-project:mainfrom
Code4me2:fix/sm121-cpp-kernel-support
Feb 7, 2026
Merged

[Kernel] Add enable_sm120_or_later for SM121 (DGX Spark) CUTLASS support#33517
vllm-bot merged 4 commits intovllm-project:mainfrom
Code4me2:fix/sm121-cpp-kernel-support

Conversation

@Code4me2
Copy link
Contributor

@Code4me2 Code4me2 commented Feb 1, 2026

Summary

Add enable_sm120_or_later kernel wrapper to support SM121 (DGX Spark GB10) in addition to SM120 (RTX 5090) for Blackwell CUTLASS kernels.

Problem

DGX Spark GB10 (SM121) cannot use CUTLASS kernels because enable_sm120_only uses exact architecture match.

Root Cause

The existing enable_sm120_only wrapper uses:

#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200

This excludes SM121 (arch 1210) which has identical tensor core capabilities.

Solution

Add new enable_sm120_or_later wrapper:

#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1200

This includes both SM120 (RTX 5090) and SM121+ (DGX Spark) architectures.

Files Changed

  • csrc/cutlass_extensions/common.hpp: Add enable_sm120_or_later template
  • csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh: Use enable_sm120_or_later for FP8 blockwise GEMM

Testing

Tested on DGX Spark GB10 (SM121) with nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4

Related Issues

Fixes #28589

@github-actions
Copy link

github-actions bot commented Feb 1, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@dosubot
Copy link

dosubot bot commented Feb 1, 2026

Related Documentation

No published documentation to review for changes on this repository.

Write your first living document

How did I do? Any feedback?  Join Discord

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for SM121 (DGX Spark) by introducing an enable_sm120_or_later kernel wrapper. The change is logical and consistent with the existing codebase structure. I have one suggestion to improve the long-term robustness of the new wrapper by adding an upper bound to the architecture check, which will prevent potential issues with future, incompatible GPU architectures.

struct enable_sm120_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While using >= 1200 correctly enables this kernel for SM120 and SM121 as intended, it makes a strong assumption about forward compatibility with all future architectures. Highly-tuned kernels like this can be sensitive to changes in future hardware generations. To make this safer and more explicit, I recommend adding an upper bound to the check to limit it to the Blackwell architecture series (presumably SM12x). This will prevent potential hard-to-debug issues on future, incompatible hardware.

#if defined __CUDA_ARCH__ && (__CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300)

Add enable_sm120_or_later kernel wrapper to support SM121 (DGX Spark GB10)
in addition to SM120 (RTX 5090/6000 Pro) for Blackwell CUTLASS kernels.

The existing enable_sm120_only wrapper uses __CUDA_ARCH__ == 1200 which
excludes SM121 (arch 1210). The new wrapper uses __CUDA_ARCH__ >= 1200 to
include both SM120 and SM121+ architectures.

Changes:
- csrc/cutlass_extensions/common.hpp: Add enable_sm120_or_later template
- scaled_mm_blockwise_sm120_fp8_dispatch.cuh: Use enable_sm120_or_later
  for FP8 blockwise GEMM kernels

SM121 shares the same tensor core capabilities as SM120, so these kernels
work correctly on both architectures.

Tested on DGX Spark GB10 (SM121) with nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4

Signed-off-by: code4me2 <[email protected]>
@Code4me2 Code4me2 force-pushed the fix/sm121-cpp-kernel-support branch 3 times, most recently from 81450fd to 9179e0a Compare February 1, 2026 21:39

// SM12x family includes SM120 (RTX 5090) and SM121 (DGX Spark GB10)
template <typename Kernel>
struct enable_sm120_or_later : Kernel {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update this to enable_sm120_family since "later" sounds like >= sm120

Address reviewer feedback: rename to enable_sm120_family since
"later" sounds like >= sm120, while this specifically targets
the SM12x family (SM120, SM121).

Signed-off-by: code4me2 <[email protected]>
@Code4me2 Code4me2 force-pushed the fix/sm121-cpp-kernel-support branch from 7e07daa to b21d37d Compare February 3, 2026 19:43
@Code4me2
Copy link
Contributor Author

Code4me2 commented Feb 6, 2026

@mgoin was there anything else to do for this one? I think this PR has all the changes you wanted

@Code4me2 Code4me2 requested a review from mgoin February 6, 2026 20:30
@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Feb 6, 2026
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 6, 2026
@mgoin
Copy link
Member

mgoin commented Feb 6, 2026

Thanks @Code4me2 ! Enabled CI

@Code4me2
Copy link
Contributor Author

Code4me2 commented Feb 7, 2026

@mgoin is there anything else for me to do here? the checks that failed seem unrelated to the implementatino

@mgoin
Copy link
Member

mgoin commented Feb 7, 2026

Nope just flaky CI at the moment, thanks for the ping!

@vllm-bot vllm-bot merged commit bc32444 into vllm-project:main Feb 7, 2026
106 of 109 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Feb 7, 2026
@shahizat
Copy link

shahizat commented Feb 8, 2026

hi @mgoin, @Code4me2, thanks for the fix. I built it from source on a machine with a Blackwell 6000 Pro and on a machine with a 5090. On the Blackwell 6000 Pro, it works fine, but on the machine with the 5090, it does not. Am I doing something wrong?

Error log:

ValueError: NvFp4 MoE backend 'FLASHINFER_CUTLASS' does not support the deployment configuration since kernel does not support current device.`
nvidia-smi --query-gpu=compute_cap --format=csv
compute_cap
12.0
12.0

5090 machine

uv pip show vllm
Using Python 3.12.3 environment at: .vllm
Name: vllm
Version: 0.15.2rc1.dev114+g4df841fe7.d20260208.cu130

Blackwell 6000 Pro

uv pip show vllm
Using Python 3.12.3 environment at: .vllm
Name: vllm
Version: 0.15.2rc1.dev113+ga263aa614.d20260208.cu130

My steps:

uv venv .vllm --python 3.12
source .vllm/bin/activate

uv pip install --force-reinstall torch torchvision torchaudio triton --index-url https://download.pytorch.org/whl/cu130

export TORCH_CUDA_ARCH_LIST="12.0"
export CUDA_HOME=/usr/local/cuda-13
export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas
export PATH="${CUDA_HOME}/bin:$PATH"

git clone https://github.com/vllm-project/vllm.git
cd vllm 
python3 use_existing_torch.py 
uv pip install -r requirements/build.txt
MAX_JOBS=$(nproc) python3 setup.py bdist_wheel

uv pip install --no-deps dist/vllm*.whl
uv pip install -r requirements/common.txt

wget https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4/resolve/main/nano_v3_reasoning_parser.py

VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND=latency \
vllm serve nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4 \
  --served-model-name model \
  --max-num-seqs 8 \
  --tensor-parallel-size 1 \
  --max-model-len 262144 \
  --port 8000 \
  --trust-remote-code \
  --enable-auto-tool-choice \
  --tool-call-parser qwen3_coder \
  --reasoning-parser-plugin nano_v3_reasoning_parser.py \
  --reasoning-parser nano_v3 \
  --kv-cache-dtype fp8

@Code4me2
Copy link
Contributor Author

Code4me2 commented Feb 8, 2026

@shahizat let me check on my setup that I was testing with. Which model were you running when it failed?

Code4me2 added a commit to Code4me2/vllm that referenced this pull request Feb 11, 2026
…and NVFP4 MoE oracle checks

PR vllm-project#33417 added is_device_capability_family(120) to
flashinfer_cutlass_moe.py and cutlass_moe.py but missed three other
NVFP4 MoE backend files that still only check family(100).

RTX 5090 (SM120) and SM110 GPUs are rejected by the oracle when
using FlashInfer TRT-LLM, CuteDSL, or NVFP4 TRT-LLM weight prep
backends.

Add the same family(110) and family(120) checks to match the pattern
established by vllm-project#33417.

Fixes the issue reported in vllm-project#33517.

Signed-off-by: code4me2 <[email protected]>
Code4me2 added a commit to Code4me2/vllm that referenced this pull request Feb 11, 2026
…and NVFP4 MoE oracle checks

PR vllm-project#33417 added is_device_capability_family(120) to
flashinfer_cutlass_moe.py and cutlass_moe.py but missed four other
checks that still only match family(100).

RTX 5090 (SM120) and SM110 GPUs are rejected by the oracle when
using FlashInfer TRT-LLM, CuteDSL, or NVFP4 TRT-LLM weight prep
backends. A fourth check in flashinfer_utils.py silently downgrades
the TRT-LLM backend to CUTLASS for non-SM100 devices.

Add the same family(110) and family(120) checks to match the pattern
established by vllm-project#33417.

Fixes the issue reported in vllm-project#33517.

Signed-off-by: code4me2 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: V1 Engine fails on Blackwell GB10 (SM 12.1): "sink setting not supported" by all compatible attention backends

4 participants