Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Feb 11, 2026

Description

Changes needed on TE side to make maxtext integration works

Issue # 2585

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Masking out padding tokens on each local EP (will be used in local permute step)
  • Pass along split_sizes and sorted_indices in residual of sort_chunks_by_index (local_permute) to avoid mismatch in size issue during tracing when EP>1

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pre-commit-ci bot and others added 6 commits February 5, 2026 13:55
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…ger than num tokens

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: JAX Toolbox <jax@nvidia.com>
@tdophung
Copy link
Collaborator Author

This PR contain changes cherry-picked from #2651 . I can wait until this gets merged and then merge mine, but if my PR is needed more urgently, happy to remove the cherry picked change

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

This PR fixes MaxText integration issues by addressing tensor permutation tracing problems when EP>1 and adding debugging utilities.

Major changes:

  • Modified sort_chunks_by_index to remove nondiff_argnums and instead pass split_sizes and sorted_indices through residuals, preventing shape mismatch errors during JAX tracing
  • Added padding token masking in Triton kernel to handle buffers larger than valid token count using identity mapping
  • Introduced new inspect_array debugging utility with C++ FFI backend that dumps tensors to binary files

Issues found:

  • Missing error handling in InspectFFI - file write failures are silently ignored
  • The name parameter in inspect_array is accepted but never used
  • Misleading comment about float32 dtype when code preserves original dtype

Confidence Score: 3/5

  • This PR is mostly safe but has a critical error handling gap in the inspect functionality
  • The core permutation fixes appear solid and address real tracing issues. However, the inspect feature has missing error handling that could cause silent failures in production. The inspect code also has unused parameters and incomplete TODO items.
  • Pay close attention to transformer_engine/jax/csrc/extensions/amax.cpp - the InspectFFI function needs error handling for file operations

Important Files Changed

Filename Overview
transformer_engine/common/triton/permutation.py Added padding token masking logic to handle buffers larger than valid token count - uses identity mapping for out-of-bounds indices
transformer_engine/jax/permutation.py Removed nondiff_argnums and now passes split_sizes/sorted_indices through residuals to fix tracing issues when EP>1, with misleading comment about dtype
transformer_engine/jax/inspect.py New debugging utility for inspecting JAX arrays - name parameter unused, minimal docstrings, TODOs not addressed
transformer_engine/jax/csrc/extensions/amax.cpp Added InspectFFI C++ handler with file I/O for debugging - missing error handling for file operations that could silently fail

Sequence Diagram

sequenceDiagram
    participant User
    participant JAX
    participant sort_chunks_by_index
    participant make_chunk_sort_map
    participant Triton Kernel
    participant InspectPrimitive
    participant InspectFFI

    User->>JAX: Call sort_chunks_by_index(inp, split_sizes, sorted_indices)
    JAX->>sort_chunks_by_index: Forward pass
    sort_chunks_by_index->>make_chunk_sort_map: Generate row_id_map
    make_chunk_sort_map->>Triton Kernel: Execute _make_chunk_sort_map_kernel
    Note over Triton Kernel: Compute total_valid_tokens<br/>Apply identity mapping for padding
    Triton Kernel-->>make_chunk_sort_map: Return row_id_map
    make_chunk_sort_map-->>sort_chunks_by_index: row_id_map
    sort_chunks_by_index->>sort_chunks_by_index: Store split_sizes & sorted_indices in residuals
    sort_chunks_by_index-->>JAX: Return (output, row_id_map), residuals
    
    Note over User,InspectFFI: Optional debugging path
    User->>InspectPrimitive: inspect_array(x, name)
    InspectPrimitive->>InspectFFI: FFI call with input buffer
    InspectFFI->>InspectFFI: cudaMemcpyAsync to host
    InspectFFI->>InspectFFI: Write to my_tensor_gpu{N}.bin
    InspectFFI-->>InspectPrimitive: Return aliased buffer
    InspectPrimitive-->>User: Return x (unchanged)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)


def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name parameter is unused - not passed to C++ backend or used in filename

Suggested change
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
def inspect_array(x: jnp.ndarray) -> jnp.ndarray:

Comment on lines +116 to +120
std::ofstream file(filename, std::ios::binary);
if (file.is_open()) {
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No error handling if file fails to open - silently continues without writing data

Suggested change
std::ofstream file(filename, std::ios::binary);
if (file.is_open()) {
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
}
std::ofstream file(filename, std::ios::binary);
if (!file.is_open()) {
return ffi::Error(ffi::ErrorCode::kInternal, "Failed to open file for writing");
}
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();

# Return gradients for all inputs: (inp, split_sizes, sorted_indices)
# split_sizes and sorted_indices are integer arrays, so their gradients are zeros
# with matching dtype (use float32 as a safe default for index arrays)
split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment says "use float32 as safe default" but code preserves original dtype (likely int32/int64) - comment is misleading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants