@@ -1094,6 +1094,93 @@ struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
10941094 vector::UnrollVectorOptions options;
10951095};
10961096
1097+ // / This pattern unrolls `vector.constant_mask` operations into smaller mask
1098+ // / operations based on the target unroll shape. Each unrolled slice computes
1099+ // / whether its elements should be masked based on the original mask dimensions
1100+ // / and the slice's offset position.
1101+ // /
1102+ // / Example:
1103+ // / Given a constant_mask operation:
1104+ // / %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
1105+ // /
1106+ // / and a target unroll shape of <4x8>, the pattern produces:
1107+ // /
1108+ // / %false = arith.constant dense<false> : vector<8x16xi1>
1109+ // /
1110+ // / Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
1111+ // / %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
1112+ // / %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1113+ // / : vector<4x8xi1> into vector<8x16xi1>
1114+ // /
1115+ // / Slice [0,8]: elements [0:4, 8:16] - partially within bounds
1116+ // / %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
1117+ // / %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1118+ // / : vector<4x8xi1> into vector<8x16xi1>
1119+ // /
1120+ // / Slice [4,0]: elements [4:8, 0:8] - partially within bounds
1121+ // / %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
1122+ // / %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1123+ // / : vector<4x8xi1> into vector<8x16xi1>
1124+ // /
1125+ // / Slice [4,8]: elements [4:8, 8:16] - partially within bounds
1126+ // / %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
1127+ // / %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1128+ // / : vector<4x8xi1> into vector<8x16xi1>
1129+ struct UnrollConstantMaskPattern
1130+ : public OpRewritePattern<vector::ConstantMaskOp> {
1131+ UnrollConstantMaskPattern (MLIRContext *context,
1132+ const vector::UnrollVectorOptions &options,
1133+ PatternBenefit benefit = 1 )
1134+ : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1135+ options (options) {}
1136+
1137+ LogicalResult matchAndRewrite (vector::ConstantMaskOp constantMaskOp,
1138+ PatternRewriter &rewriter) const override {
1139+ std::optional<SmallVector<int64_t >> targetShape =
1140+ getTargetShape (options, constantMaskOp);
1141+ if (!targetShape)
1142+ return failure ();
1143+
1144+ VectorType resultType = constantMaskOp.getVectorType ();
1145+ SmallVector<int64_t > originalSize = *constantMaskOp.getShapeForUnroll ();
1146+ Location loc = constantMaskOp.getLoc ();
1147+
1148+ Value result = arith::ConstantOp::create (rewriter, loc, resultType,
1149+ rewriter.getZeroAttr (resultType));
1150+ VectorType targetVectorType =
1151+ VectorType::get (*targetShape, rewriter.getI1Type ());
1152+ SmallVector<int64_t > strides (targetShape->size (), 1 );
1153+
1154+ // In each dimension (d), each unrolled vector computes its mask size as:
1155+ // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
1156+ for (const SmallVector<int64_t > &offsets :
1157+ StaticTileOffsetRange (originalSize, *targetShape)) {
1158+ SmallVector<int64_t > unrolledMaskDims;
1159+
1160+ for (auto [i, originalMaskDim] :
1161+ llvm::enumerate (constantMaskOp.getMaskDimSizes ())) {
1162+ // Calculate how many elements in this dimension should be masked
1163+ // for this particular slice
1164+ int64_t adjustedMaskSize =
1165+ std::max (originalMaskDim - offsets[i], static_cast <int64_t >(0 ));
1166+ int64_t unrolledMaskDim =
1167+ std::min (adjustedMaskSize, static_cast <int64_t >((*targetShape)[i]));
1168+ unrolledMaskDims.push_back (unrolledMaskDim);
1169+ }
1170+
1171+ auto unrolledMask = rewriter.createOrFold <vector::ConstantMaskOp>(
1172+ loc, targetVectorType, unrolledMaskDims);
1173+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
1174+ loc, unrolledMask, result, offsets, strides);
1175+ }
1176+ rewriter.replaceOp (constantMaskOp, result);
1177+ return success ();
1178+ }
1179+
1180+ private:
1181+ vector::UnrollVectorOptions options;
1182+ };
1183+
10971184// / Checks whether extractShape is a contiguous slice of shape.
10981185// / For extractShape to be contiguous in shape:
10991186// / 1) All but the leading dimension of extractShape and shape must match
@@ -1294,8 +1381,8 @@ void mlir::vector::populateVectorUnrollPatterns(
12941381 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
12951382 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
12961383 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1297- UnrollCreateMaskPattern>(patterns. getContext (), options,
1298- benefit);
1384+ UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
1385+ patterns. getContext (), options, benefit);
12991386}
13001387
13011388void mlir::vector::populateVectorToElementsUnrollPatterns (
0 commit comments