Skip to content

Conversation

@zhongbozhu
Copy link
Collaborator

@zhongbozhu zhongbozhu commented Jan 3, 2026

Description

Note: #2558 reported a bug in #2411. Fix is here #2564: make sure you cherry-pick this one too before it's in main.

Previously, similar optimization has been applied for MOE grouped quantize with RHT in #2411. This PR targets the dense linear layers & shared experts when being quantized to NVFP4. Having this fusion means high precision input only needs to be read once while without this fusion, it needs to be read twice.

Similarly, we have env var NVTE_USE_FAST_MATH to control the numerical behavior of RHT quant fusion kernel to accelerate it further. The fast math is only applied to the high precision math so it will have minimal impact of the training convergence.

What fast-math toggle controls:

  1. replace x / y by x * (1/y)
  2. replace 1 / x by reciporal_approximate_ftz(x)
  3. when RHT cast fusion is available, fusion allows nvfp4 quantize to be performed directly on FP32 data in register files, this will essentially remove a round trip between FP32 to BF16 then FP32.

Therefore, I DO recommend turn it on since it will significantly improve the RHT kernel performance.

The only reason why it's still not default open is because we want ZERO TOLERNACE test between our CUDA quantize kernels and our pytorch-based emulated quantize references. With fast math toggle turned on, it's hard to pass test with zero tolerance without further investigation of how to relax the test conditions while still providing high confidence of the test case.

TODO items:

  • Merge the bug fix PR 2564 first.
  • Some cutlass deprecating APIs are being used, output many warnings.
  • Maybe turn on fast math by default and use NVTE_DISABLE_RHT_FAST_MATH instead of using NVTE_USE_FAST_MATH? @timmoon10 for opinions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zhongbozhu zhongbozhu requested a review from timmoon10 January 3, 2026 01:23
@zhongbozhu zhongbozhu self-assigned this Jan 3, 2026
@zhongbozhu zhongbozhu added the fp4 label Jan 3, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 3, 2026

Greptile Summary

This PR integrates a new Cutlass-based CUDA kernel that fuses Row-Cast and Column-RHT-Transpose-Cast operations for NVFP4 quantization in dense linear layers. The optimization reduces memory bandwidth by reading high-precision input data only once instead of twice.

  • Added nvte_hadamard_transform_cast_fusion API function that performs rowwise quantization and columnwise RHT+quantization+transpose in a single kernel
  • Kernel uses MMA hardware for efficient Hadamard transform computation and is eligible when input is BF16 with dimensions divisible by 64x128
  • Refactored NVFP4Quantizer::quantize_impl() to use the fused kernel when eligible, with extracted helper method for unfused fallback path
  • Added NVTE_USE_FAST_MATH environment variable support to accelerate RHT operations (replaces division with multiplication by reciprocal, uses approximate reciprocal)
  • Extended test coverage to include columnwise-only quantization mode
  • Added benchmark script for profiling linear layer performance across quantization recipes

Confidence Score: 4/5

  • This PR is safe to merge - it adds a new performance optimization path with proper fallback to existing behavior for unsupported shapes.
  • Score of 4 reflects well-structured code with proper compile guards, shape validation checks, and fallback paths. The CUDA kernel follows established patterns in the codebase. Minor unused variables exist but do not affect functionality.
  • The new CUDA kernel file (row_cast_col_hadamard_transform_cast_fusion.cu) is the most complex addition and should be reviewed for numerical correctness in production workloads.

Important Files Changed

