-
Notifications
You must be signed in to change notification settings - Fork 831
[Encoding] Add iree_encoding.dim op and reification
#23311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jorn Tuyls <[email protected]>
38700b3 to
26c689c
Compare
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that adding the interface may be over-design. What you're trying to do in this PR can just have a single canonicalization pattern? Below is an implementation example from Claude.
I think the main question is: why do we need this interface? In practice, the encodings are either come from SetEncoding op or frontend. Is it because you need such support for Flow::EncodeOp and Stream::EncodeOp as well?
struct ReifyEncodingDim : public OpRewritePattern<DimOp> {
LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
auto result = dyn_cast<OpResult>(dimOp.getSource());
if (!result)
return failure();
Operation *producer = result.getOwner();
int64_t dimIndex = dimOp.getConstantIndex();
// Source: set_encoding directly provides encoding dims.
if (auto setEnc = dyn_cast<SetEncodingOp>(producer)) {
ValueRange encodingDims = setEnc.getEncodingDims();
if (dimIndex >= encodingDims.size())
return failure();
rewriter.replaceOp(dimOp, encodingDims[dimIndex]);
return success();
}
// Pass-through: tensor.cast forwards to source.
if (auto castOp = dyn_cast<tensor::CastOp>(producer)) {
rewriter.replaceOpWithNewOp<DimOp>(dimOp, castOp.getSource(),
dimIndex);
return success();
}
// Pass-through: DPS ops forward to tied init.
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(producer)) {
if (auto *tiedInit = dpsOp.getTiedOpOperand(result)) {
rewriter.replaceOpWithNewOp<DimOp>(dimOp, tiedInit->get(),
dimIndex);
return success();
}
}
return failure();
}
};btw, please expand the context in PR description. I was not aware of an interface is added until I review the code. Here is a good guidance: https://google.github.io/eng-practices/review/developer/cl-descriptions.html#informative
| This interface enables reification of `iree_encoding.encoding_dim` operations | ||
| by tracing through producer chains to find where encoding dimension values | ||
| were originally captured. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| This interface enables reification of `iree_encoding.encoding_dim` operations | |
| by tracing through producer chains to find where encoding dimension values | |
| were originally captured. | |
| This interface enables reification of `iree_encoding.dim` operations by tracing | |
| through producer chains to find where encoding dimension values were | |
| originally captured. |
| - `set_encoding` implements `EncodingDimReificationInterface` and returns | ||
| the corresponding `encoding_dims` value | ||
| - `tensor.cast` and DPS ops (like `linalg.fill`, `linalg.generic`) forward | ||
| the query to their source/init operands |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - `set_encoding` implements `EncodingDimReificationInterface` and returns | |
| the corresponding `encoding_dims` value | |
| - `tensor.cast` and DPS ops (like `linalg.fill`, `linalg.generic`) forward | |
| the query to their source/init operands | |
| - `set_encoding` implements `EncodingDimReificationInterface` and returns | |
| the corresponding `encoding_dims` value. | |
| - `tensor.cast` and DPS ops (like `linalg.fill`, `linalg.generic`) forward | |
| the query to their source/init operands. |
| let results = (outs Index:$result); | ||
|
|
||
| let assemblyFormat = [{ | ||
| attr-dict $source `[` $index `]` `:` type($source) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making attribute list after source and index seems more common?
| attr-dict $source `[` $index `]` `:` type($source) | |
| $source `[` $index `]` attr-dict `:` type($source) |
| /*methodName=*/"reifyEncodingDim", | ||
| /*args=*/(ins | ||
| "::mlir::OpBuilder &":$builder, | ||
| "unsigned":$resultIndex, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we pass OpResult instead? It usually provides more information and passing it is not expensive as my understanding is that it is a pointer-like value.
I don't have a full picture about how you'd use the interface, so I'll leave the decision to you.
| - Success with the value if the dimension can be resolved directly | ||
| - Failure if the operation cannot directly provide the value | ||
| (caller should use `getEncodingDimSource` to trace through) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - Success with the value if the dimension can be resolved directly | |
| - Failure if the operation cannot directly provide the value | |
| (caller should use `getEncodingDimSource` to trace through) | |
| - Success with the value if the dimension can be resolved directly. | |
| - Failure if the operation cannot directly provide the value. | |
| (caller should use `getEncodingDimSource` to trace through) |
Do we need FailureOr? Do we just follow the other method that returns either a Value or null?
caller should use
getEncodingDimSourceto trace through
Can you collaborate a bit more? Does it mean that caller uses wrong method if it returns failure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is weird to see *Patterns under IR/. Are they only used by canonicalization patterns? If so, can you move them to EncodingOps.cpp?
| /// 2. Operations that forward encoding dims from a source (like tensor.cast): | ||
| /// The pattern calls getEncodingDimSource() and creates a new dim op on | ||
| /// that source. | ||
| /// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd drop this blank comment.
| OpResult result = dyn_cast<OpResult>(dimOp.getSource()); | ||
| if (!result) { | ||
| return failure(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| OpResult result = dyn_cast<OpResult>(dimOp.getSource()); | |
| if (!result) { | |
| return failure(); | |
| } | |
| auto result = dyn_cast<OpResult>(dimOp.getSource()); | |
| if (!result) { | |
| return failure(); | |
| } |
Please also replace return failure() with more meaningful message if possible. I.e., return rewriter.notifyMatchFailure(...). The error message is also a self-comment, which looks better to me.
| // Verify encodings match. | ||
| auto resultType = dyn_cast<RankedTensorType>(result.getType()); | ||
| auto initType = dyn_cast<RankedTensorType>(tiedInit->get().getType()); | ||
| if (!resultType || !initType) { | ||
| return failure(); | ||
| } | ||
|
|
||
| if (resultType.getEncoding() != initType.getEncoding()) { | ||
| return failure(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My intuition told me that it should be checked by the interface; yes, I confirmed it: https://github.com/llvm/llvm-project/blob/52dfcab327fe959074563603b6ebaaed314e9677/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp#L51-L59
for (OpOperand *opOperand : outputTensorOperands) {
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
if (result.getType() != opOperand->get().getType())
return op->emitOpError("expected type of operand #")
<< opOperand->getOperandNumber() << " ("
<< opOperand->get().getType() << ")"
<< " to match type of corresponding result (" << result.getType()
<< ")";
}They have the same type, which indicates that they have the same encoding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we have a unified interface method? Then the next question is, why do we need the interface? Can't it just be a single method?
(I may be missing how it is used in other places. I haven't reached to that state yet.)
Imo, the interface is appropriate as for the cost of a bit additional code we get:
Per my understanding, interfaces exist precisely to avoid encoding op-specific knowledge into transformation logic, and this keeps future additions isolated. |
Implements phase 3 for adding specialization support on dynamic values for data-tiling: #22370. This adds the
iree_encoding.dimoperation and reification patterns to query encoding dimensions. For example:Assisted-by: Claude