-
Notifications
You must be signed in to change notification settings - Fork 73
Description
Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request?
Critical (currently preventing usage)
Please provide a clear description of problem this feature solves
I am implementing a wavefront-style parallel algorithm where each thread needs to access the value computed by its neighbor in the previous step. Mathematically, this represents a "Shift Right" operation on a 1D Tile residing in registers: new_vec[i] = old_vec[i-1].
Currently, cuTile enforces strict Power-of-2 shape constraints on extract and cat. This makes it impossible to implement a shift by extracting the first
Real usage example:
In stencil computations or dynamic programming wavefronts, data often flows diagonally or horizontally between threads. Without a register-level shift, developers are forced to use high-overhead workarounds:
- Global Memory: Writing to global memory and reading back with an offset (scatter gather).
- Matrix Multiplication: Constructing a shift matrix and using
mmato perform the shift. This works but is computationally expensive (overkill) for a simple data movement operation.
Feature Description
As a high-performance kernel developer,
I want to efficiently shift or rotate elements within a Tile (intra-tile communication),
So that I can implement stencil and wavefront dependencies entirely within registers without incurring global memory latency or Tensor Core overhead.
Describe your ideal solution
I propose adding a dedicated primitive for intra-tile communication, which maps to efficient hardware instructions (like __shfl_up_sync or __shfl_down_sync in CUDA).
Proposed API:
# Shift elements to the right by 'shift_amount'.
# Elements shifted in are filled with 'fill_value'.
output_tile = ct.shift(input_tile, shift_amount=1, fill_value=0)Alternative Solution:
Relax the Power-of-2 constraint for ct.extract and ct.cat. If the library allowed operations on arbitrary shapes (e.g., extracting a size-127 tile), users could manually implement shifts via slicing and concatenation:
# Ideally, this should be allowed:
slice = ct.extract(val, index=(0,), shape=(127,))
boundary = ct.full((1,), 0, dtype=ct.int32)
shifted = ct.cat(boundary, slice, axis=0)Describe any alternatives you have considered
No response
Additional context
No response
Contributing Guidelines
- I agree to follow cuTile Python's contributing guidelines
- I have searched the open feature requests and have found no duplicates for this feature request