Filename Overview
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu New CUDA kernel implementing fused Row-Cast-Col-RHT-Transpose-Cast for NVFP4 quantization, leveraging MMA hardware for Hadamard transform. Contains minor unused variables but no functional issues.
transformer_engine/pytorch/csrc/quantizer.cpp Refactored quantization logic to use fused kernel when eligible (rows%64==0, cols%128==0), with unfused fallback. Added fast math toggle support and extracted unfused helper method.
transformer_engine/common/include/transformer_engine/hadamard_transform.h Added new API function nvte_hadamard_transform_cast_fusion for row-cast and column-RHT-transpose-cast fusion; marked old columnwise-only function for deprecation.
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py Extended test coverage to support columnwise-only quantization mode in addition to existing rowwise and combined modes.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant NVFP4Q as NVFP4Quantizer
    participant QImpl as quantize_impl()
    participant FusedK as nvte_hadamard_transform_cast_fusion
    participant UnfusedH as quantize_with_rht_unfused_helper
    participant RowQuant as nvte_quantize_v2 (rowwise)
    participant RHT as nvte_hadamard_transform
    participant ColQuant as nvte_quantize_v2 (columnwise)

    User->>NVFP4Q: quantize(input, output)
    NVFP4Q->>QImpl: quantize_impl(input, output)
    
    alt eligible_for_rht_cast_fusion (BF16, rows%64==0, cols%128==0)
        QImpl->>FusedK: nvte_hadamard_transform_cast_fusion()
        Note over FusedK: Single kernel does:<br/>1. Rowwise quantization<br/>2. RHT + columnwise quant + transpose
        FusedK-->>QImpl: rowwise + columnwise output
    else not eligible (irregular shapes)
        QImpl->>UnfusedH: quantize_with_rht_unfused_helper()
        alt rowwise_usage
            UnfusedH->>RowQuant: nvte_quantize_v2()
            RowQuant-->>UnfusedH: rowwise output
        end
        alt columnwise_usage
            UnfusedH->>RHT: nvte_hadamard_transform()
            RHT-->>UnfusedH: RHT(input.T)
            UnfusedH->>ColQuant: nvte_quantize_v2()
            ColQuant-->>UnfusedH: columnwise output
        end
        UnfusedH-->>QImpl: combined output
    end
    
    QImpl-->>NVFP4Q: quantized tensor
    NVFP4Q-->>User: NVFP4Tensor
Loading

@zhongbozhu zhongbozhu changed the title [NVFP4][Dense] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose Fusion Kernel [NVFP4][Dense] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel Jan 3, 2026
@zhongbozhu zhongbozhu force-pushed the zhongbo/dense_row_col_rht_fp4_quant branch from c80932f to fc42825 Compare January 3, 2026 04:16
@zhongbozhu
Copy link
Collaborator Author

