Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
add_mlir_interface(SymbolInterfaces)
set(LLVM_TARGET_DEFINITIONS SymbolInterfaces.td)
mlir_tablegen(SymbolInterfacesAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(SymbolInterfacesAttrInterface.cpp.inc -gen-attr-interface-defs)
add_mlir_interface(RegionKindInterface)

set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
Expand Down
21 changes: 20 additions & 1 deletion mlir/include/mlir/IR/SymbolInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
This interface describes an operation that may use a `Symbol`. This
interface allows for users of symbols to hook into verification and other
symbol related utilities that are either costly or otherwise disallowed
within a traditional operation.
within an operation.
}];
let cppNamespace = "::mlir";

Expand All @@ -222,6 +222,25 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
];
}

def SymbolUserAttrInterface : AttrInterface<"SymbolUserAttrInterface"> {
let description = [{
This interface describes an attribute that may use a `Symbol`. This
interface allows for users of symbols to hook into verification and other
symbol related utilities that are either costly or otherwise disallowed
within an operation (e.g., recreating symbol users per op verified rather
than per symbol table, or querying symbols usage of sibblings).
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<"Verify the symbol uses held by this attribute of this operation.",
"::llvm::LogicalResult", "verifySymbolUses",
(ins "::mlir::Operation *":$op,
"::mlir::SymbolTableCollection &":$symbolTable)
>,
];
}

//===----------------------------------------------------------------------===//
// Symbol Traits
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/SymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,5 +499,6 @@ ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser,

/// Include the generated symbol interfaces.
#include "mlir/IR/SymbolInterfaces.h.inc"
#include "mlir/IR/SymbolInterfacesAttrInterface.h.inc"

#endif // MLIR_IR_SYMBOLTABLE_H
10 changes: 9 additions & 1 deletion mlir/lib/IR/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,14 @@ LogicalResult detail::verifySymbolTable(Operation *op) {
SymbolTableCollection symbolTable;
auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
return WalkResult(user.verifySymbolUses(symbolTable));
if (failed(user.verifySymbolUses(symbolTable)))
return WalkResult::interrupt();
for (auto &attr : op->getDiscardableAttrs()) {
if (auto user = dyn_cast<SymbolUserAttrInterface>(attr.getValue())) {
if (failed(user.verifySymbolUses(op, symbolTable)))
return WalkResult::interrupt();
}
}
return WalkResult::advance();
};

Expand Down Expand Up @@ -1132,3 +1139,4 @@ ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,

/// Include the generated symbol interfaces.
#include "mlir/IR/SymbolInterfaces.cpp.inc"
#include "mlir/IR/SymbolInterfacesAttrInterface.cpp.inc"
16 changes: 16 additions & 0 deletions mlir/test/IR/test-verifiers-attr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// Test basic symbol verification using discardable attribute.
module {
func.func @existing_symbol() { return }

func.func @test() attributes {symbol_ref = #test.symbol_ref_attr<@existing_symbol>} { return }
}

// -----

// Test invalid symbol reference, symbol does not exist.
module {
// expected-error@+1 {{TestSymbolRefAttr::verifySymbolUses: '@non_existent_symbol' does not reference a valid symbol}}
func.func @test() attributes {symbol_ref = #test.symbol_ref_attr<@non_existent_symbol>} { return }
}
14 changes: 14 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/TensorEncoding.td"

// All of the attributes will extend this class.
Expand Down Expand Up @@ -456,4 +457,17 @@ def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout",
let assemblyFormat = "`<` $dummy `>`";
}

// Test attribute that implements SymbolUserAttrInterface.
def TestSymbolRefAttr : Test_Attr<"TestSymbolRef",
[DeclareAttrInterfaceMethods<SymbolUserAttrInterface>]> {
let mnemonic = "symbol_ref_attr";
let summary = "Test attribute that references a symbol";
let description = [{
This attribute holds a reference to a symbol and implements
SymbolUserAttrInterface to verify that the referenced symbol exists.
}];
let parameters = (ins "::mlir::FlatSymbolRefAttr":$symbol);
let assemblyFormat = "`<` $symbol `>`";
}

#endif // TEST_ATTRDEFS
19 changes: 19 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,25 @@ LogicalResult TestCopyCountAttr::verify(
return success();
}

//===----------------------------------------------------------------------===//
// TestSymbolRefAttr
//===----------------------------------------------------------------------===//

LogicalResult
TestSymbolRefAttr::verifySymbolUses(Operation *op,
SymbolTableCollection &symbolTable) const {
// Verify that the referenced symbol exists
if (!symbolTable.lookupNearestSymbolFrom<SymbolOpInterface>(op, getSymbol()))
return op->emitOpError()
<< "TestSymbolRefAttr::verifySymbolUses: '" << getSymbol()
<< "' does not reference a valid symbol";
return success();
}

//===----------------------------------------------------------------------===//
// Generated Attribute Definitions
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
// CopyCountAttr Implementation
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 3 additions & 2 deletions mlir/test/lib/Dialect/Test/TestAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TensorEncoding.h"

// generated files require above includes to come first
Expand Down
2 changes: 2 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ exports_files(glob(["include/**/*.td"]))
tbl_outs = {
"include/mlir/IR/" + name + ".h.inc": ["-gen-op-interface-decls"],
"include/mlir/IR/" + name + ".cpp.inc": ["-gen-op-interface-defs"],
"include/mlir/IR/" + name + "AttrInterface.h.inc": ["-gen-attr-interface-decls"],
"include/mlir/IR/" + name + "AttrInterface.cpp.inc": ["-gen-attr-interface-defs"],
},
tblgen = ":mlir-tblgen",
td_file = "include/mlir/IR/" + name + ".td",
Expand Down
Loading