-
Notifications
You must be signed in to change notification settings - Fork 633
[JAX] TE Permutation integration to Maxtext #2672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…ger than num tokens Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: JAX Toolbox <jax@nvidia.com>
for more information, see https://pre-commit.ci
|
This PR contain changes cherry-picked from #2651 . I can wait until this gets merged and then merge mine, but if my PR is needed more urgently, happy to remove the cherry picked change |
Greptile OverviewGreptile SummaryThis PR fixes MaxText integration issues by addressing tensor permutation tracing problems when EP>1 and adding debugging utilities. Major changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant JAX
participant sort_chunks_by_index
participant make_chunk_sort_map
participant Triton Kernel
participant InspectPrimitive
participant InspectFFI
User->>JAX: Call sort_chunks_by_index(inp, split_sizes, sorted_indices)
JAX->>sort_chunks_by_index: Forward pass
sort_chunks_by_index->>make_chunk_sort_map: Generate row_id_map
make_chunk_sort_map->>Triton Kernel: Execute _make_chunk_sort_map_kernel
Note over Triton Kernel: Compute total_valid_tokens<br/>Apply identity mapping for padding
Triton Kernel-->>make_chunk_sort_map: Return row_id_map
make_chunk_sort_map-->>sort_chunks_by_index: row_id_map
sort_chunks_by_index->>sort_chunks_by_index: Store split_sizes & sorted_indices in residuals
sort_chunks_by_index-->>JAX: Return (output, row_id_map), residuals
Note over User,InspectFFI: Optional debugging path
User->>InspectPrimitive: inspect_array(x, name)
InspectPrimitive->>InspectFFI: FFI call with input buffer
InspectFFI->>InspectFFI: cudaMemcpyAsync to host
InspectFFI->>InspectFFI: Write to my_tensor_gpu{N}.bin
InspectFFI-->>InspectPrimitive: Return aliased buffer
InspectPrimitive-->>User: Return x (unchanged)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 3 comments
| _inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) | ||
|
|
||
|
|
||
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name parameter is unused - not passed to C++ backend or used in filename
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: | |
| def inspect_array(x: jnp.ndarray) -> jnp.ndarray: |
| std::ofstream file(filename, std::ios::binary); | ||
| if (file.is_open()) { | ||
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | ||
| file.close(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No error handling if file fails to open - silently continues without writing data
| std::ofstream file(filename, std::ios::binary); | |
| if (file.is_open()) { | |
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | |
| file.close(); | |
| } | |
| std::ofstream file(filename, std::ios::binary); | |
| if (!file.is_open()) { | |
| return ffi::Error(ffi::ErrorCode::kInternal, "Failed to open file for writing"); | |
| } | |
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | |
| file.close(); |
| # Return gradients for all inputs: (inp, split_sizes, sorted_indices) | ||
| # split_sizes and sorted_indices are integer arrays, so their gradients are zeros | ||
| # with matching dtype (use float32 as a safe default for index arrays) | ||
| split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment says "use float32 as safe default" but code preserves original dtype (likely int32/int64) - comment is misleading
Description
Changes needed on TE side to make maxtext integration works
Issue # 2585
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: