Skip to content

Commit 1d0d7da

Browse files
jpienaarjoker-eph
andauthored
[mlir] Add symbol user attribute interface. (#153206)
Enables verification of attributes, independent of op, that references symbols. This enables verifying Attribute with symbol usage independent of operation attached to (e.g., the validity is on the Attribute independent of the operation). --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 617ba83 commit 1d0d7da

File tree

9 files changed

+87
-4
lines changed

9 files changed

+87
-4
lines changed

mlir/include/mlir/IR/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
add_mlir_interface(SymbolInterfaces)
2+
set(LLVM_TARGET_DEFINITIONS SymbolInterfaces.td)
3+
mlir_tablegen(SymbolInterfacesAttrInterface.h.inc -gen-attr-interface-decls)
4+
mlir_tablegen(SymbolInterfacesAttrInterface.cpp.inc -gen-attr-interface-defs)
25
add_mlir_interface(RegionKindInterface)
36

47
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)

mlir/include/mlir/IR/SymbolInterfaces.td

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
210210
This interface describes an operation that may use a `Symbol`. This
211211
interface allows for users of symbols to hook into verification and other
212212
symbol related utilities that are either costly or otherwise disallowed
213-
within a traditional operation.
213+
within an operation.
214214
}];
215215
let cppNamespace = "::mlir";
216216

@@ -222,6 +222,25 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
222222
];
223223
}
224224

225+
def SymbolUserAttrInterface : AttrInterface<"SymbolUserAttrInterface"> {
226+
let description = [{
227+
This interface describes an attribute that may use a `Symbol`. This
228+
interface allows for users of symbols to hook into verification and other
229+
symbol related utilities that are either costly or otherwise disallowed
230+
within an operation (e.g., recreating symbol users per op verified rather
231+
than per symbol table, or querying symbols usage of sibblings).
232+
}];
233+
let cppNamespace = "::mlir";
234+
235+
let methods = [
236+
InterfaceMethod<"Verify the symbol uses held by this attribute of this operation.",
237+
"::llvm::LogicalResult", "verifySymbolUses",
238+
(ins "::mlir::Operation *":$op,
239+
"::mlir::SymbolTableCollection &":$symbolTable)
240+
>,
241+
];
242+
}
243+
225244
//===----------------------------------------------------------------------===//
226245
// Symbol Traits
227246
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/SymbolTable.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,5 +499,6 @@ ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser,
499499

500500
/// Include the generated symbol interfaces.
501501
#include "mlir/IR/SymbolInterfaces.h.inc"
502+
#include "mlir/IR/SymbolInterfacesAttrInterface.h.inc"
502503

503504
#endif // MLIR_IR_SYMBOLTABLE_H

mlir/lib/IR/SymbolTable.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,14 @@ LogicalResult detail::verifySymbolTable(Operation *op) {
511511
SymbolTableCollection symbolTable;
512512
auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
513513
if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
514-
return WalkResult(user.verifySymbolUses(symbolTable));
514+
if (failed(user.verifySymbolUses(symbolTable)))
515+
return WalkResult::interrupt();
516+
for (auto &attr : op->getDiscardableAttrs()) {
517+
if (auto user = dyn_cast<SymbolUserAttrInterface>(attr.getValue())) {
518+
if (failed(user.verifySymbolUses(op, symbolTable)))
519+
return WalkResult::interrupt();
520+
}
521+
}
515522
return WalkResult::advance();
516523
};
517524

@@ -1132,3 +1139,4 @@ ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
11321139

