Skip to content

[BUG] Add FP4 TMA kernel example for Blackwell in CuTe Python DSL #2922

@whatdhack

Description

@whatdhack

Which component has the problem?

CuTe DSL

Bug Report

Summary
CuTe DSL has all the low-level primitives for FP4 (Float4E2M1FN) with TMA on Blackwell, but lacks high-level examples or helper functions that demonstrate how to use them together. All existing Blackwell TMA examples use FP8.

Background
We are attempting to create a persistent blockwise GEMM kernel with TMA for FP4 data types on Blackwell B200, similar to the existing FP8 examples in examples/python/CuTeDSL/blackwell/blockwise_gemm/.

What Works
✅ Non-TMA FP4 kernels (using direct SMEM access)
✅ FP8 TMA kernels (existing examples)
✅ Low-level FP4 primitives:

cutlass.Float4E2M1FN datatype
MmaMXF4NVF4Op
(block-scaled FP4 MMA atom)
sm100_utils.make_blockscaled_trivial_tiled_mma()
What's Missing
❌ No FP4+TMA kernel examples
❌ No documentation on FP4-specific constraints for TMA
❌ Helper functions assume FP8 behavior (K=32)

Technical Challenges Encountered
When adapting an FP8 TMA kernel to FP4, we encountered:

MMA Instruction K Constraint

FP8 MMA: K = 32 per instruction
FP4 MMA: K = 64 per instruction (hardcoded in
MmaMXF4NVF4Op
)
Source:
/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py
line 1293-1300
Scale Factor Granularity

FP8: typically uses scale_granularity = 128
FP4: requires sf_vec_size = 16 for
MmaMXF4NVF4Op
Multiple scale factors per K-tile (e.g., 128/16 = 8 scales)
Initial Error Path

"unsupported ab_dtype, got Float4E2M1FN"
Fixed by using
make_blockscaled_trivial_tiled_mma()
instead of
make_trivial_tiled_mma()
Source:
/python/CuTeDSL/cutlass/utils/blackwell_helpers.py
line 999-1016
Current Error

"profile of input tuples doesn't match: (32, (16, 2))"
Thread/value layout mismatch due to different K dimension requirements
Shared memory layouts and TMA tile shapes need adjustment for K=64
Request
Could you provide one of the following:

Option 1: Add an FP4 TMA example kernel

Similar to blockwise_gemm.py but using Float4E2M1FN
Demonstrates proper configuration for FP4's K=64 constraint
Shows correct scale factor handling with sf_vec_size=16
Option 2: Update helper functions

Extend make_smem_layout_a/b() to handle FP4's K=64 dimension
Add validation/documentation for FP4-specific constraints
Provide guidelines for adapting FP8 kernels to FP4
Option 3: Documentation

Explain differences between FP8 and FP4 TMA kernel configuration
Document the K=64 constraint and its implications
Provide migration guide from FP8 to FP4
Environment
GPU: NVIDIA B200
CUDA: 13.0.0
CUTLASS: Latest from main branch
Python CuTe DSL: nvidia-cutlass, nvidia-cutlass-dsl
Code Reference
We successfully identified that FP4 support requires:

tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
ab_dtype, # Float4E2M1FN
a_major_mode,
b_major_mode,
cutlass.Float8E4M3FN, # sf_dtype
16, # sf_vec_size for FP4
cta_group,
mma_tiler_mn,
)
But further kernel configuration (shared memory layouts, TMA tile shapes, pipeline logic) needs adjustment for FP4's different hardware characteristics.

Impact
FP4 offers 2x improvement over FP8 for inference workloads on Blackwell. Having working TMA examples would enable the community to leverage this performance benefit while using the efficient TMA memory subsystem.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions