-
Notifications
You must be signed in to change notification settings - Fork 633
Description
Is your feature request related to a problem? Please describe.
I am currently exploring the performance benefits of the Blackwell architecture, specifically focusing on the SM120 compute capability. While recent releases mention MXFP8 support, there seems to be limited documentation or potentially missing features specifically for the SM120 target compared to SM100.
Describe the solution you'd like
I would appreciate clarification on:
- The official support status and feature roadmap for MXFP8 on SM120 devices.
- Whether there are plans to introduce SM120-optimized kernels specifically for MXFP8BlockScaling that differ from the current Float8BlockScaling execution path.
Describe alternatives you've considered
I have tested the Float8BlockScaling recipe, and it functions correctly on SM120. However, I believe that MXFP8BlockScaling might potentially offer better performance or more native hardware utilization on Blackwell arch.
Additional context
After reviewing the TransformerEngine source code, I observed that on SM120 (and generally for sm_arch() >= 100), the Float8BlockScaling implementation performs a conversion of 1D/2D block scaling into an MXFP8 layout with E8M0 scales. It then invokes the same MXFP8 GEMM path as MXFP8BlockScaling.
Since both recipes eventually converge on the same GEMM kernel execution path, I have a few technical questions regarding this design:
- What is the primary design purpose behind having Float8BlockScaling and MXFP8BlockScaling share the same kernel path on Blackwell?
- Does the conversion step in Float8BlockScaling introduce significant overhead compared to a native MXFP8BlockScaling approach?