11331140
/// Include the generated symbol interfaces.
11341141
#include "mlir/IR/SymbolInterfaces.cpp.inc"
1142+
#include "mlir/IR/SymbolInterfacesAttrInterface.cpp.inc"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2+
3+
// Test basic symbol verification using discardable attribute.
4+
module {
5+
func.func @existing_symbol() { return }
6+
7+
func.func @test() attributes {symbol_ref = #test.symbol_ref_attr<@existing_symbol>} { return }
8+
}
9+
10+
// -----
11+
12+
// Test invalid symbol reference, symbol does not exist.
13+
module {
14+
// expected-error@+1 {{TestSymbolRefAttr::verifySymbolUses: '@non_existent_symbol' does not reference a valid symbol}}
15+
func.func @test() attributes {symbol_ref = #test.symbol_ref_attr<@non_existent_symbol>} { return }
16+
}

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td"
2222
include "mlir/IR/BuiltinAttributeInterfaces.td"
2323
include "mlir/IR/EnumAttr.td"
2424
include "mlir/IR/OpAsmInterface.td"
25+
include "mlir/IR/SymbolInterfaces.td"
2526
include "mlir/IR/TensorEncoding.td"
2627

2728
// All of the attributes will extend this class.
@@ -456,4 +457,17 @@ def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout",
456457
let assemblyFormat = "`<` $dummy `>`";
457458
}
458459

460+
// Test attribute that implements SymbolUserAttrInterface.
461+
def TestSymbolRefAttr : Test_Attr<"TestSymbolRef",
462+
[DeclareAttrInterfaceMethods<SymbolUserAttrInterface>]> {
463+
let mnemonic = "symbol_ref_attr";
464+
let summary = "Test attribute that references a symbol";
465+
let description = [{
466+
This attribute holds a reference to a symbol and implements
467+
SymbolUserAttrInterface to verify that the referenced symbol exists.
468+
}];
469+
let parameters = (ins "::mlir::FlatSymbolRefAttr":$symbol);
470+
let assemblyFormat = "`<` $symbol `>`";
471+
}
472+
459473
#endif // TEST_ATTRDEFS

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,25 @@ LogicalResult TestCopyCountAttr::verify(
223223
return success();
224224
}
225225

226+
//===----------------------------------------------------------------------===//
227+
// TestSymbolRefAttr
228+
//===----------------------------------------------------------------------===//
229+
230+
LogicalResult
231+
TestSymbolRefAttr::verifySymbolUses(Operation *op,
232+
SymbolTableCollection &symbolTable) const {
233+
// Verify that the referenced symbol exists
234+
if (!symbolTable.lookupNearestSymbolFrom<SymbolOpInterface>(op, getSymbol()))
235+
return op->emitOpError()
236+
<< "TestSymbolRefAttr::verifySymbolUses: '" << getSymbol()
237+
<< "' does not reference a valid symbol";
238+
return success();
239+
}
240+
241+
//===----------------------------------------------------------------------===//
242+
// Generated Attribute Definitions
243+
//===----------------------------------------------------------------------===//
244+
226245
//===----------------------------------------------------------------------===//
227246
// CopyCountAttr Implementation
228247
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestAttributes.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
2121
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2222
#include "mlir/IR/Attributes.h"
23-
#include "mlir/IR/Diagnostics.h"
23+
#include "mlir/IR/BuiltinAttributes.h"
2424
#include "mlir/IR/Dialect.h"
25-
#include "mlir/IR/DialectImplementation.h"
2625
#include "mlir/IR/DialectResourceBlobManager.h"
26+
#include "mlir/IR/OpImplementation.h"
27+
#include "mlir/IR/SymbolTable.h"
2728
#include "mlir/IR/TensorEncoding.h"
2829

2930
// generated files require above includes to come first

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ exports_files(glob(["include/**/*.td"]))
8585
tbl_outs = {
8686
"include/mlir/IR/" + name + ".h.inc": ["-gen-op-interface-decls"],
8787
"include/mlir/IR/" + name + ".cpp.inc": ["-gen-op-interface-defs"],
88+
"include/mlir/IR/" + name + "AttrInterface.h.inc": ["-gen-attr-interface-decls"],
89+
"include/mlir/IR/" + name + "AttrInterface.cpp.inc": ["-gen-attr-interface-defs"],
8890
},
8991
tblgen = ":mlir-tblgen",
9092
td_file = "include/mlir/IR/" + name + ".td",

0 commit comments

Comments
 (0)