/te-ci arm L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. benchmarks/linear/benchmark_linear.py, line 141 (link)

    logic: NVTX range is pushed but never popped in the benchmark function

  2. transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu, line 346 (link)

    syntax: Typo in comment: 'SMEMork' should be 'SMEM work'

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@zhongbozhu zhongbozhu added the MoE label Jan 6, 2026
@zhongbozhu zhongbozhu changed the title [NVFP4][Dense] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel [NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel, Fixing NVFP4 Group Quant Bug Jan 6, 2026
@zhongbozhu zhongbozhu changed the title [NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel, Fixing NVFP4 Group Quant Bug [NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel Jan 6, 2026
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
@zhongbozhu zhongbozhu force-pushed the zhongbo/dense_row_col_rht_fp4_quant branch from 2bc695e to 6ea9dab Compare January 9, 2026 23:14
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR integrates a Cutlass-based fusion kernel that combines row-wise quantization and column-wise RHT (Random Hadamard Transform) + quantization + transpose operations for NVFP4 dense linear layers and shared experts. The key optimization reduces memory bandwidth by reading high-precision input data once instead of twice.

Key Changes

New Fusion Kernel (row_cast_col_hadamard_transform_cast_fusion.cu):

  • Implements nvte_hadamard_transform_cast_fusion API that performs both rowwise and columnwise quantization in a single pass
  • Uses MMA hardware for efficient Hadamard transform computation
  • Eligible when input is BF16 with dimensions divisible by 64×128
  • Reads pre-computed amax values to calculate FP8 scaling factors
  • Supports stochastic rounding and fast math optimization flags

Refactored Quantizer Logic (quantizer.cpp):

  • Moved unfused RHT path into quantize_with_rht_unfused_helper method for cleaner code organization
  • Improved RNG state handling: single RNG state when fusion is used, separate states for rowwise/columnwise when unfused
  • Added NVTE_USE_FAST_MATH environment variable support for accelerating high-precision math operations
  • Eligibility check moved before RNG state generation to avoid unnecessary work

Extended Test Coverage (test_nvfp4_rht_quantize_exact.py):

  • Added "columnwise-only" quantization mode testing alongside existing "quantize" and "quantize_transpose" modes
  • Tests now validate rowwise/columnwise results conditionally based on the quantization mode

Grouped Quantization Support (cast.cpp):

  • Split-quantize path now uses fused kernel when all tensors have 128-aligned dimensions
  • Bulk RNG state generation for grouped kernels (single state shared across splits)
  • Fast math flag propagation to all quantization configs

Architecture Notes

The fusion provides optimal performance when:

  1. Input dtype is BF16
  2. Rows are divisible by 64 (MMA tile requirement)
  3. Columns are divisible by 128 (MMA tile requirement)

When these conditions aren't met, the code gracefully falls back to the unfused path with separate kernel launches for rowwise and columnwise quantization.

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk after addressing documentation and TODO items mentioned in the PR description
  • Score of 4 reflects a well-engineered feature with thorough implementation. The code demonstrates good software practices: clean refactoring with extracted helper methods, proper error handling, graceful fallback paths, and comprehensive test coverage including the new columnwise-only mode. The fusion kernel follows established patterns from the grouped quantization PR #2411. Deducted 1 point due to: (1) PR author notes cutlass deprecation warnings need addressing, (2) TODOs remain about potentially defaulting fast math on, and (3) the ~1400 line CUDA kernel file has limited inline documentation for complex template logic
  • The main CUDA kernel file (row_cast_col_hadamard_transform_cast_fusion.cu) would benefit from additional inline comments explaining the template parameter switches and MMA computation flow, but no files have critical issues requiring immediate attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/csrc/quantizer.cpp 4/5 Refactored NVFP4 quantize_impl to use new fused RHT cast kernel, extracted unfused helper, improved RNG state handling for fused vs unfused paths
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu 4/5 New CUDA kernel implementing fused row-cast and column-RHT-transpose-cast using Cutlass MMA hardware for BF16 inputs with 64x128 alignment
transformer_engine/common/include/transformer_engine/hadamard_transform.h 5/5 Added new API function nvte_hadamard_transform_cast_fusion for dense layer quantization, marked old columnwise function for future deprecation
transformer_engine/pytorch/csrc/extensions/cast.cpp 4/5 Added NVTE_USE_FAST_MATH env var support in split_quantize for grouped NVFP4 kernels, improved RNG state setup with bulk generation flag
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py 5/5 Extended test coverage to support columnwise-only quantization mode, added return_identity parameter to test all three modes

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Quantizer as NVFP4Quantizer
    participant API as nvte_hadamard_transform_cast_fusion
    participant Kernel as row_col_rht_gemm_ntt_w_sfc
    participant AmaxKernel as nvte_hadamard_transform_amax
    
    User->>Quantizer: quantize(input, output)
    Quantizer->>Quantizer: Check eligibility (BF16, rows%64==0, cols%128==0)
    
    alt With RHT and eligible for fusion
        Quantizer->>AmaxKernel: Compute rowwise & columnwise amax
        AmaxKernel-->>Quantizer: amax values populated
        
        alt Stochastic rounding enabled
            Quantizer->>Quantizer: Generate RNG state
        end
        
        alt Fast math enabled (NVTE_USE_FAST_MATH)
            Quantizer->>Quantizer: Set use_fast_math flag
        end
        
        Quantizer->>API: Call with input, output, hadamard_matrix, quant_config
        API->>Kernel: Launch fused kernel
        
        Kernel->>Kernel: Read amax values
        Kernel->>Kernel: Perform rowwise quantization to FP4
        Kernel->>Kernel: Compute RHT using MMA hardware
        Kernel->>Kernel: Transpose and quantize to FP4
        Kernel->>Kernel: Write FP8 scales
        
        Kernel-->>API: Complete
        API-->>Quantizer: Return
        
    else Not eligible for fusion
        Quantizer->>AmaxKernel: Compute amax
        AmaxKernel-->>Quantizer: amax values
        
        alt Rowwise usage
            Quantizer->>Quantizer: Call nvte_quantize_v2 for rowwise
        end
        
        alt Columnwise usage
            Quantizer->>Quantizer: Call nvte_hadamard_transform for RHT
            Quantizer->>Quantizer: Call nvte_quantize_v2 for columnwise
        end
    end
    
    Quantizer-->>User: Quantized output
Loading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant