Skip to content

Commit 71ee84a

Browse files
authored
[MLIR][Vector] Add unroll pattern for vector.constant_mask (#171518)
This PR adds unrolling for vector.constant_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as: min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).
1 parent 757c5b3 commit 71ee84a

File tree

4 files changed

+116
-8
lines changed

4 files changed

+116
-8
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2534,7 +2534,9 @@ def Vector_TypeCastOp :
25342534
}
25352535

25362536
def Vector_ConstantMaskOp :
2537-
Vector_Op<"constant_mask", [Pure]>,
2537+
Vector_Op<"constant_mask", [Pure,
2538+
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
2539+
]>,
25382540
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
25392541
Results<(outs VectorOfAnyRankOf<[I1]>)> {
25402542
let summary = "creates a constant vector mask";

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,93 @@ struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
10941094
vector::UnrollVectorOptions options;
10951095
};
10961096

1097+
/// This pattern unrolls `vector.constant_mask` operations into smaller mask
1098+
/// operations based on the target unroll shape. Each unrolled slice computes
1099+
/// whether its elements should be masked based on the original mask dimensions
1100+
/// and the slice's offset position.
1101+
///
1102+
/// Example:
1103+
/// Given a constant_mask operation:
1104+
/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
1105+
///
1106+
/// and a target unroll shape of <4x8>, the pattern produces:
1107+
///
1108+
/// %false = arith.constant dense<false> : vector<8x16xi1>
1109+
///
1110+
/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
1111+
/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
1112+
/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1113+
/// : vector<4x8xi1> into vector<8x16xi1>
1114+
///
1115+
/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
1116+
/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
1117+
/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1118+
/// : vector<4x8xi1> into vector<8x16xi1>
1119+
///
1120+
/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
1121+
/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
1122+
/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1123+
/// : vector<4x8xi1> into vector<8x16xi1>
1124+
///
1125+
/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
1126+
/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
1127+
/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1128+
/// : vector<4x8xi1> into vector<8x16xi1>
1129+
struct UnrollConstantMaskPattern
1130+
: public OpRewritePattern<vector::ConstantMaskOp> {
1131+
UnrollConstantMaskPattern(MLIRContext *context,
1132+
const vector::UnrollVectorOptions &options,
1133+
PatternBenefit benefit = 1)
1134+
: OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1135+
options(options) {}
1136+
1137+
LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1138+
PatternRewriter &rewriter) const override {
1139+
std::optional<SmallVector<int64_t>> targetShape =
1140+
getTargetShape(options, constantMaskOp);
1141+
if (!targetShape)
1142+
return failure();
1143+
1144+
VectorType resultType = constantMaskOp.getVectorType();
1145+
SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1146+
Location loc = constantMaskOp.getLoc();
1147+
1148+
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1149+
rewriter.getZeroAttr(resultType));
1150+
VectorType targetVectorType =
1151+
VectorType::get(*targetShape, rewriter.getI1Type());
1152+
SmallVector<int64_t> strides(targetShape->size(), 1);
1153+
1154+
// In each dimension (d), each unrolled vector computes its mask size as:
1155+
// min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
1156+
for (const SmallVector<int64_t> &offsets :
1157+
StaticTileOffsetRange(originalSize, *targetShape)) {
1158+
SmallVector<int64_t> unrolledMaskDims;
1159+
1160+
for (auto [i, originalMaskDim] :
1161+
llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1162+
// Calculate how many elements in this dimension should be masked
1163+
// for this particular slice
1164+
int64_t adjustedMaskSize =
1165+
std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
1166+
int64_t unrolledMaskDim =
1167+
std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
1168+
unrolledMaskDims.push_back(unrolledMaskDim);
1169+
}
1170+
1171+
auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
1172+
loc, targetVectorType, unrolledMaskDims);
1173+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1174+
loc, unrolledMask, result, offsets, strides);
1175+
}
1176+
rewriter.replaceOp(constantMaskOp, result);
1177+
return success();
1178+
}
1179+
1180+
private:
1181+
vector::UnrollVectorOptions options;
1182+
};
1183+
10971184
/// Checks whether extractShape is a contiguous slice of shape.
10981185
/// For extractShape to be contiguous in shape:
10991186
/// 1) All but the leading dimension of extractShape and shape must match
@@ -1294,8 +1381,8 @@ void mlir::vector::populateVectorUnrollPatterns(
12941381
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
12951382
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
12961383
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1297-
UnrollCreateMaskPattern>(patterns.getContext(), options,
1298-
benefit);
1384+
UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
1385+
patterns.getContext(), options, benefit);
12991386
}
13001387

13011388
void mlir::vector::populateVectorToElementsUnrollPatterns(

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,23 @@ func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
552552
// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
553553
// CHECK: return %[[S3]] : vector<16x16xi1>
554554

555+
func.func @vector_constant_mask() -> vector<16x16xi1> {
556+
%0 = vector.constant_mask [12, 10] : vector<16x16xi1>
557+
return %0 : vector<16x16xi1>
558+
}
559+
560+
// CHECK-LABEL: func @vector_constant_mask
561+
// CHECK-SAME: () -> vector<16x16xi1>
562+
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
563+
// CHECK: %[[CST_TRUE:.*]] = arith.constant dense<true> : vector<8x8xi1>
564+
// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[CST_TRUE]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
565+
// CHECK: %[[MASK01:.*]] = vector.constant_mask [8, 2] : vector<8x8xi1>
566+
// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
567+
// CHECK: %[[MASK10:.*]] = vector.constant_mask [4, 8] : vector<8x8xi1>
568+
// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
569+
// CHECK: %[[MASK11:.*]] = vector.constant_mask [4, 2] : vector<8x8xi1>
570+
// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
571+
// CHECK: return %[[INS11]] : vector<16x16xi1>
555572

556573
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
557574
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,13 @@ struct TestVectorUnrollingPatterns
179179
return success(isa<vector::StepOp>(op));
180180
}));
181181
populateVectorUnrollPatterns(
182-
patterns, UnrollVectorOptions()
183-
.setNativeShape(ArrayRef<int64_t>{8, 8})
184-
.setFilterConstraint([](Operation *op) {
185-
return success(isa<vector::CreateMaskOp>(op));
186-
}));
182+
patterns,
183+
UnrollVectorOptions()
184+
.setNativeShape(ArrayRef<int64_t>{8, 8})
185+
.setFilterConstraint([](Operation *op) {
186+
return success(
187+
isa<vector::CreateMaskOp, vector::ConstantMaskOp>(op));
188+
}));
187189
populateVectorUnrollPatterns(
188190
patterns,
189191
UnrollVectorOptions()

0 commit comments

Comments
 (0)