diff --git a/.gitignore b/.gitignore index 008cb93..21453af 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ llvm-* ipu external xla/jax-test +.cache diff --git a/mlir/example/.devcontainer/Dockerfile b/mlir/example/.devcontainer/Dockerfile new file mode 100644 index 0000000..02a83dd --- /dev/null +++ b/mlir/example/.devcontainer/Dockerfile @@ -0,0 +1,32 @@ +FROM alwaysproblem/fastdev-u2204:zsh + +ARG UID=1000 +ARG GID=1000 + +RUN echo "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy main" > /etc/apt/sources.list.d/llvm.list \ + && echo "deb-src http://apt.llvm.org/jammy/ llvm-toolchain-jammy main" >> /etc/apt/sources.list.d/llvm.list \ + && echo "# 18" >> /etc/apt/sources.list.d/llvm.list \ + && echo "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main" >> /etc/apt/sources.list.d/llvm.list \ + && echo "deb-src http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main" >> /etc/apt/sources.list.d/llvm.list \ + && echo "# 19" >> /etc/apt/sources.list.d/llvm.list \ + && echo "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-19 main" >> /etc/apt/sources.list.d/llvm.list \ + && echo "deb-src http://apt.llvm.org/jammy/ llvm-toolchain-jammy-19 main" >> /etc/apt/sources.list.d/llvm.list \ + && wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc \ + && apt update -y && \ + apt install -y \ + python3 python3-dev python3-setuptools python3-pip \ + libtinfo-dev zlib1g-dev \ + build-essential cmake ninja-build \ + clang-19 clang-tidy-19 clangd-19 cmake-format \ + clang-format-19 lldb-19 lld-19 libfmt-dev libspdlog-dev \ + && update-alternatives --install /usr/bin/clang clang /usr/bin/clang-19 100 \ + && update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-19 100 \ + && update-alternatives --install /usr/bin/clangd clangd /usr/bin/clangd-19 100 \ + && update-alternatives --install /usr/bin/clang-format clang-format /usr/bin/clang-format-19 100 \ + && update-alternatives --install /usr/bin/clang-tidy clang-tidy /usr/bin/clang-tidy-19 100 \ + && update-alternatives --install /usr/bin/lld lld /usr/bin/lld-19 100 \ + && update-alternatives --install /usr/bin/lldb lldb /usr/bin/lldb-19 100 + +RUN git config --global --add safe.directory '*' && \ + /root/.local/bin/setup_new_user ${UID} ${GID} && \ + python3 -m pip install pre-commit compdb diff --git a/mlir/example/.devcontainer/devcontainer.json b/mlir/example/.devcontainer/devcontainer.json new file mode 100644 index 0000000..c576b68 --- /dev/null +++ b/mlir/example/.devcontainer/devcontainer.json @@ -0,0 +1,73 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda +{ + "remoteUser": "root", + "name": "mlir-example", + "workspaceMount": "source=${localWorkspaceFolder},target=${localWorkspaceFolder}/../../../MLcompiler-tutorial/mlir/${localWorkspaceFolderBasename},type=bind", + "workspaceFolder": "/root/Desktop/dockerVolumn/MLcompiler-tutorial/mlir/${localWorkspaceFolderBasename}", + "build": { + "context": "${localWorkspaceFolder}/.devcontainer", + "dockerfile": "Dockerfile", + "options": [ + "--net=host" + ], + "args": { + "UID": "1000", + "GID": "1000" + } + }, + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + // Use 'postCreateCommand' to run commands after the container is created. + // "postCreateCommand": "python --version", + // Configure tool-specific properties. + // "customizations": {}, + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" + "privileged": true, + // "capAdd": ["SYS_PTRACE"], + "mounts": [ + { + "source": "${localWorkspaceFolder}/../../../", + "target": "/root/Desktop/dockerVolumn", + "type": "bind" + } + ], + "runArgs": [ + // "--cap-add=SYS_PTRACE", + // "--security-opt", + // "seccomp=unconfined", + "--name", + // "${localEnv:USER}-tvm", + "yyx-mlir-example", + // "-v", + // "/data/rech/yongxiy/Desktop/dockerVolumn:/root/Desktop/dockerVolumn" + ], + "customizations": { + "vscode": { + "extensions": [ + "jeff-hykin.better-cpp-syntax", + "aaron-bond.better-comments", + "ms-vscode.cpptools-themes", + "revng.llvm-ir", + "jakob-erzar.llvm-tablegen", + "MomenAbdelkarim-WyattCalandro-LuisPrieto.mlir", + "ms-vscode.cpptools", + "ms-vscode.cpptools-extension-pack", + "twxs.cmake", + "josetr.cmake-language-support-vscode", + "ms-vscode.cmake-tools", + "cheshirekow.cmake-format", + "yzhang.markdown-all-in-one", + "bierner.markdown-preview-github-styles", + "bierner.markdown-mermaid", + "DavidAnson.vscode-markdownlint", + "llvm-vs-code-extensions.vscode-mlir", + "llvm-vs-code-extensions.vscode-clangd", + "llvm-vs-code-extensions.lldb-dap" + ] + } + } +} \ No newline at end of file diff --git a/mlir/example/.devcontainer/noop.txt b/mlir/example/.devcontainer/noop.txt new file mode 100644 index 0000000..dde8dc3 --- /dev/null +++ b/mlir/example/.devcontainer/noop.txt @@ -0,0 +1,3 @@ +This file copied into the container along with environment.yml* from the parent +folder. This file is included to prevents the Dockerfile COPY instruction from +failing if no environment.yml is found. \ No newline at end of file diff --git a/mlir/example/.pre-commit-config.yaml b/mlir/example/.pre-commit-config.yaml index b0a3c02..5736549 100644 --- a/mlir/example/.pre-commit-config.yaml +++ b/mlir/example/.pre-commit-config.yaml @@ -6,38 +6,12 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer -- repo: https://github.com/pycqa/pylint - rev: v2.15.5 - hooks: - - id: pylint - args: - - "--rcfile=.pylintrc" - exclude: tests(/\w*)*/functional/|tests/input|tests(/\w*)*data/|doc/|TFserving/ClientAPI/go/pkg/ - -- repo: https://github.com/google/yapf - rev: v0.32.0 - hooks: - - id: yapf - -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.982 - hooks: - - id: mypy - additional_dependencies: ['types-requests'] - exclude: tests(/\w*)*/functional/|tests/input|tests(/\w*)*data/|doc/|TFserving/ClientAPI/go/pkg/|(/\w*)*_test.py|xla/jax-core - - repo: https://github.com/pre-commit/mirrors-clang-format rev: 'v14.0.6' hooks: - id: clang-format types_or: [c++, c] -- repo: https://github.com/mwouts/jupytext - rev: v1.14.1 - hooks: - - id: jupytext - args: [--sync] - - repo: https://github.com/cheshirekow/cmake-format-precommit rev: v0.6.10 hooks: diff --git a/mlir/example/CMakeLists.txt b/mlir/example/CMakeLists.txt index 6e9c1e7..b032946 100644 --- a/mlir/example/CMakeLists.txt +++ b/mlir/example/CMakeLists.txt @@ -43,3 +43,6 @@ add_subdirectory(Ch5) add_subdirectory(Ch6) add_subdirectory(Ch7) add_subdirectory(Ch8) +add_subdirectory(transform_Ch2) +add_subdirectory(transform_Ch3) +add_subdirectory(transform_Ch4) diff --git a/mlir/example/Ch1/include/toy/AST.h b/mlir/example/Ch1/include/toy/AST.h index c9d1bdb..d2ba101 100644 --- a/mlir/example/Ch1/include/toy/AST.h +++ b/mlir/example/Ch1/include/toy/AST.h @@ -15,14 +15,14 @@ #ifndef TOY_AST_H #define TOY_AST_H -#include -#include -#include +#include "toy/Lexer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "toy/Lexer.h" +#include +#include +#include namespace toy { @@ -33,7 +33,7 @@ struct VarType { /// Base class for all expression nodes. class ExprAST { - public: +public: enum ExprASTKind { Expr_VarDecl, Expr_Return, @@ -53,7 +53,7 @@ class ExprAST { const Location &loc() { return location; } - private: +private: const ExprASTKind kind; Location location; }; @@ -65,7 +65,7 @@ using ExprASTList = std::vector>; class NumberExprAST : public ExprAST { double val; - public: +public: NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, std::move(loc)), val(val) {} @@ -80,11 +80,10 @@ class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; - public: +public: LiteralExprAST(Location loc, std::vector> values, std::vector dims) - : ExprAST(Expr_Literal, std::move(loc)), - values(std::move(values)), + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), dims(std::move(dims)) {} llvm::ArrayRef> getValues() { return values; } @@ -98,7 +97,7 @@ class LiteralExprAST : public ExprAST { class VariableExprAST : public ExprAST { std::string name; - public: +public: VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, std::move(loc)), name(name) {} @@ -114,13 +113,11 @@ class VarDeclExprAST : public ExprAST { VarType type; std::unique_ptr initVal; - public: +public: VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr initVal) - : ExprAST(Expr_VarDecl, std::move(loc)), - name(name), - type(std::move(type)), - initVal(std::move(initVal)) {} + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } @@ -134,12 +131,13 @@ class VarDeclExprAST : public ExprAST { class ReturnExprAST : public ExprAST { std::optional> expr; - public: +public: ReturnExprAST(Location loc, std::optional> expr) : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} std::optional getExpr() { - if (expr.has_value()) return expr->get(); + if (expr.has_value()) + return expr->get(); return std::nullopt; } @@ -152,16 +150,14 @@ class BinaryExprAST : public ExprAST { char op; std::unique_ptr lhs, rhs; - public: +public: char getOp() { return op; } ExprAST *getLHS() { return lhs.get(); } ExprAST *getRHS() { return rhs.get(); } BinaryExprAST(Location loc, char op, std::unique_ptr lhs, std::unique_ptr rhs) - : ExprAST(Expr_BinOp, std::move(loc)), - op(op), - lhs(std::move(lhs)), + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} /// LLVM style RTTI @@ -173,11 +169,10 @@ class CallExprAST : public ExprAST { std::string callee; std::vector> args; - public: +public: CallExprAST(Location loc, const std::string &callee, std::vector> args) - : ExprAST(Expr_Call, std::move(loc)), - callee(callee), + : ExprAST(Expr_Call, std::move(loc)), callee(callee), args(std::move(args)) {} llvm::StringRef getCallee() { return callee; } @@ -191,7 +186,7 @@ class CallExprAST : public ExprAST { class PrintExprAST : public ExprAST { std::unique_ptr arg; - public: +public: PrintExprAST(Location loc, std::unique_ptr arg) : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} @@ -209,7 +204,7 @@ class PrototypeAST { std::string name; std::vector> args; - public: +public: PrototypeAST(Location location, const std::string &name, std::vector> args) : location(std::move(location)), name(name), args(std::move(args)) {} @@ -224,7 +219,7 @@ class FunctionAST { std::unique_ptr proto; std::unique_ptr body; - public: +public: FunctionAST(std::unique_ptr proto, std::unique_ptr body) : proto(std::move(proto)), body(std::move(body)) {} @@ -236,7 +231,7 @@ class FunctionAST { class ModuleAST { std::vector functions; - public: +public: ModuleAST(std::vector functions) : functions(std::move(functions)) {} @@ -246,6 +241,6 @@ class ModuleAST { void dump(ModuleAST &); -} // namespace toy +} // namespace toy -#endif // TOY_AST_H +#endif // TOY_AST_H diff --git a/mlir/example/Ch1/include/toy/Lexer.h b/mlir/example/Ch1/include/toy/Lexer.h index ec0a1ae..ecbb3b4 100644 --- a/mlir/example/Ch1/include/toy/Lexer.h +++ b/mlir/example/Ch1/include/toy/Lexer.h @@ -13,18 +13,18 @@ #ifndef TOY_LEXER_H #define TOY_LEXER_H +#include "llvm/ADT/StringRef.h" + #include #include -#include "llvm/ADT/StringRef.h" - namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename. - int line; ///< line number. - int col; ///< column number. + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. }; // List of Token returned by the lexer. @@ -56,7 +56,7 @@ enum Token : int { /// can proceed by reading the next line from the standard input or from a /// memory mapped file. class Lexer { - public: +public: /// Create a lexer for the given filename. The filename is kept only for /// debugging purposes (attaching a location to a Token). Lexer(std::string filename) @@ -98,7 +98,7 @@ class Lexer { // Return the current column in the file. int getCol() { return curCol; } - private: +private: /// Delegate to a derived class fetching the next line. Returns an empty /// string to signal end of file (EOF). Lines are expected to always finish /// with "\n" @@ -109,11 +109,13 @@ class Lexer { /// needed. int getNextChar() { // The current line buffer should not be empty unless it is the end of file. - if (curLineBuffer.empty()) return EOF; + if (curLineBuffer.empty()) + return EOF; ++curCol; auto nextchar = curLineBuffer.front(); curLineBuffer = curLineBuffer.drop_front(); - if (curLineBuffer.empty()) curLineBuffer = readNextLine(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); if (nextchar == '\n') { ++curLineNum; curCol = 0; @@ -124,7 +126,8 @@ class Lexer { /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(lastChar)) lastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; @@ -136,9 +139,12 @@ class Lexer { while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') identifierStr += (char)lastChar; - if (identifierStr == "return") return tok_return; - if (identifierStr == "def") return tok_def; - if (identifierStr == "var") return tok_var; + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; return tok_identifier; } @@ -160,11 +166,13 @@ class Lexer { lastChar = Token(getNextChar()); } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (lastChar != EOF) return getTok(); + if (lastChar != EOF) + return getTok(); } // Check for end of file. Don't eat the EOF. - if (lastChar == EOF) return tok_eof; + if (lastChar == EOF) + return tok_eof; // Otherwise, just return the character as its ascii value. Token thisChar = Token(lastChar); @@ -201,22 +209,24 @@ class Lexer { /// A lexer implementation operating on a buffer in memory. class LexerBuffer final : public Lexer { - public: +public: LexerBuffer(const char *begin, const char *end, std::string filename) : Lexer(std::move(filename)), current(begin), end(end) {} - private: +private: /// Provide one line at a time to the Lexer, return an empty string when /// reaching the end of the buffer. llvm::StringRef readNextLine() override { auto *begin = current; - while (current <= end && *current && *current != '\n') ++current; - if (current <= end && *current) ++current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; llvm::StringRef result{begin, static_cast(current - begin)}; return result; } const char *current, *end; }; -} // namespace toy +} // namespace toy -#endif // TOY_LEXER_H +#endif // TOY_LEXER_H diff --git a/mlir/example/Ch1/include/toy/Parser.h b/mlir/example/Ch1/include/toy/Parser.h index ededa4c..1f20616 100644 --- a/mlir/example/Ch1/include/toy/Parser.h +++ b/mlir/example/Ch1/include/toy/Parser.h @@ -14,16 +14,17 @@ #ifndef TOY_PARSER_H #define TOY_PARSER_H -#include -#include -#include -#include +#include "toy/AST.h" +#include "toy/Lexer.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" -#include "toy/AST.h" -#include "toy/Lexer.h" + +#include +#include +#include +#include namespace toy { @@ -33,19 +34,20 @@ namespace toy { /// string and the code could reference an undeclared variable and the parsing /// succeeds. class Parser { - public: +public: /// Create a Parser for the supplied lexer. Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. std::unique_ptr parseModule() { - lexer.getNextToken(); // prime the lexer + lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector functions; while (auto f = parseDefinition()) { functions.push_back(std::move(*f)); - if (lexer.getCurToken() == tok_eof) break; + if (lexer.getCurToken() == tok_eof) + break; } // If we didn't reach EOF, there was an error during parsing if (lexer.getCurToken() != tok_eof) @@ -54,7 +56,7 @@ class Parser { return std::make_unique(std::move(functions)); } - private: +private: Lexer &lexer; /// Parse a return statement. @@ -67,7 +69,8 @@ class Parser { std::optional> expr; if (lexer.getCurToken() != ';') { expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; } return std::make_unique(std::move(loc), std::move(expr)); } @@ -97,7 +100,8 @@ class Parser { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { values.push_back(parseTensorLiteralExpr()); - if (!values.back()) return nullptr; // parse error in the nested array. + if (!values.back()) + return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError(" or [", "in literal expression"); @@ -105,17 +109,18 @@ class Parser { } // End of this list on ']' - if (lexer.getCurToken() == ']') break; + if (lexer.getCurToken() == ']') + break; // Elements are separated by a comma. if (lexer.getCurToken() != ',') return parseError("] or ,", "in literal expression"); - lexer.getNextToken(); // eat , + lexer.getNextToken(); // eat , } while (true); if (values.empty()) return parseError("", "to fill literal expression"); - lexer.getNextToken(); // eat ] + lexer.getNextToken(); // eat ] /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); @@ -151,9 +156,10 @@ class Parser { /// parenexpr ::= '(' expression ')' std::unique_ptr parseParenExpr() { - lexer.getNextToken(); // eat (. + lexer.getNextToken(); // eat (. auto v = parseExpression(); - if (!v) return nullptr; + if (!v) + return nullptr; if (lexer.getCurToken() != ')') return parseError(")", "to close expression with parentheses"); @@ -168,9 +174,9 @@ class Parser { std::string name(lexer.getId()); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. + lexer.getNextToken(); // eat identifier. - if (lexer.getCurToken() != '(') // Simple variable ref. + if (lexer.getCurToken() != '(') // Simple variable ref. return std::make_unique(std::move(loc), name); // This is a function call. @@ -183,7 +189,8 @@ class Parser { else return nullptr; - if (lexer.getCurToken() == ')') break; + if (lexer.getCurToken() == ')') + break; if (lexer.getCurToken() != ',') return parseError(", or )", "in argument list"); @@ -211,22 +218,22 @@ class Parser { /// ::= tensorliteral std::unique_ptr parsePrimary() { switch (lexer.getCurToken()) { - default: - llvm::errs() << "unknown token '" << lexer.getCurToken() - << "' when expecting an expression\n"; - return nullptr; - case tok_identifier: - return parseIdentifierExpr(); - case tok_number: - return parseNumberExpr(); - case '(': - return parseParenExpr(); - case '[': - return parseTensorLiteralExpr(); - case ';': - return nullptr; - case '}': - return nullptr; + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; } } @@ -242,7 +249,8 @@ class Parser { // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (tokPrec < exprPrec) return lhs; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. int binOp = lexer.getCurToken(); @@ -259,7 +267,8 @@ class Parser { int nextPrec = getTokPrecedence(); if (tokPrec < nextPrec) { rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; } // Merge lhs/RHS. @@ -271,7 +280,8 @@ class Parser { /// expression::= primary binop rhs std::unique_ptr parseExpression() { auto lhs = parsePrimary(); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; return parseBinOpRHS(0, std::move(lhs)); } @@ -281,19 +291,20 @@ class Parser { std::unique_ptr parseType() { if (lexer.getCurToken() != '<') return parseError("<", "to begin type"); - lexer.getNextToken(); // eat < + lexer.getNextToken(); // eat < auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); lexer.getNextToken(); - if (lexer.getCurToken() == ',') lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); } if (lexer.getCurToken() != '>') return parseError(">", "to end type"); - lexer.getNextToken(); // eat > + lexer.getNextToken(); // eat > return type; } @@ -305,21 +316,23 @@ class Parser { if (lexer.getCurToken() != tok_var) return parseError("var", "to begin declaration"); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat var + lexer.getNextToken(); // eat var if (lexer.getCurToken() != tok_identifier) return parseError("identified", "after 'var' declaration"); std::string id(lexer.getId()); - lexer.getNextToken(); // eat id + lexer.getNextToken(); // eat id - std::unique_ptr type; // Type is optional, it can be inferred + std::unique_ptr type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { type = parseType(); - if (!type) return nullptr; + if (!type) + return nullptr; } - if (!type) type = std::make_unique(); + if (!type) + type = std::make_unique(); lexer.consume(Token('=')); auto expr = parseExpression(); return std::make_unique(std::move(loc), std::move(id), @@ -340,23 +353,27 @@ class Parser { auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration auto varDecl = parseDeclaration(); - if (!varDecl) return nullptr; + if (!varDecl) + return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement auto ret = parseReturn(); - if (!ret) return nullptr; + if (!ret) + return nullptr; exprList->push_back(std::move(ret)); } else { // General expression auto expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; exprList->push_back(std::move(expr)); } // Ensure that elements are separated by a semicolon. @@ -364,7 +381,8 @@ class Parser { return parseError(";", "after expression"); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); } if (lexer.getCurToken() != '}') @@ -401,7 +419,8 @@ class Parser { lexer.consume(tok_identifier); auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); - if (lexer.getCurToken() != ',') break; + if (lexer.getCurToken() != ',') + break; lexer.consume(Token(',')); if (lexer.getCurToken() != tok_identifier) return parseError( @@ -423,7 +442,8 @@ class Parser { /// definition ::= prototype block std::unique_ptr parseDefinition() { auto proto = parsePrototype(); - if (!proto) return nullptr; + if (!proto) + return nullptr; if (auto block = parseBlock()) return std::make_unique(std::move(proto), std::move(block)); @@ -432,18 +452,19 @@ class Parser { /// Get the precedence of the pending binary operator token. int getTokPrecedence() { - if (!isascii(lexer.getCurToken())) return -1; + if (!isascii(lexer.getCurToken())) + return -1; // 1 is lowest precedence. switch (static_cast(lexer.getCurToken())) { - case '-': - return 20; - case '+': - return 20; - case '*': - return 40; - default: - return -1; + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; } } @@ -456,12 +477,13 @@ class Parser { llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " << lexer.getLastLocation().col << "): expected '" << expected << "' " << context << " but has Token " << curToken; - if (isprint(curToken)) llvm::errs() << " '" << (char)curToken << "'"; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; llvm::errs() << "\n"; return nullptr; } }; -} // namespace toy +} // namespace toy -#endif // TOY_PARSER_H +#endif // TOY_PARSER_H diff --git a/mlir/example/Ch1/parser/AST.cpp b/mlir/example/Ch1/parser/AST.cpp index e0bd2fd..2546f2a 100644 --- a/mlir/example/Ch1/parser/AST.cpp +++ b/mlir/example/Ch1/parser/AST.cpp @@ -12,9 +12,12 @@ #include "toy/AST.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include using namespace toy; @@ -31,10 +34,10 @@ struct Indent { /// Helper class that implement the AST tree traversal and print the nodes along /// the way. The only data member is the current indentation level. class ASTDumper { - public: +public: void dump(ModuleAST *node); - private: +private: void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); @@ -51,12 +54,13 @@ class ASTDumper { // Actually print spaces matching the current indentation level void indent() { - for (int i = 0; i < curIndent; i++) llvm::errs() << " "; + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; } int curIndent = 0; }; -} // namespace +} // namespace /// Return a formatted string for the location of any node template @@ -69,8 +73,8 @@ static std::string loc(T *node) { // Helper Macro to bump the indentation level and print the leading spaces for // the current indentations -#define INDENT() \ - Indent level_(curIndent); \ +#define INDENT() \ + Indent level_(curIndent); \ indent(); /// Dispatch to a generic expressions to the appropriate subclass using RTTI @@ -100,7 +104,8 @@ void ASTDumper::dump(VarDeclExprAST *varDecl) { void ASTDumper::dump(ExprASTList *exprList) { INDENT(); llvm::errs() << "Block {\n"; - for (auto &expr : *exprList) dump(expr.get()); + for (auto &expr : *exprList) + dump(expr.get()); indent(); llvm::errs() << "} // Block\n"; } @@ -153,7 +158,8 @@ void ASTDumper::dump(VariableExprAST *node) { void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (node->getExpr().has_value()) return dump(*node->getExpr()); + if (node->getExpr().has_value()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -173,7 +179,8 @@ void ASTDumper::dump(BinaryExprAST *node) { void ASTDumper::dump(CallExprAST *node) { INDENT(); llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; - for (auto &arg : node->getArgs()) dump(arg.get()); + for (auto &arg : node->getArgs()) + dump(arg.get()); indent(); llvm::errs() << "]\n"; } @@ -218,7 +225,8 @@ void ASTDumper::dump(FunctionAST *node) { void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &f : *node) dump(&f); + for (auto &f : *node) + dump(&f); } namespace toy { @@ -226,4 +234,4 @@ namespace toy { // Public API void dump(ModuleAST &module) { ASTDumper().dump(&module); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch1/toyc.cpp b/mlir/example/Ch1/toyc.cpp index 825228f..fb7b484 100644 --- a/mlir/example/Ch1/toyc.cpp +++ b/mlir/example/Ch1/toyc.cpp @@ -10,12 +10,18 @@ // //===----------------------------------------------------------------------===// +#include "toy/AST.h" +#include "toy/Lexer.h" +#include "toy/Parser.h" + #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" -#include "toy/Parser.h" +#include +#include +#include using namespace toy; namespace cl = llvm::cl; @@ -26,11 +32,11 @@ static cl::opt inputFilename(cl::Positional, cl::value_desc("filename")); namespace { enum Action { None, DumpAST }; -} // namespace +} // namespace -static cl::opt emitAction( - "emit", cl::desc("Select the kind of output desired"), - cl::values(clEnumValN(DumpAST, "ast", "output the AST dump"))); +static cl::opt + emitAction("emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump"))); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. std::unique_ptr parseInputFile(llvm::StringRef filename) { @@ -50,15 +56,15 @@ int main(int argc, char **argv) { cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); auto moduleAST = parseInputFile(inputFilename); - if (!moduleAST) return 1; + if (!moduleAST) + return 1; switch (emitAction) { - case Action::DumpAST: - dump(*moduleAST); - return 0; - default: - llvm::errs() - << "No action specified (parsing only?), use -emit=\n"; + case Action::DumpAST: + dump(*moduleAST); + return 0; + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; } return 0; diff --git a/mlir/example/Ch2/include/toy/AST.h b/mlir/example/Ch2/include/toy/AST.h index c9d1bdb..d2ba101 100644 --- a/mlir/example/Ch2/include/toy/AST.h +++ b/mlir/example/Ch2/include/toy/AST.h @@ -15,14 +15,14 @@ #ifndef TOY_AST_H #define TOY_AST_H -#include -#include -#include +#include "toy/Lexer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "toy/Lexer.h" +#include +#include +#include namespace toy { @@ -33,7 +33,7 @@ struct VarType { /// Base class for all expression nodes. class ExprAST { - public: +public: enum ExprASTKind { Expr_VarDecl, Expr_Return, @@ -53,7 +53,7 @@ class ExprAST { const Location &loc() { return location; } - private: +private: const ExprASTKind kind; Location location; }; @@ -65,7 +65,7 @@ using ExprASTList = std::vector>; class NumberExprAST : public ExprAST { double val; - public: +public: NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, std::move(loc)), val(val) {} @@ -80,11 +80,10 @@ class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; - public: +public: LiteralExprAST(Location loc, std::vector> values, std::vector dims) - : ExprAST(Expr_Literal, std::move(loc)), - values(std::move(values)), + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), dims(std::move(dims)) {} llvm::ArrayRef> getValues() { return values; } @@ -98,7 +97,7 @@ class LiteralExprAST : public ExprAST { class VariableExprAST : public ExprAST { std::string name; - public: +public: VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, std::move(loc)), name(name) {} @@ -114,13 +113,11 @@ class VarDeclExprAST : public ExprAST { VarType type; std::unique_ptr initVal; - public: +public: VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr initVal) - : ExprAST(Expr_VarDecl, std::move(loc)), - name(name), - type(std::move(type)), - initVal(std::move(initVal)) {} + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } @@ -134,12 +131,13 @@ class VarDeclExprAST : public ExprAST { class ReturnExprAST : public ExprAST { std::optional> expr; - public: +public: ReturnExprAST(Location loc, std::optional> expr) : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} std::optional getExpr() { - if (expr.has_value()) return expr->get(); + if (expr.has_value()) + return expr->get(); return std::nullopt; } @@ -152,16 +150,14 @@ class BinaryExprAST : public ExprAST { char op; std::unique_ptr lhs, rhs; - public: +public: char getOp() { return op; } ExprAST *getLHS() { return lhs.get(); } ExprAST *getRHS() { return rhs.get(); } BinaryExprAST(Location loc, char op, std::unique_ptr lhs, std::unique_ptr rhs) - : ExprAST(Expr_BinOp, std::move(loc)), - op(op), - lhs(std::move(lhs)), + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} /// LLVM style RTTI @@ -173,11 +169,10 @@ class CallExprAST : public ExprAST { std::string callee; std::vector> args; - public: +public: CallExprAST(Location loc, const std::string &callee, std::vector> args) - : ExprAST(Expr_Call, std::move(loc)), - callee(callee), + : ExprAST(Expr_Call, std::move(loc)), callee(callee), args(std::move(args)) {} llvm::StringRef getCallee() { return callee; } @@ -191,7 +186,7 @@ class CallExprAST : public ExprAST { class PrintExprAST : public ExprAST { std::unique_ptr arg; - public: +public: PrintExprAST(Location loc, std::unique_ptr arg) : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} @@ -209,7 +204,7 @@ class PrototypeAST { std::string name; std::vector> args; - public: +public: PrototypeAST(Location location, const std::string &name, std::vector> args) : location(std::move(location)), name(name), args(std::move(args)) {} @@ -224,7 +219,7 @@ class FunctionAST { std::unique_ptr proto; std::unique_ptr body; - public: +public: FunctionAST(std::unique_ptr proto, std::unique_ptr body) : proto(std::move(proto)), body(std::move(body)) {} @@ -236,7 +231,7 @@ class FunctionAST { class ModuleAST { std::vector functions; - public: +public: ModuleAST(std::vector functions) : functions(std::move(functions)) {} @@ -246,6 +241,6 @@ class ModuleAST { void dump(ModuleAST &); -} // namespace toy +} // namespace toy -#endif // TOY_AST_H +#endif // TOY_AST_H diff --git a/mlir/example/Ch2/include/toy/Dialect.h b/mlir/example/Ch2/include/toy/Dialect.h index 78f9a63..292f50f 100644 --- a/mlir/example/Ch2/include/toy/Dialect.h +++ b/mlir/example/Ch2/include/toy/Dialect.h @@ -14,10 +14,11 @@ #ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ #define MLIR_TUTORIAL_TOY_DIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" /// Include the auto-generated header file containing the declaration of the toy @@ -29,4 +30,4 @@ #define GET_OP_CLASSES #include "toy/Ops.h.inc" -#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/example/Ch2/include/toy/Lexer.h b/mlir/example/Ch2/include/toy/Lexer.h index 176a40c..3c59cd9 100644 --- a/mlir/example/Ch2/include/toy/Lexer.h +++ b/mlir/example/Ch2/include/toy/Lexer.h @@ -13,18 +13,18 @@ #ifndef TOY_LEXER_H #define TOY_LEXER_H +#include "llvm/ADT/StringRef.h" + #include #include -#include "llvm/ADT/StringRef.h" - namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename. - int line; ///< line number. - int col; ///< column number. + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. }; // List of Token returned by the lexer. @@ -56,7 +56,7 @@ enum Token : int { /// can proceed by reading the next line from the standard input or from a /// memory mapped file. class Lexer { - public: +public: /// Create a lexer for the given filename. The filename is kept only for /// debugging purpose (attaching a location to a Token). Lexer(std::string filename) @@ -98,7 +98,7 @@ class Lexer { // Return the current column in the file. int getCol() { return curCol; } - private: +private: /// Delegate to a derived class fetching the next line. Returns an empty /// string to signal end of file (EOF). Lines are expected to always finish /// with "\n" @@ -109,11 +109,13 @@ class Lexer { /// needed. int getNextChar() { // The current line buffer should not be empty unless it is the end of file. - if (curLineBuffer.empty()) return EOF; + if (curLineBuffer.empty()) + return EOF; ++curCol; auto nextchar = curLineBuffer.front(); curLineBuffer = curLineBuffer.drop_front(); - if (curLineBuffer.empty()) curLineBuffer = readNextLine(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); if (nextchar == '\n') { ++curLineNum; curCol = 0; @@ -124,7 +126,8 @@ class Lexer { /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(lastChar)) lastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; @@ -136,9 +139,12 @@ class Lexer { while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') identifierStr += (char)lastChar; - if (identifierStr == "return") return tok_return; - if (identifierStr == "def") return tok_def; - if (identifierStr == "var") return tok_var; + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; return tok_identifier; } @@ -160,11 +166,13 @@ class Lexer { lastChar = Token(getNextChar()); } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (lastChar != EOF) return getTok(); + if (lastChar != EOF) + return getTok(); } // Check for end of file. Don't eat the EOF. - if (lastChar == EOF) return tok_eof; + if (lastChar == EOF) + return tok_eof; // Otherwise, just return the character as its ascii value. Token thisChar = Token(lastChar); @@ -201,22 +209,24 @@ class Lexer { /// A lexer implementation operating on a buffer in memory. class LexerBuffer final : public Lexer { - public: +public: LexerBuffer(const char *begin, const char *end, std::string filename) : Lexer(std::move(filename)), current(begin), end(end) {} - private: +private: /// Provide one line at a time to the Lexer, return an empty string when /// reaching the end of the buffer. llvm::StringRef readNextLine() override { auto *begin = current; - while (current <= end && *current && *current != '\n') ++current; - if (current <= end && *current) ++current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; llvm::StringRef result{begin, static_cast(current - begin)}; return result; } const char *current, *end; }; -} // namespace toy +} // namespace toy -#endif // TOY_LEXER_H +#endif // TOY_LEXER_H diff --git a/mlir/example/Ch2/include/toy/MLIRGen.h b/mlir/example/Ch2/include/toy/MLIRGen.h index 5afdf62..fe9dbe5 100644 --- a/mlir/example/Ch2/include/toy/MLIRGen.h +++ b/mlir/example/Ch2/include/toy/MLIRGen.h @@ -21,7 +21,7 @@ class MLIRContext; template class OwningOpRef; class ModuleOp; -} // namespace mlir +} // namespace mlir namespace toy { class ModuleAST; @@ -30,6 +30,6 @@ class ModuleAST; /// or nullptr on failure. mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); -} // namespace toy +} // namespace toy -#endif // TOY_MLIRGEN_H +#endif // TOY_MLIRGEN_H diff --git a/mlir/example/Ch2/include/toy/Ops.td b/mlir/example/Ch2/include/toy/Ops.td index cba08eb..1a1b136 100644 --- a/mlir/example/Ch2/include/toy/Ops.td +++ b/mlir/example/Ch2/include/toy/Ops.td @@ -14,7 +14,7 @@ #define TOY_OPS include "mlir/IR/OpBase.td" -include "mlir/IR/FunctionInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -144,6 +144,7 @@ def FuncOp : Toy_Op<"func", [ "StringRef":$name, "FunctionType":$type, CArg<"ArrayRef", "{}">:$attrs) >]; + let extraClassDeclaration = [{ //===------------------------------------------------------------------===// // FunctionOpInterface Methods @@ -154,7 +155,10 @@ def FuncOp : Toy_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } }]; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; } diff --git a/mlir/example/Ch2/include/toy/Parser.h b/mlir/example/Ch2/include/toy/Parser.h index ededa4c..1f20616 100644 --- a/mlir/example/Ch2/include/toy/Parser.h +++ b/mlir/example/Ch2/include/toy/Parser.h @@ -14,16 +14,17 @@ #ifndef TOY_PARSER_H #define TOY_PARSER_H -#include -#include -#include -#include +#include "toy/AST.h" +#include "toy/Lexer.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" -#include "toy/AST.h" -#include "toy/Lexer.h" + +#include +#include +#include +#include namespace toy { @@ -33,19 +34,20 @@ namespace toy { /// string and the code could reference an undeclared variable and the parsing /// succeeds. class Parser { - public: +public: /// Create a Parser for the supplied lexer. Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. std::unique_ptr parseModule() { - lexer.getNextToken(); // prime the lexer + lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector functions; while (auto f = parseDefinition()) { functions.push_back(std::move(*f)); - if (lexer.getCurToken() == tok_eof) break; + if (lexer.getCurToken() == tok_eof) + break; } // If we didn't reach EOF, there was an error during parsing if (lexer.getCurToken() != tok_eof) @@ -54,7 +56,7 @@ class Parser { return std::make_unique(std::move(functions)); } - private: +private: Lexer &lexer; /// Parse a return statement. @@ -67,7 +69,8 @@ class Parser { std::optional> expr; if (lexer.getCurToken() != ';') { expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; } return std::make_unique(std::move(loc), std::move(expr)); } @@ -97,7 +100,8 @@ class Parser { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { values.push_back(parseTensorLiteralExpr()); - if (!values.back()) return nullptr; // parse error in the nested array. + if (!values.back()) + return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError(" or [", "in literal expression"); @@ -105,17 +109,18 @@ class Parser { } // End of this list on ']' - if (lexer.getCurToken() == ']') break; + if (lexer.getCurToken() == ']') + break; // Elements are separated by a comma. if (lexer.getCurToken() != ',') return parseError("] or ,", "in literal expression"); - lexer.getNextToken(); // eat , + lexer.getNextToken(); // eat , } while (true); if (values.empty()) return parseError("", "to fill literal expression"); - lexer.getNextToken(); // eat ] + lexer.getNextToken(); // eat ] /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); @@ -151,9 +156,10 @@ class Parser { /// parenexpr ::= '(' expression ')' std::unique_ptr parseParenExpr() { - lexer.getNextToken(); // eat (. + lexer.getNextToken(); // eat (. auto v = parseExpression(); - if (!v) return nullptr; + if (!v) + return nullptr; if (lexer.getCurToken() != ')') return parseError(")", "to close expression with parentheses"); @@ -168,9 +174,9 @@ class Parser { std::string name(lexer.getId()); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. + lexer.getNextToken(); // eat identifier. - if (lexer.getCurToken() != '(') // Simple variable ref. + if (lexer.getCurToken() != '(') // Simple variable ref. return std::make_unique(std::move(loc), name); // This is a function call. @@ -183,7 +189,8 @@ class Parser { else return nullptr; - if (lexer.getCurToken() == ')') break; + if (lexer.getCurToken() == ')') + break; if (lexer.getCurToken() != ',') return parseError(", or )", "in argument list"); @@ -211,22 +218,22 @@ class Parser { /// ::= tensorliteral std::unique_ptr parsePrimary() { switch (lexer.getCurToken()) { - default: - llvm::errs() << "unknown token '" << lexer.getCurToken() - << "' when expecting an expression\n"; - return nullptr; - case tok_identifier: - return parseIdentifierExpr(); - case tok_number: - return parseNumberExpr(); - case '(': - return parseParenExpr(); - case '[': - return parseTensorLiteralExpr(); - case ';': - return nullptr; - case '}': - return nullptr; + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; } } @@ -242,7 +249,8 @@ class Parser { // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (tokPrec < exprPrec) return lhs; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. int binOp = lexer.getCurToken(); @@ -259,7 +267,8 @@ class Parser { int nextPrec = getTokPrecedence(); if (tokPrec < nextPrec) { rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; } // Merge lhs/RHS. @@ -271,7 +280,8 @@ class Parser { /// expression::= primary binop rhs std::unique_ptr parseExpression() { auto lhs = parsePrimary(); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; return parseBinOpRHS(0, std::move(lhs)); } @@ -281,19 +291,20 @@ class Parser { std::unique_ptr parseType() { if (lexer.getCurToken() != '<') return parseError("<", "to begin type"); - lexer.getNextToken(); // eat < + lexer.getNextToken(); // eat < auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); lexer.getNextToken(); - if (lexer.getCurToken() == ',') lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); } if (lexer.getCurToken() != '>') return parseError(">", "to end type"); - lexer.getNextToken(); // eat > + lexer.getNextToken(); // eat > return type; } @@ -305,21 +316,23 @@ class Parser { if (lexer.getCurToken() != tok_var) return parseError("var", "to begin declaration"); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat var + lexer.getNextToken(); // eat var if (lexer.getCurToken() != tok_identifier) return parseError("identified", "after 'var' declaration"); std::string id(lexer.getId()); - lexer.getNextToken(); // eat id + lexer.getNextToken(); // eat id - std::unique_ptr type; // Type is optional, it can be inferred + std::unique_ptr type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { type = parseType(); - if (!type) return nullptr; + if (!type) + return nullptr; } - if (!type) type = std::make_unique(); + if (!type) + type = std::make_unique(); lexer.consume(Token('=')); auto expr = parseExpression(); return std::make_unique(std::move(loc), std::move(id), @@ -340,23 +353,27 @@ class Parser { auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration auto varDecl = parseDeclaration(); - if (!varDecl) return nullptr; + if (!varDecl) + return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement auto ret = parseReturn(); - if (!ret) return nullptr; + if (!ret) + return nullptr; exprList->push_back(std::move(ret)); } else { // General expression auto expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; exprList->push_back(std::move(expr)); } // Ensure that elements are separated by a semicolon. @@ -364,7 +381,8 @@ class Parser { return parseError(";", "after expression"); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); } if (lexer.getCurToken() != '}') @@ -401,7 +419,8 @@ class Parser { lexer.consume(tok_identifier); auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); - if (lexer.getCurToken() != ',') break; + if (lexer.getCurToken() != ',') + break; lexer.consume(Token(',')); if (lexer.getCurToken() != tok_identifier) return parseError( @@ -423,7 +442,8 @@ class Parser { /// definition ::= prototype block std::unique_ptr parseDefinition() { auto proto = parsePrototype(); - if (!proto) return nullptr; + if (!proto) + return nullptr; if (auto block = parseBlock()) return std::make_unique(std::move(proto), std::move(block)); @@ -432,18 +452,19 @@ class Parser { /// Get the precedence of the pending binary operator token. int getTokPrecedence() { - if (!isascii(lexer.getCurToken())) return -1; + if (!isascii(lexer.getCurToken())) + return -1; // 1 is lowest precedence. switch (static_cast(lexer.getCurToken())) { - case '-': - return 20; - case '+': - return 20; - case '*': - return 40; - default: - return -1; + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; } } @@ -456,12 +477,13 @@ class Parser { llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " << lexer.getLastLocation().col << "): expected '" << expected << "' " << context << " but has Token " << curToken; - if (isprint(curToken)) llvm::errs() << " '" << (char)curToken << "'"; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; llvm::errs() << "\n"; return nullptr; } }; -} // namespace toy +} // namespace toy -#endif // TOY_PARSER_H +#endif // TOY_PARSER_H diff --git a/mlir/example/Ch2/mlir/Dialect.cpp b/mlir/example/Ch2/mlir/Dialect.cpp index cd1737d..489f348 100644 --- a/mlir/example/Ch2/mlir/Dialect.cpp +++ b/mlir/example/Ch2/mlir/Dialect.cpp @@ -13,10 +13,21 @@ #include "toy/Dialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include using namespace mlir; using namespace mlir::toy; @@ -130,11 +141,10 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { /// Verifier for the constant operation. This corresponds to the /// `let hasVerifier = 1` in the op definition. -mlir::LogicalResult ConstantOp::verify() { +llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = - llvm::dyn_cast(getResult().getType()); + auto resultType = llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -246,7 +256,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } // ReturnOp //===----------------------------------------------------------------------===// -mlir::LogicalResult ReturnOp::verify() { +llvm::LogicalResult ReturnOp::verify() { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. auto function = cast((*this)->getParentOp()); @@ -270,8 +280,7 @@ mlir::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || - llvm::isa(inputType) || + if (inputType == resultType || llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); @@ -290,7 +299,7 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, state.addOperands(value); } -mlir::LogicalResult TransposeOp::verify() { +llvm::LogicalResult TransposeOp::verify() { auto inputType = llvm::dyn_cast(getOperand().getType()); auto resultType = llvm::dyn_cast(getType()); if (!inputType || !resultType) diff --git a/mlir/example/Ch2/mlir/MLIRGen.cpp b/mlir/example/Ch2/mlir/MLIRGen.cpp index 589bd3a..bf4c099 100644 --- a/mlir/example/Ch2/mlir/MLIRGen.cpp +++ b/mlir/example/Ch2/mlir/MLIRGen.cpp @@ -12,20 +12,30 @@ //===----------------------------------------------------------------------===// #include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "toy/AST.h" +#include "toy/Dialect.h" -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" -#include "toy/AST.h" -#include "toy/Dialect.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include using namespace mlir::toy; using namespace toy; @@ -47,7 +57,7 @@ namespace { /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. class MLIRGenImpl { - public: +public: MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR @@ -57,7 +67,8 @@ class MLIRGenImpl { // add them to the module. theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); - for (FunctionAST &f : moduleAST) mlirGen(f); + for (FunctionAST &f : moduleAST) + mlirGen(f); // Verify the module after we have finished constructing it, this will check // the structural properties of the IR and invoke any specific verifiers we @@ -70,7 +81,7 @@ class MLIRGenImpl { return theModule; } - private: +private: /// A "module" matches a Toy source file: containing a list of functions. mlir::ModuleOp theModule; @@ -93,8 +104,9 @@ class MLIRGenImpl { /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { - if (symbolTable.count(var)) return mlir::failure(); + llvm::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); symbolTable.insert(var, value); return mlir::success(); } @@ -121,7 +133,8 @@ class MLIRGenImpl { // Create an MLIR function for the given prototype. builder.setInsertionPointToEnd(theModule.getBody()); mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); - if (!function) return nullptr; + if (!function) + return nullptr; // Let's start the body of the function now! mlir::Block &entryBlock = function.front(); @@ -150,7 +163,8 @@ class MLIRGenImpl { // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) ReturnOp returnOp; - if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { builder.create(loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { @@ -177,18 +191,20 @@ class MLIRGenImpl { // and propagate. // mlir::Value lhs = mlirGen(*binop.getLHS()); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; mlir::Value rhs = mlirGen(*binop.getRHS()); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; auto location = loc(binop.loc()); // Derive the operation name from the binary operator. At the moment we only // support '+' and '*'. switch (binop.getOp()) { - case '+': - return builder.create(location, lhs, rhs); - case '*': - return builder.create(location, lhs, rhs); + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -199,7 +215,8 @@ class MLIRGenImpl { /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. mlir::Value mlirGen(VariableExprAST &expr) { - if (auto variable = symbolTable.lookup(expr.getName())) return variable; + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'"; @@ -207,13 +224,14 @@ class MLIRGenImpl { } /// Emit a return operation. This will return failure if any generation fails. - mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + llvm::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. mlir::Value expr = nullptr; if (ret.getExpr().has_value()) { - if (!(expr = mlirGen(**ret.getExpr()))) return mlir::failure(); + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); } // Otherwise, this return operation has zero operands. @@ -275,7 +293,8 @@ class MLIRGenImpl { /// Attributes are the way MLIR attaches constant to operations. void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { - for (auto &value : lit->getValues()) collectData(*value, data); + for (auto &value : lit->getValues()) + collectData(*value, data); return; } @@ -293,7 +312,8 @@ class MLIRGenImpl { SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); - if (!arg) return nullptr; + if (!arg) + return nullptr; operands.push_back(arg); } @@ -301,9 +321,8 @@ class MLIRGenImpl { // straightforward emission. if (callee == "transpose") { if (call.getArgs().size() != 1) { - emitError(location, - "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); return nullptr; } return builder.create(location, operands[0]); @@ -317,9 +336,10 @@ class MLIRGenImpl { /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). - mlir::LogicalResult mlirGen(PrintExprAST &call) { + llvm::LogicalResult mlirGen(PrintExprAST &call) { auto arg = mlirGen(*call.getArg()); - if (!arg) return mlir::failure(); + if (!arg) + return mlir::failure(); builder.create(loc(call.loc()), arg); return mlir::success(); @@ -333,21 +353,21 @@ class MLIRGenImpl { /// Dispatch codegen for the right expression subclass using RTTI. mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast(expr)); - default: - emitError(loc(expr.loc())) - << "MLIR codegen encountered an unhandled expr kind '" - << Twine(expr.getKind()) << "'"; - return nullptr; + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; } } @@ -364,7 +384,8 @@ class MLIRGenImpl { } mlir::Value value = mlirGen(*init); - if (!value) return nullptr; + if (!value) + return nullptr; // We have the initializer value, but in case the variable was declared // with specific shape, we emit a "reshape" operation. It will get @@ -375,29 +396,34 @@ class MLIRGenImpl { } // Register the value in the symbol table. - if (failed(declare(vardecl.getName(), value))) return nullptr; + if (failed(declare(vardecl.getName(), value))) + return nullptr; return value; } /// Codegen a list of expression, return failure if one of them hit an error. - mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + llvm::LogicalResult mlirGen(ExprASTList &blockAST) { ScopedHashTableScope varScope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { - if (!mlirGen(*vardecl)) return mlir::failure(); + if (!mlirGen(*vardecl)) + return mlir::failure(); continue; } - if (auto *ret = dyn_cast(expr.get())) return mlirGen(*ret); + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); if (auto *print = dyn_cast(expr.get())) { - if (mlir::failed(mlirGen(*print))) return mlir::success(); + if (mlir::failed(mlirGen(*print))) + return mlir::success(); continue; } // Generic expression dispatch codegen. - if (!mlirGen(*expr)) return mlir::failure(); + if (!mlirGen(*expr)) + return mlir::failure(); } return mlir::success(); } @@ -417,7 +443,7 @@ class MLIRGenImpl { mlir::Type getType(const VarType &type) { return getType(type.shape); } }; -} // namespace +} // namespace namespace toy { @@ -427,4 +453,4 @@ mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, return MLIRGenImpl(context).mlirGen(moduleAST); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch2/parser/AST.cpp b/mlir/example/Ch2/parser/AST.cpp index e0bd2fd..2546f2a 100644 --- a/mlir/example/Ch2/parser/AST.cpp +++ b/mlir/example/Ch2/parser/AST.cpp @@ -12,9 +12,12 @@ #include "toy/AST.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include using namespace toy; @@ -31,10 +34,10 @@ struct Indent { /// Helper class that implement the AST tree traversal and print the nodes along /// the way. The only data member is the current indentation level. class ASTDumper { - public: +public: void dump(ModuleAST *node); - private: +private: void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); @@ -51,12 +54,13 @@ class ASTDumper { // Actually print spaces matching the current indentation level void indent() { - for (int i = 0; i < curIndent; i++) llvm::errs() << " "; + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; } int curIndent = 0; }; -} // namespace +} // namespace /// Return a formatted string for the location of any node template @@ -69,8 +73,8 @@ static std::string loc(T *node) { // Helper Macro to bump the indentation level and print the leading spaces for // the current indentations -#define INDENT() \ - Indent level_(curIndent); \ +#define INDENT() \ + Indent level_(curIndent); \ indent(); /// Dispatch to a generic expressions to the appropriate subclass using RTTI @@ -100,7 +104,8 @@ void ASTDumper::dump(VarDeclExprAST *varDecl) { void ASTDumper::dump(ExprASTList *exprList) { INDENT(); llvm::errs() << "Block {\n"; - for (auto &expr : *exprList) dump(expr.get()); + for (auto &expr : *exprList) + dump(expr.get()); indent(); llvm::errs() << "} // Block\n"; } @@ -153,7 +158,8 @@ void ASTDumper::dump(VariableExprAST *node) { void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (node->getExpr().has_value()) return dump(*node->getExpr()); + if (node->getExpr().has_value()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -173,7 +179,8 @@ void ASTDumper::dump(BinaryExprAST *node) { void ASTDumper::dump(CallExprAST *node) { INDENT(); llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; - for (auto &arg : node->getArgs()) dump(arg.get()); + for (auto &arg : node->getArgs()) + dump(arg.get()); indent(); llvm::errs() << "]\n"; } @@ -218,7 +225,8 @@ void ASTDumper::dump(FunctionAST *node) { void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &f : *node) dump(&f); + for (auto &f : *node) + dump(&f); } namespace toy { @@ -226,4 +234,4 @@ namespace toy { // Public API void dump(ModuleAST &module) { ASTDumper().dump(&module); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch2/toyc.cpp b/mlir/example/Ch2/toyc.cpp index fe62831..e33b49b 100644 --- a/mlir/example/Ch2/toyc.cpp +++ b/mlir/example/Ch2/toyc.cpp @@ -10,7 +10,20 @@ // //===----------------------------------------------------------------------===// +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" #include +#include +#include +#include + +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" @@ -18,14 +31,6 @@ #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Verifier.h" -#include "mlir/Parser/Parser.h" -#include "toy/Dialect.h" -#include "toy/MLIRGen.h" -#include "toy/Parser.h" using namespace toy; namespace cl = llvm::cl; @@ -37,7 +42,7 @@ static cl::opt inputFilename(cl::Positional, namespace { enum InputType { Toy, MLIR }; -} // namespace +} // namespace static cl::opt inputType( "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), @@ -46,7 +51,7 @@ static cl::opt inputType( namespace { enum Action { None, DumpAST, DumpMLIR }; -} // namespace +} // namespace static cl::opt emitAction( "emit", cl::desc("Select the kind of output desired"), cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), @@ -73,11 +78,13 @@ int dumpMLIR() { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && - !llvm::StringRef(inputFilename).endswith(".mlir")) { + !llvm::StringRef(inputFilename).ends_with(".mlir")) { auto moduleAST = parseInputFile(inputFilename); - if (!moduleAST) return 6; + if (!moduleAST) + return 6; mlir::OwningOpRef module = mlirGen(context, *moduleAST); - if (!module) return 1; + if (!module) + return 1; module->dump(); return 0; @@ -112,7 +119,8 @@ int dumpAST() { } auto moduleAST = parseInputFile(inputFilename); - if (!moduleAST) return 1; + if (!moduleAST) + return 1; dump(*moduleAST); return 0; @@ -125,13 +133,12 @@ int main(int argc, char **argv) { cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); switch (emitAction) { - case Action::DumpAST: - return dumpAST(); - case Action::DumpMLIR: - return dumpMLIR(); - default: - llvm::errs() - << "No action specified (parsing only?), use -emit=\n"; + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; } return 0; diff --git a/mlir/example/Ch3/include/toy/AST.h b/mlir/example/Ch3/include/toy/AST.h index c9d1bdb..d2ba101 100644 --- a/mlir/example/Ch3/include/toy/AST.h +++ b/mlir/example/Ch3/include/toy/AST.h @@ -15,14 +15,14 @@ #ifndef TOY_AST_H #define TOY_AST_H -#include -#include -#include +#include "toy/Lexer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "toy/Lexer.h" +#include +#include +#include namespace toy { @@ -33,7 +33,7 @@ struct VarType { /// Base class for all expression nodes. class ExprAST { - public: +public: enum ExprASTKind { Expr_VarDecl, Expr_Return, @@ -53,7 +53,7 @@ class ExprAST { const Location &loc() { return location; } - private: +private: const ExprASTKind kind; Location location; }; @@ -65,7 +65,7 @@ using ExprASTList = std::vector>; class NumberExprAST : public ExprAST { double val; - public: +public: NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, std::move(loc)), val(val) {} @@ -80,11 +80,10 @@ class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; - public: +public: LiteralExprAST(Location loc, std::vector> values, std::vector dims) - : ExprAST(Expr_Literal, std::move(loc)), - values(std::move(values)), + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), dims(std::move(dims)) {} llvm::ArrayRef> getValues() { return values; } @@ -98,7 +97,7 @@ class LiteralExprAST : public ExprAST { class VariableExprAST : public ExprAST { std::string name; - public: +public: VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, std::move(loc)), name(name) {} @@ -114,13 +113,11 @@ class VarDeclExprAST : public ExprAST { VarType type; std::unique_ptr initVal; - public: +public: VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr initVal) - : ExprAST(Expr_VarDecl, std::move(loc)), - name(name), - type(std::move(type)), - initVal(std::move(initVal)) {} + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } @@ -134,12 +131,13 @@ class VarDeclExprAST : public ExprAST { class ReturnExprAST : public ExprAST { std::optional> expr; - public: +public: ReturnExprAST(Location loc, std::optional> expr) : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} std::optional getExpr() { - if (expr.has_value()) return expr->get(); + if (expr.has_value()) + return expr->get(); return std::nullopt; } @@ -152,16 +150,14 @@ class BinaryExprAST : public ExprAST { char op; std::unique_ptr lhs, rhs; - public: +public: char getOp() { return op; } ExprAST *getLHS() { return lhs.get(); } ExprAST *getRHS() { return rhs.get(); } BinaryExprAST(Location loc, char op, std::unique_ptr lhs, std::unique_ptr rhs) - : ExprAST(Expr_BinOp, std::move(loc)), - op(op), - lhs(std::move(lhs)), + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} /// LLVM style RTTI @@ -173,11 +169,10 @@ class CallExprAST : public ExprAST { std::string callee; std::vector> args; - public: +public: CallExprAST(Location loc, const std::string &callee, std::vector> args) - : ExprAST(Expr_Call, std::move(loc)), - callee(callee), + : ExprAST(Expr_Call, std::move(loc)), callee(callee), args(std::move(args)) {} llvm::StringRef getCallee() { return callee; } @@ -191,7 +186,7 @@ class CallExprAST : public ExprAST { class PrintExprAST : public ExprAST { std::unique_ptr arg; - public: +public: PrintExprAST(Location loc, std::unique_ptr arg) : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} @@ -209,7 +204,7 @@ class PrototypeAST { std::string name; std::vector> args; - public: +public: PrototypeAST(Location location, const std::string &name, std::vector> args) : location(std::move(location)), name(name), args(std::move(args)) {} @@ -224,7 +219,7 @@ class FunctionAST { std::unique_ptr proto; std::unique_ptr body; - public: +public: FunctionAST(std::unique_ptr proto, std::unique_ptr body) : proto(std::move(proto)), body(std::move(body)) {} @@ -236,7 +231,7 @@ class FunctionAST { class ModuleAST { std::vector functions; - public: +public: ModuleAST(std::vector functions) : functions(std::move(functions)) {} @@ -246,6 +241,6 @@ class ModuleAST { void dump(ModuleAST &); -} // namespace toy +} // namespace toy -#endif // TOY_AST_H +#endif // TOY_AST_H diff --git a/mlir/example/Ch3/include/toy/Dialect.h b/mlir/example/Ch3/include/toy/Dialect.h index 78f9a63..292f50f 100644 --- a/mlir/example/Ch3/include/toy/Dialect.h +++ b/mlir/example/Ch3/include/toy/Dialect.h @@ -14,10 +14,11 @@ #ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ #define MLIR_TUTORIAL_TOY_DIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" /// Include the auto-generated header file containing the declaration of the toy @@ -29,4 +30,4 @@ #define GET_OP_CLASSES #include "toy/Ops.h.inc" -#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/example/Ch3/include/toy/Lexer.h b/mlir/example/Ch3/include/toy/Lexer.h index 176a40c..3c59cd9 100644 --- a/mlir/example/Ch3/include/toy/Lexer.h +++ b/mlir/example/Ch3/include/toy/Lexer.h @@ -13,18 +13,18 @@ #ifndef TOY_LEXER_H #define TOY_LEXER_H +#include "llvm/ADT/StringRef.h" + #include #include -#include "llvm/ADT/StringRef.h" - namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename. - int line; ///< line number. - int col; ///< column number. + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. }; // List of Token returned by the lexer. @@ -56,7 +56,7 @@ enum Token : int { /// can proceed by reading the next line from the standard input or from a /// memory mapped file. class Lexer { - public: +public: /// Create a lexer for the given filename. The filename is kept only for /// debugging purpose (attaching a location to a Token). Lexer(std::string filename) @@ -98,7 +98,7 @@ class Lexer { // Return the current column in the file. int getCol() { return curCol; } - private: +private: /// Delegate to a derived class fetching the next line. Returns an empty /// string to signal end of file (EOF). Lines are expected to always finish /// with "\n" @@ -109,11 +109,13 @@ class Lexer { /// needed. int getNextChar() { // The current line buffer should not be empty unless it is the end of file. - if (curLineBuffer.empty()) return EOF; + if (curLineBuffer.empty()) + return EOF; ++curCol; auto nextchar = curLineBuffer.front(); curLineBuffer = curLineBuffer.drop_front(); - if (curLineBuffer.empty()) curLineBuffer = readNextLine(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); if (nextchar == '\n') { ++curLineNum; curCol = 0; @@ -124,7 +126,8 @@ class Lexer { /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(lastChar)) lastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; @@ -136,9 +139,12 @@ class Lexer { while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') identifierStr += (char)lastChar; - if (identifierStr == "return") return tok_return; - if (identifierStr == "def") return tok_def; - if (identifierStr == "var") return tok_var; + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; return tok_identifier; } @@ -160,11 +166,13 @@ class Lexer { lastChar = Token(getNextChar()); } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (lastChar != EOF) return getTok(); + if (lastChar != EOF) + return getTok(); } // Check for end of file. Don't eat the EOF. - if (lastChar == EOF) return tok_eof; + if (lastChar == EOF) + return tok_eof; // Otherwise, just return the character as its ascii value. Token thisChar = Token(lastChar); @@ -201,22 +209,24 @@ class Lexer { /// A lexer implementation operating on a buffer in memory. class LexerBuffer final : public Lexer { - public: +public: LexerBuffer(const char *begin, const char *end, std::string filename) : Lexer(std::move(filename)), current(begin), end(end) {} - private: +private: /// Provide one line at a time to the Lexer, return an empty string when /// reaching the end of the buffer. llvm::StringRef readNextLine() override { auto *begin = current; - while (current <= end && *current && *current != '\n') ++current; - if (current <= end && *current) ++current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; llvm::StringRef result{begin, static_cast(current - begin)}; return result; } const char *current, *end; }; -} // namespace toy +} // namespace toy -#endif // TOY_LEXER_H +#endif // TOY_LEXER_H diff --git a/mlir/example/Ch3/include/toy/MLIRGen.h b/mlir/example/Ch3/include/toy/MLIRGen.h index 5afdf62..fe9dbe5 100644 --- a/mlir/example/Ch3/include/toy/MLIRGen.h +++ b/mlir/example/Ch3/include/toy/MLIRGen.h @@ -21,7 +21,7 @@ class MLIRContext; template class OwningOpRef; class ModuleOp; -} // namespace mlir +} // namespace mlir namespace toy { class ModuleAST; @@ -30,6 +30,6 @@ class ModuleAST; /// or nullptr on failure. mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); -} // namespace toy +} // namespace toy -#endif // TOY_MLIRGEN_H +#endif // TOY_MLIRGEN_H diff --git a/mlir/example/Ch3/include/toy/Ops.td b/mlir/example/Ch3/include/toy/Ops.td index b7add5a..021802b 100644 --- a/mlir/example/Ch3/include/toy/Ops.td +++ b/mlir/example/Ch3/include/toy/Ops.td @@ -13,7 +13,7 @@ #ifndef TOY_OPS #define TOY_OPS -include "mlir/IR/FunctionInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -153,6 +153,9 @@ def FuncOp : Toy_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the current operation that is callable. + ::mlir::Region *getCallableRegion() { return &getBody(); } }]; let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; diff --git a/mlir/example/Ch3/include/toy/Parser.h b/mlir/example/Ch3/include/toy/Parser.h index ededa4c..1f20616 100644 --- a/mlir/example/Ch3/include/toy/Parser.h +++ b/mlir/example/Ch3/include/toy/Parser.h @@ -14,16 +14,17 @@ #ifndef TOY_PARSER_H #define TOY_PARSER_H -#include -#include -#include -#include +#include "toy/AST.h" +#include "toy/Lexer.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" -#include "toy/AST.h" -#include "toy/Lexer.h" + +#include +#include +#include +#include namespace toy { @@ -33,19 +34,20 @@ namespace toy { /// string and the code could reference an undeclared variable and the parsing /// succeeds. class Parser { - public: +public: /// Create a Parser for the supplied lexer. Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. std::unique_ptr parseModule() { - lexer.getNextToken(); // prime the lexer + lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector functions; while (auto f = parseDefinition()) { functions.push_back(std::move(*f)); - if (lexer.getCurToken() == tok_eof) break; + if (lexer.getCurToken() == tok_eof) + break; } // If we didn't reach EOF, there was an error during parsing if (lexer.getCurToken() != tok_eof) @@ -54,7 +56,7 @@ class Parser { return std::make_unique(std::move(functions)); } - private: +private: Lexer &lexer; /// Parse a return statement. @@ -67,7 +69,8 @@ class Parser { std::optional> expr; if (lexer.getCurToken() != ';') { expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; } return std::make_unique(std::move(loc), std::move(expr)); } @@ -97,7 +100,8 @@ class Parser { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { values.push_back(parseTensorLiteralExpr()); - if (!values.back()) return nullptr; // parse error in the nested array. + if (!values.back()) + return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError(" or [", "in literal expression"); @@ -105,17 +109,18 @@ class Parser { } // End of this list on ']' - if (lexer.getCurToken() == ']') break; + if (lexer.getCurToken() == ']') + break; // Elements are separated by a comma. if (lexer.getCurToken() != ',') return parseError("] or ,", "in literal expression"); - lexer.getNextToken(); // eat , + lexer.getNextToken(); // eat , } while (true); if (values.empty()) return parseError("", "to fill literal expression"); - lexer.getNextToken(); // eat ] + lexer.getNextToken(); // eat ] /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); @@ -151,9 +156,10 @@ class Parser { /// parenexpr ::= '(' expression ')' std::unique_ptr parseParenExpr() { - lexer.getNextToken(); // eat (. + lexer.getNextToken(); // eat (. auto v = parseExpression(); - if (!v) return nullptr; + if (!v) + return nullptr; if (lexer.getCurToken() != ')') return parseError(")", "to close expression with parentheses"); @@ -168,9 +174,9 @@ class Parser { std::string name(lexer.getId()); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. + lexer.getNextToken(); // eat identifier. - if (lexer.getCurToken() != '(') // Simple variable ref. + if (lexer.getCurToken() != '(') // Simple variable ref. return std::make_unique(std::move(loc), name); // This is a function call. @@ -183,7 +189,8 @@ class Parser { else return nullptr; - if (lexer.getCurToken() == ')') break; + if (lexer.getCurToken() == ')') + break; if (lexer.getCurToken() != ',') return parseError(", or )", "in argument list"); @@ -211,22 +218,22 @@ class Parser { /// ::= tensorliteral std::unique_ptr parsePrimary() { switch (lexer.getCurToken()) { - default: - llvm::errs() << "unknown token '" << lexer.getCurToken() - << "' when expecting an expression\n"; - return nullptr; - case tok_identifier: - return parseIdentifierExpr(); - case tok_number: - return parseNumberExpr(); - case '(': - return parseParenExpr(); - case '[': - return parseTensorLiteralExpr(); - case ';': - return nullptr; - case '}': - return nullptr; + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; } } @@ -242,7 +249,8 @@ class Parser { // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (tokPrec < exprPrec) return lhs; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. int binOp = lexer.getCurToken(); @@ -259,7 +267,8 @@ class Parser { int nextPrec = getTokPrecedence(); if (tokPrec < nextPrec) { rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; } // Merge lhs/RHS. @@ -271,7 +280,8 @@ class Parser { /// expression::= primary binop rhs std::unique_ptr parseExpression() { auto lhs = parsePrimary(); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; return parseBinOpRHS(0, std::move(lhs)); } @@ -281,19 +291,20 @@ class Parser { std::unique_ptr parseType() { if (lexer.getCurToken() != '<') return parseError("<", "to begin type"); - lexer.getNextToken(); // eat < + lexer.getNextToken(); // eat < auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); lexer.getNextToken(); - if (lexer.getCurToken() == ',') lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); } if (lexer.getCurToken() != '>') return parseError(">", "to end type"); - lexer.getNextToken(); // eat > + lexer.getNextToken(); // eat > return type; } @@ -305,21 +316,23 @@ class Parser { if (lexer.getCurToken() != tok_var) return parseError("var", "to begin declaration"); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat var + lexer.getNextToken(); // eat var if (lexer.getCurToken() != tok_identifier) return parseError("identified", "after 'var' declaration"); std::string id(lexer.getId()); - lexer.getNextToken(); // eat id + lexer.getNextToken(); // eat id - std::unique_ptr type; // Type is optional, it can be inferred + std::unique_ptr type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { type = parseType(); - if (!type) return nullptr; + if (!type) + return nullptr; } - if (!type) type = std::make_unique(); + if (!type) + type = std::make_unique(); lexer.consume(Token('=')); auto expr = parseExpression(); return std::make_unique(std::move(loc), std::move(id), @@ -340,23 +353,27 @@ class Parser { auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration auto varDecl = parseDeclaration(); - if (!varDecl) return nullptr; + if (!varDecl) + return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement auto ret = parseReturn(); - if (!ret) return nullptr; + if (!ret) + return nullptr; exprList->push_back(std::move(ret)); } else { // General expression auto expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; exprList->push_back(std::move(expr)); } // Ensure that elements are separated by a semicolon. @@ -364,7 +381,8 @@ class Parser { return parseError(";", "after expression"); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); } if (lexer.getCurToken() != '}') @@ -401,7 +419,8 @@ class Parser { lexer.consume(tok_identifier); auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); - if (lexer.getCurToken() != ',') break; + if (lexer.getCurToken() != ',') + break; lexer.consume(Token(',')); if (lexer.getCurToken() != tok_identifier) return parseError( @@ -423,7 +442,8 @@ class Parser { /// definition ::= prototype block std::unique_ptr parseDefinition() { auto proto = parsePrototype(); - if (!proto) return nullptr; + if (!proto) + return nullptr; if (auto block = parseBlock()) return std::make_unique(std::move(proto), std::move(block)); @@ -432,18 +452,19 @@ class Parser { /// Get the precedence of the pending binary operator token. int getTokPrecedence() { - if (!isascii(lexer.getCurToken())) return -1; + if (!isascii(lexer.getCurToken())) + return -1; // 1 is lowest precedence. switch (static_cast(lexer.getCurToken())) { - case '-': - return 20; - case '+': - return 20; - case '*': - return 40; - default: - return -1; + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; } } @@ -456,12 +477,13 @@ class Parser { llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " << lexer.getLastLocation().col << "): expected '" << expected << "' " << context << " but has Token " << curToken; - if (isprint(curToken)) llvm::errs() << " '" << (char)curToken << "'"; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; llvm::errs() << "\n"; return nullptr; } }; -} // namespace toy +} // namespace toy -#endif // TOY_PARSER_H +#endif // TOY_PARSER_H diff --git a/mlir/example/Ch3/mlir/Dialect.cpp b/mlir/example/Ch3/mlir/Dialect.cpp index dcf1045..708855f 100644 --- a/mlir/example/Ch3/mlir/Dialect.cpp +++ b/mlir/example/Ch3/mlir/Dialect.cpp @@ -13,10 +13,21 @@ #include "toy/Dialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include using namespace mlir; using namespace mlir::toy; @@ -130,11 +141,10 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { /// Verifier for the constant operation. This corresponds to the /// `let hasVerifier = 1` in the op definition. -mlir::LogicalResult ConstantOp::verify() { +llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = - llvm::dyn_cast(getResult().getType()); + auto resultType = llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -246,7 +256,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } // ReturnOp //===----------------------------------------------------------------------===// -mlir::LogicalResult ReturnOp::verify() { +llvm::LogicalResult ReturnOp::verify() { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. auto function = cast((*this)->getParentOp()); @@ -270,8 +280,7 @@ mlir::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || - llvm::isa(inputType) || + if (inputType == resultType || llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); @@ -290,7 +299,7 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, state.addOperands(value); } -mlir::LogicalResult TransposeOp::verify() { +llvm::LogicalResult TransposeOp::verify() { auto inputType = llvm::dyn_cast(getOperand().getType()); auto resultType = llvm::dyn_cast(getType()); if (!inputType || !resultType) diff --git a/mlir/example/Ch3/mlir/MLIRGen.cpp b/mlir/example/Ch3/mlir/MLIRGen.cpp index 589bd3a..bf4c099 100644 --- a/mlir/example/Ch3/mlir/MLIRGen.cpp +++ b/mlir/example/Ch3/mlir/MLIRGen.cpp @@ -12,20 +12,30 @@ //===----------------------------------------------------------------------===// #include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "toy/AST.h" +#include "toy/Dialect.h" -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" -#include "toy/AST.h" -#include "toy/Dialect.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include using namespace mlir::toy; using namespace toy; @@ -47,7 +57,7 @@ namespace { /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. class MLIRGenImpl { - public: +public: MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR @@ -57,7 +67,8 @@ class MLIRGenImpl { // add them to the module. theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); - for (FunctionAST &f : moduleAST) mlirGen(f); + for (FunctionAST &f : moduleAST) + mlirGen(f); // Verify the module after we have finished constructing it, this will check // the structural properties of the IR and invoke any specific verifiers we @@ -70,7 +81,7 @@ class MLIRGenImpl { return theModule; } - private: +private: /// A "module" matches a Toy source file: containing a list of functions. mlir::ModuleOp theModule; @@ -93,8 +104,9 @@ class MLIRGenImpl { /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { - if (symbolTable.count(var)) return mlir::failure(); + llvm::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); symbolTable.insert(var, value); return mlir::success(); } @@ -121,7 +133,8 @@ class MLIRGenImpl { // Create an MLIR function for the given prototype. builder.setInsertionPointToEnd(theModule.getBody()); mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); - if (!function) return nullptr; + if (!function) + return nullptr; // Let's start the body of the function now! mlir::Block &entryBlock = function.front(); @@ -150,7 +163,8 @@ class MLIRGenImpl { // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) ReturnOp returnOp; - if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { builder.create(loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { @@ -177,18 +191,20 @@ class MLIRGenImpl { // and propagate. // mlir::Value lhs = mlirGen(*binop.getLHS()); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; mlir::Value rhs = mlirGen(*binop.getRHS()); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; auto location = loc(binop.loc()); // Derive the operation name from the binary operator. At the moment we only // support '+' and '*'. switch (binop.getOp()) { - case '+': - return builder.create(location, lhs, rhs); - case '*': - return builder.create(location, lhs, rhs); + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -199,7 +215,8 @@ class MLIRGenImpl { /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. mlir::Value mlirGen(VariableExprAST &expr) { - if (auto variable = symbolTable.lookup(expr.getName())) return variable; + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'"; @@ -207,13 +224,14 @@ class MLIRGenImpl { } /// Emit a return operation. This will return failure if any generation fails. - mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + llvm::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. mlir::Value expr = nullptr; if (ret.getExpr().has_value()) { - if (!(expr = mlirGen(**ret.getExpr()))) return mlir::failure(); + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); } // Otherwise, this return operation has zero operands. @@ -275,7 +293,8 @@ class MLIRGenImpl { /// Attributes are the way MLIR attaches constant to operations. void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { - for (auto &value : lit->getValues()) collectData(*value, data); + for (auto &value : lit->getValues()) + collectData(*value, data); return; } @@ -293,7 +312,8 @@ class MLIRGenImpl { SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); - if (!arg) return nullptr; + if (!arg) + return nullptr; operands.push_back(arg); } @@ -301,9 +321,8 @@ class MLIRGenImpl { // straightforward emission. if (callee == "transpose") { if (call.getArgs().size() != 1) { - emitError(location, - "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); return nullptr; } return builder.create(location, operands[0]); @@ -317,9 +336,10 @@ class MLIRGenImpl { /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). - mlir::LogicalResult mlirGen(PrintExprAST &call) { + llvm::LogicalResult mlirGen(PrintExprAST &call) { auto arg = mlirGen(*call.getArg()); - if (!arg) return mlir::failure(); + if (!arg) + return mlir::failure(); builder.create(loc(call.loc()), arg); return mlir::success(); @@ -333,21 +353,21 @@ class MLIRGenImpl { /// Dispatch codegen for the right expression subclass using RTTI. mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast(expr)); - default: - emitError(loc(expr.loc())) - << "MLIR codegen encountered an unhandled expr kind '" - << Twine(expr.getKind()) << "'"; - return nullptr; + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; } } @@ -364,7 +384,8 @@ class MLIRGenImpl { } mlir::Value value = mlirGen(*init); - if (!value) return nullptr; + if (!value) + return nullptr; // We have the initializer value, but in case the variable was declared // with specific shape, we emit a "reshape" operation. It will get @@ -375,29 +396,34 @@ class MLIRGenImpl { } // Register the value in the symbol table. - if (failed(declare(vardecl.getName(), value))) return nullptr; + if (failed(declare(vardecl.getName(), value))) + return nullptr; return value; } /// Codegen a list of expression, return failure if one of them hit an error. - mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + llvm::LogicalResult mlirGen(ExprASTList &blockAST) { ScopedHashTableScope varScope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { - if (!mlirGen(*vardecl)) return mlir::failure(); + if (!mlirGen(*vardecl)) + return mlir::failure(); continue; } - if (auto *ret = dyn_cast(expr.get())) return mlirGen(*ret); + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); if (auto *print = dyn_cast(expr.get())) { - if (mlir::failed(mlirGen(*print))) return mlir::success(); + if (mlir::failed(mlirGen(*print))) + return mlir::success(); continue; } // Generic expression dispatch codegen. - if (!mlirGen(*expr)) return mlir::failure(); + if (!mlirGen(*expr)) + return mlir::failure(); } return mlir::success(); } @@ -417,7 +443,7 @@ class MLIRGenImpl { mlir::Type getType(const VarType &type) { return getType(type.shape); } }; -} // namespace +} // namespace namespace toy { @@ -427,4 +453,4 @@ mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, return MLIRGenImpl(context).mlirGen(moduleAST); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch3/mlir/ToyCombine.cpp b/mlir/example/Ch3/mlir/ToyCombine.cpp index d124e86..f8397c2 100644 --- a/mlir/example/Ch3/mlir/ToyCombine.cpp +++ b/mlir/example/Ch3/mlir/ToyCombine.cpp @@ -11,10 +11,9 @@ // //===----------------------------------------------------------------------===// -#include - -#include "mlir/IR/Matchers.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "toy/Dialect.h" using namespace mlir; using namespace toy; @@ -22,7 +21,7 @@ using namespace toy; namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "ToyCombine.inc" -} // namespace +} // namespace /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x @@ -36,14 +35,16 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::LogicalResult matchAndRewrite( - TransposeOp op, mlir::PatternRewriter &rewriter) const override { + llvm::LogicalResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. - if (!transposeInputOp) return failure(); + if (!transposeInputOp) + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); diff --git a/mlir/example/Ch3/mlir/ToyCombine.td b/mlir/example/Ch3/mlir/ToyCombine.td index 11d7831..8bd2b44 100644 --- a/mlir/example/Ch3/mlir/ToyCombine.td +++ b/mlir/example/Ch3/mlir/ToyCombine.td @@ -22,6 +22,7 @@ include "toy/Ops.td" /// class Pattern< /// dag sourcePattern, list resultPatterns, /// list additionalConstraints = [], +// list supplementalPatterns = [], /// dag benefitsAdded = (addBenefit 0) /// >; diff --git a/mlir/example/Ch3/parser/AST.cpp b/mlir/example/Ch3/parser/AST.cpp index e0bd2fd..2546f2a 100644 --- a/mlir/example/Ch3/parser/AST.cpp +++ b/mlir/example/Ch3/parser/AST.cpp @@ -12,9 +12,12 @@ #include "toy/AST.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include using namespace toy; @@ -31,10 +34,10 @@ struct Indent { /// Helper class that implement the AST tree traversal and print the nodes along /// the way. The only data member is the current indentation level. class ASTDumper { - public: +public: void dump(ModuleAST *node); - private: +private: void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); @@ -51,12 +54,13 @@ class ASTDumper { // Actually print spaces matching the current indentation level void indent() { - for (int i = 0; i < curIndent; i++) llvm::errs() << " "; + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; } int curIndent = 0; }; -} // namespace +} // namespace /// Return a formatted string for the location of any node template @@ -69,8 +73,8 @@ static std::string loc(T *node) { // Helper Macro to bump the indentation level and print the leading spaces for // the current indentations -#define INDENT() \ - Indent level_(curIndent); \ +#define INDENT() \ + Indent level_(curIndent); \ indent(); /// Dispatch to a generic expressions to the appropriate subclass using RTTI @@ -100,7 +104,8 @@ void ASTDumper::dump(VarDeclExprAST *varDecl) { void ASTDumper::dump(ExprASTList *exprList) { INDENT(); llvm::errs() << "Block {\n"; - for (auto &expr : *exprList) dump(expr.get()); + for (auto &expr : *exprList) + dump(expr.get()); indent(); llvm::errs() << "} // Block\n"; } @@ -153,7 +158,8 @@ void ASTDumper::dump(VariableExprAST *node) { void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (node->getExpr().has_value()) return dump(*node->getExpr()); + if (node->getExpr().has_value()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -173,7 +179,8 @@ void ASTDumper::dump(BinaryExprAST *node) { void ASTDumper::dump(CallExprAST *node) { INDENT(); llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; - for (auto &arg : node->getArgs()) dump(arg.get()); + for (auto &arg : node->getArgs()) + dump(arg.get()); indent(); llvm::errs() << "]\n"; } @@ -218,7 +225,8 @@ void ASTDumper::dump(FunctionAST *node) { void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &f : *node) dump(&f); + for (auto &f : *node) + dump(&f); } namespace toy { @@ -226,4 +234,4 @@ namespace toy { // Public API void dump(ModuleAST &module) { ASTDumper().dump(&module); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch3/toyc.cpp b/mlir/example/Ch3/toyc.cpp index e4e2aad..f8aa846 100644 --- a/mlir/example/Ch3/toyc.cpp +++ b/mlir/example/Ch3/toyc.cpp @@ -10,23 +10,31 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/Diagnostics.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" + #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "toy/Dialect.h" -#include "toy/MLIRGen.h" -#include "toy/Parser.h" + #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include using namespace toy; namespace cl = llvm::cl; @@ -73,7 +81,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && - !llvm::StringRef(inputFilename).endswith(".mlir")) { + !llvm::StringRef(inputFilename).ends_with(".mlir")) { auto moduleAST = parseInputFile(inputFilename); if (!moduleAST) return 6; diff --git a/mlir/example/Ch4/include/toy/AST.h b/mlir/example/Ch4/include/toy/AST.h index c9d1bdb..d2ba101 100644 --- a/mlir/example/Ch4/include/toy/AST.h +++ b/mlir/example/Ch4/include/toy/AST.h @@ -15,14 +15,14 @@ #ifndef TOY_AST_H #define TOY_AST_H -#include -#include -#include +#include "toy/Lexer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "toy/Lexer.h" +#include +#include +#include namespace toy { @@ -33,7 +33,7 @@ struct VarType { /// Base class for all expression nodes. class ExprAST { - public: +public: enum ExprASTKind { Expr_VarDecl, Expr_Return, @@ -53,7 +53,7 @@ class ExprAST { const Location &loc() { return location; } - private: +private: const ExprASTKind kind; Location location; }; @@ -65,7 +65,7 @@ using ExprASTList = std::vector>; class NumberExprAST : public ExprAST { double val; - public: +public: NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, std::move(loc)), val(val) {} @@ -80,11 +80,10 @@ class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; - public: +public: LiteralExprAST(Location loc, std::vector> values, std::vector dims) - : ExprAST(Expr_Literal, std::move(loc)), - values(std::move(values)), + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), dims(std::move(dims)) {} llvm::ArrayRef> getValues() { return values; } @@ -98,7 +97,7 @@ class LiteralExprAST : public ExprAST { class VariableExprAST : public ExprAST { std::string name; - public: +public: VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, std::move(loc)), name(name) {} @@ -114,13 +113,11 @@ class VarDeclExprAST : public ExprAST { VarType type; std::unique_ptr initVal; - public: +public: VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr initVal) - : ExprAST(Expr_VarDecl, std::move(loc)), - name(name), - type(std::move(type)), - initVal(std::move(initVal)) {} + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } @@ -134,12 +131,13 @@ class VarDeclExprAST : public ExprAST { class ReturnExprAST : public ExprAST { std::optional> expr; - public: +public: ReturnExprAST(Location loc, std::optional> expr) : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} std::optional getExpr() { - if (expr.has_value()) return expr->get(); + if (expr.has_value()) + return expr->get(); return std::nullopt; } @@ -152,16 +150,14 @@ class BinaryExprAST : public ExprAST { char op; std::unique_ptr lhs, rhs; - public: +public: char getOp() { return op; } ExprAST *getLHS() { return lhs.get(); } ExprAST *getRHS() { return rhs.get(); } BinaryExprAST(Location loc, char op, std::unique_ptr lhs, std::unique_ptr rhs) - : ExprAST(Expr_BinOp, std::move(loc)), - op(op), - lhs(std::move(lhs)), + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} /// LLVM style RTTI @@ -173,11 +169,10 @@ class CallExprAST : public ExprAST { std::string callee; std::vector> args; - public: +public: CallExprAST(Location loc, const std::string &callee, std::vector> args) - : ExprAST(Expr_Call, std::move(loc)), - callee(callee), + : ExprAST(Expr_Call, std::move(loc)), callee(callee), args(std::move(args)) {} llvm::StringRef getCallee() { return callee; } @@ -191,7 +186,7 @@ class CallExprAST : public ExprAST { class PrintExprAST : public ExprAST { std::unique_ptr arg; - public: +public: PrintExprAST(Location loc, std::unique_ptr arg) : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} @@ -209,7 +204,7 @@ class PrototypeAST { std::string name; std::vector> args; - public: +public: PrototypeAST(Location location, const std::string &name, std::vector> args) : location(std::move(location)), name(name), args(std::move(args)) {} @@ -224,7 +219,7 @@ class FunctionAST { std::unique_ptr proto; std::unique_ptr body; - public: +public: FunctionAST(std::unique_ptr proto, std::unique_ptr body) : proto(std::move(proto)), body(std::move(body)) {} @@ -236,7 +231,7 @@ class FunctionAST { class ModuleAST { std::vector functions; - public: +public: ModuleAST(std::vector functions) : functions(std::move(functions)) {} @@ -246,6 +241,6 @@ class ModuleAST { void dump(ModuleAST &); -} // namespace toy +} // namespace toy -#endif // TOY_AST_H +#endif // TOY_AST_H diff --git a/mlir/example/Ch4/include/toy/Dialect.h b/mlir/example/Ch4/include/toy/Dialect.h index 8aae4b8..5db325e 100644 --- a/mlir/example/Ch4/include/toy/Dialect.h +++ b/mlir/example/Ch4/include/toy/Dialect.h @@ -14,12 +14,13 @@ #ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ #define MLIR_TUTORIAL_TOY_DIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" @@ -32,4 +33,4 @@ #define GET_OP_CLASSES #include "toy/Ops.h.inc" -#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/example/Ch4/include/toy/Lexer.h b/mlir/example/Ch4/include/toy/Lexer.h index 176a40c..3c59cd9 100644 --- a/mlir/example/Ch4/include/toy/Lexer.h +++ b/mlir/example/Ch4/include/toy/Lexer.h @@ -13,18 +13,18 @@ #ifndef TOY_LEXER_H #define TOY_LEXER_H +#include "llvm/ADT/StringRef.h" + #include #include -#include "llvm/ADT/StringRef.h" - namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename. - int line; ///< line number. - int col; ///< column number. + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. }; // List of Token returned by the lexer. @@ -56,7 +56,7 @@ enum Token : int { /// can proceed by reading the next line from the standard input or from a /// memory mapped file. class Lexer { - public: +public: /// Create a lexer for the given filename. The filename is kept only for /// debugging purpose (attaching a location to a Token). Lexer(std::string filename) @@ -98,7 +98,7 @@ class Lexer { // Return the current column in the file. int getCol() { return curCol; } - private: +private: /// Delegate to a derived class fetching the next line. Returns an empty /// string to signal end of file (EOF). Lines are expected to always finish /// with "\n" @@ -109,11 +109,13 @@ class Lexer { /// needed. int getNextChar() { // The current line buffer should not be empty unless it is the end of file. - if (curLineBuffer.empty()) return EOF; + if (curLineBuffer.empty()) + return EOF; ++curCol; auto nextchar = curLineBuffer.front(); curLineBuffer = curLineBuffer.drop_front(); - if (curLineBuffer.empty()) curLineBuffer = readNextLine(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); if (nextchar == '\n') { ++curLineNum; curCol = 0; @@ -124,7 +126,8 @@ class Lexer { /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(lastChar)) lastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; @@ -136,9 +139,12 @@ class Lexer { while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') identifierStr += (char)lastChar; - if (identifierStr == "return") return tok_return; - if (identifierStr == "def") return tok_def; - if (identifierStr == "var") return tok_var; + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; return tok_identifier; } @@ -160,11 +166,13 @@ class Lexer { lastChar = Token(getNextChar()); } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (lastChar != EOF) return getTok(); + if (lastChar != EOF) + return getTok(); } // Check for end of file. Don't eat the EOF. - if (lastChar == EOF) return tok_eof; + if (lastChar == EOF) + return tok_eof; // Otherwise, just return the character as its ascii value. Token thisChar = Token(lastChar); @@ -201,22 +209,24 @@ class Lexer { /// A lexer implementation operating on a buffer in memory. class LexerBuffer final : public Lexer { - public: +public: LexerBuffer(const char *begin, const char *end, std::string filename) : Lexer(std::move(filename)), current(begin), end(end) {} - private: +private: /// Provide one line at a time to the Lexer, return an empty string when /// reaching the end of the buffer. llvm::StringRef readNextLine() override { auto *begin = current; - while (current <= end && *current && *current != '\n') ++current; - if (current <= end && *current) ++current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; llvm::StringRef result{begin, static_cast(current - begin)}; return result; } const char *current, *end; }; -} // namespace toy +} // namespace toy -#endif // TOY_LEXER_H +#endif // TOY_LEXER_H diff --git a/mlir/example/Ch4/include/toy/MLIRGen.h b/mlir/example/Ch4/include/toy/MLIRGen.h index 5afdf62..fe9dbe5 100644 --- a/mlir/example/Ch4/include/toy/MLIRGen.h +++ b/mlir/example/Ch4/include/toy/MLIRGen.h @@ -21,7 +21,7 @@ class MLIRContext; template class OwningOpRef; class ModuleOp; -} // namespace mlir +} // namespace mlir namespace toy { class ModuleAST; @@ -30,6 +30,6 @@ class ModuleAST; /// or nullptr on failure. mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); -} // namespace toy +} // namespace toy -#endif // TOY_MLIRGEN_H +#endif // TOY_MLIRGEN_H diff --git a/mlir/example/Ch4/include/toy/Ops.td b/mlir/example/Ch4/include/toy/Ops.td index c394134..075fd1a 100644 --- a/mlir/example/Ch4/include/toy/Ops.td +++ b/mlir/example/Ch4/include/toy/Ops.td @@ -13,7 +13,7 @@ #ifndef TOY_OPS #define TOY_OPS -include "mlir/IR/FunctionInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -141,8 +141,7 @@ def CastOp : Toy_Op<"cast", [ //===----------------------------------------------------------------------===// def FuncOp : Toy_Op<"func", [ - DeclareOpInterfaceMethods, FunctionOpInterface, - IsolatedFromAbove + FunctionOpInterface, IsolatedFromAbove ]> { let summary = "user defined function operation"; let description = [{ @@ -173,6 +172,7 @@ def FuncOp : Toy_Op<"func", [ "StringRef":$name, "FunctionType":$type, CArg<"ArrayRef", "{}">:$attrs) >]; + let extraClassDeclaration = [{ //===------------------------------------------------------------------===// // FunctionOpInterface Methods @@ -183,7 +183,10 @@ def FuncOp : Toy_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } }]; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; } diff --git a/mlir/example/Ch4/include/toy/Parser.h b/mlir/example/Ch4/include/toy/Parser.h index ededa4c..1f20616 100644 --- a/mlir/example/Ch4/include/toy/Parser.h +++ b/mlir/example/Ch4/include/toy/Parser.h @@ -14,16 +14,17 @@ #ifndef TOY_PARSER_H #define TOY_PARSER_H -#include -#include -#include -#include +#include "toy/AST.h" +#include "toy/Lexer.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" -#include "toy/AST.h" -#include "toy/Lexer.h" + +#include +#include +#include +#include namespace toy { @@ -33,19 +34,20 @@ namespace toy { /// string and the code could reference an undeclared variable and the parsing /// succeeds. class Parser { - public: +public: /// Create a Parser for the supplied lexer. Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. std::unique_ptr parseModule() { - lexer.getNextToken(); // prime the lexer + lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector functions; while (auto f = parseDefinition()) { functions.push_back(std::move(*f)); - if (lexer.getCurToken() == tok_eof) break; + if (lexer.getCurToken() == tok_eof) + break; } // If we didn't reach EOF, there was an error during parsing if (lexer.getCurToken() != tok_eof) @@ -54,7 +56,7 @@ class Parser { return std::make_unique(std::move(functions)); } - private: +private: Lexer &lexer; /// Parse a return statement. @@ -67,7 +69,8 @@ class Parser { std::optional> expr; if (lexer.getCurToken() != ';') { expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; } return std::make_unique(std::move(loc), std::move(expr)); } @@ -97,7 +100,8 @@ class Parser { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { values.push_back(parseTensorLiteralExpr()); - if (!values.back()) return nullptr; // parse error in the nested array. + if (!values.back()) + return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError(" or [", "in literal expression"); @@ -105,17 +109,18 @@ class Parser { } // End of this list on ']' - if (lexer.getCurToken() == ']') break; + if (lexer.getCurToken() == ']') + break; // Elements are separated by a comma. if (lexer.getCurToken() != ',') return parseError("] or ,", "in literal expression"); - lexer.getNextToken(); // eat , + lexer.getNextToken(); // eat , } while (true); if (values.empty()) return parseError("", "to fill literal expression"); - lexer.getNextToken(); // eat ] + lexer.getNextToken(); // eat ] /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); @@ -151,9 +156,10 @@ class Parser { /// parenexpr ::= '(' expression ')' std::unique_ptr parseParenExpr() { - lexer.getNextToken(); // eat (. + lexer.getNextToken(); // eat (. auto v = parseExpression(); - if (!v) return nullptr; + if (!v) + return nullptr; if (lexer.getCurToken() != ')') return parseError(")", "to close expression with parentheses"); @@ -168,9 +174,9 @@ class Parser { std::string name(lexer.getId()); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. + lexer.getNextToken(); // eat identifier. - if (lexer.getCurToken() != '(') // Simple variable ref. + if (lexer.getCurToken() != '(') // Simple variable ref. return std::make_unique(std::move(loc), name); // This is a function call. @@ -183,7 +189,8 @@ class Parser { else return nullptr; - if (lexer.getCurToken() == ')') break; + if (lexer.getCurToken() == ')') + break; if (lexer.getCurToken() != ',') return parseError(", or )", "in argument list"); @@ -211,22 +218,22 @@ class Parser { /// ::= tensorliteral std::unique_ptr parsePrimary() { switch (lexer.getCurToken()) { - default: - llvm::errs() << "unknown token '" << lexer.getCurToken() - << "' when expecting an expression\n"; - return nullptr; - case tok_identifier: - return parseIdentifierExpr(); - case tok_number: - return parseNumberExpr(); - case '(': - return parseParenExpr(); - case '[': - return parseTensorLiteralExpr(); - case ';': - return nullptr; - case '}': - return nullptr; + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; } } @@ -242,7 +249,8 @@ class Parser { // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (tokPrec < exprPrec) return lhs; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. int binOp = lexer.getCurToken(); @@ -259,7 +267,8 @@ class Parser { int nextPrec = getTokPrecedence(); if (tokPrec < nextPrec) { rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; } // Merge lhs/RHS. @@ -271,7 +280,8 @@ class Parser { /// expression::= primary binop rhs std::unique_ptr parseExpression() { auto lhs = parsePrimary(); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; return parseBinOpRHS(0, std::move(lhs)); } @@ -281,19 +291,20 @@ class Parser { std::unique_ptr parseType() { if (lexer.getCurToken() != '<') return parseError("<", "to begin type"); - lexer.getNextToken(); // eat < + lexer.getNextToken(); // eat < auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); lexer.getNextToken(); - if (lexer.getCurToken() == ',') lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); } if (lexer.getCurToken() != '>') return parseError(">", "to end type"); - lexer.getNextToken(); // eat > + lexer.getNextToken(); // eat > return type; } @@ -305,21 +316,23 @@ class Parser { if (lexer.getCurToken() != tok_var) return parseError("var", "to begin declaration"); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat var + lexer.getNextToken(); // eat var if (lexer.getCurToken() != tok_identifier) return parseError("identified", "after 'var' declaration"); std::string id(lexer.getId()); - lexer.getNextToken(); // eat id + lexer.getNextToken(); // eat id - std::unique_ptr type; // Type is optional, it can be inferred + std::unique_ptr type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { type = parseType(); - if (!type) return nullptr; + if (!type) + return nullptr; } - if (!type) type = std::make_unique(); + if (!type) + type = std::make_unique(); lexer.consume(Token('=')); auto expr = parseExpression(); return std::make_unique(std::move(loc), std::move(id), @@ -340,23 +353,27 @@ class Parser { auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration auto varDecl = parseDeclaration(); - if (!varDecl) return nullptr; + if (!varDecl) + return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement auto ret = parseReturn(); - if (!ret) return nullptr; + if (!ret) + return nullptr; exprList->push_back(std::move(ret)); } else { // General expression auto expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; exprList->push_back(std::move(expr)); } // Ensure that elements are separated by a semicolon. @@ -364,7 +381,8 @@ class Parser { return parseError(";", "after expression"); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); } if (lexer.getCurToken() != '}') @@ -401,7 +419,8 @@ class Parser { lexer.consume(tok_identifier); auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); - if (lexer.getCurToken() != ',') break; + if (lexer.getCurToken() != ',') + break; lexer.consume(Token(',')); if (lexer.getCurToken() != tok_identifier) return parseError( @@ -423,7 +442,8 @@ class Parser { /// definition ::= prototype block std::unique_ptr parseDefinition() { auto proto = parsePrototype(); - if (!proto) return nullptr; + if (!proto) + return nullptr; if (auto block = parseBlock()) return std::make_unique(std::move(proto), std::move(block)); @@ -432,18 +452,19 @@ class Parser { /// Get the precedence of the pending binary operator token. int getTokPrecedence() { - if (!isascii(lexer.getCurToken())) return -1; + if (!isascii(lexer.getCurToken())) + return -1; // 1 is lowest precedence. switch (static_cast(lexer.getCurToken())) { - case '-': - return 20; - case '+': - return 20; - case '*': - return 40; - default: - return -1; + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; } } @@ -456,12 +477,13 @@ class Parser { llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " << lexer.getLastLocation().col << "): expected '" << expected << "' " << context << " but has Token " << curToken; - if (isprint(curToken)) llvm::errs() << " '" << (char)curToken << "'"; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; llvm::errs() << "\n"; return nullptr; } }; -} // namespace toy +} // namespace toy -#endif // TOY_PARSER_H +#endif // TOY_PARSER_H diff --git a/mlir/example/Ch4/include/toy/Passes.h b/mlir/example/Ch4/include/toy/Passes.h index b33aee8..0eafa08 100644 --- a/mlir/example/Ch4/include/toy/Passes.h +++ b/mlir/example/Ch4/include/toy/Passes.h @@ -20,7 +20,7 @@ class Pass; namespace toy { std::unique_ptr createShapeInferencePass(); -} // namespace toy -} // namespace mlir +} // namespace toy +} // namespace mlir -#endif // TOY_PASSES_H +#endif // TOY_PASSES_H diff --git a/mlir/example/Ch4/include/toy/ShapeInferenceInterface.h b/mlir/example/Ch4/include/toy/ShapeInferenceInterface.h index a32ef17..cfe5a87 100644 --- a/mlir/example/Ch4/include/toy/ShapeInferenceInterface.h +++ b/mlir/example/Ch4/include/toy/ShapeInferenceInterface.h @@ -22,7 +22,7 @@ namespace toy { /// Include the auto-generated declarations. #include "toy/ShapeInferenceOpInterfaces.h.inc" -} // namespace toy -} // namespace mlir +} // namespace toy +} // namespace mlir -#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/example/Ch4/mlir/Dialect.cpp b/mlir/example/Ch4/mlir/Dialect.cpp index b43d289..6c6cdd9 100644 --- a/mlir/example/Ch4/mlir/Dialect.cpp +++ b/mlir/example/Ch4/mlir/Dialect.cpp @@ -13,11 +13,25 @@ #include "toy/Dialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/FunctionImplementation.h" +#include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include using namespace mlir; using namespace mlir::toy; @@ -59,8 +73,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -190,11 +203,10 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { /// Verifier for the constant operation. This corresponds to the /// `let hasVerifier = 1` in the op definition. -mlir::LogicalResult ConstantOp::verify() { +llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = - llvm::dyn_cast(getResult().getType()); + auto resultType = llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -299,27 +311,6 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { getArgAttrsAttrName(), getResAttrsAttrName()); } -/// Returns the region on the function operation that is callable. -mlir::Region *FuncOp::getCallableRegion() { return &getBody(); } - -/// Returns the results types that the callable region produces when -/// executed. -llvm::ArrayRef FuncOp::getCallableResults() { - return getFunctionType().getResults(); -} - -/// Returns the argument attributes for all callable region arguments or -/// null if there are none. -ArrayAttr FuncOp::getCallableArgAttrs() { - return getArgAttrs().value_or(nullptr); -} - -/// Returns the result attributes for all callable region results or -/// null if there are none. -ArrayAttr FuncOp::getCallableResAttrs() { - return getResAttrs().value_or(nullptr); -} - //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// @@ -349,6 +340,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -374,7 +371,7 @@ void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } // ReturnOp //===----------------------------------------------------------------------===// -mlir::LogicalResult ReturnOp::verify() { +llvm::LogicalResult ReturnOp::verify() { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. auto function = cast((*this)->getParentOp()); @@ -398,8 +395,7 @@ mlir::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || - llvm::isa(inputType) || + if (inputType == resultType || llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); @@ -424,7 +420,7 @@ void TransposeOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } -mlir::LogicalResult TransposeOp::verify() { +llvm::LogicalResult TransposeOp::verify() { auto inputType = llvm::dyn_cast(getOperand().getType()); auto resultType = llvm::dyn_cast(getType()); if (!inputType || !resultType) diff --git a/mlir/example/Ch4/mlir/MLIRGen.cpp b/mlir/example/Ch4/mlir/MLIRGen.cpp index 58ec129..b56e2f7 100644 --- a/mlir/example/Ch4/mlir/MLIRGen.cpp +++ b/mlir/example/Ch4/mlir/MLIRGen.cpp @@ -12,20 +12,30 @@ //===----------------------------------------------------------------------===// #include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "toy/AST.h" +#include "toy/Dialect.h" -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" -#include "toy/AST.h" -#include "toy/Dialect.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include using namespace mlir::toy; using namespace toy; @@ -47,7 +57,7 @@ namespace { /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. class MLIRGenImpl { - public: +public: MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR @@ -57,7 +67,8 @@ class MLIRGenImpl { // add them to the module. theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); - for (FunctionAST &f : moduleAST) mlirGen(f); + for (FunctionAST &f : moduleAST) + mlirGen(f); // Verify the module after we have finished constructing it, this will check // the structural properties of the IR and invoke any specific verifiers we @@ -70,7 +81,7 @@ class MLIRGenImpl { return theModule; } - private: +private: /// A "module" matches a Toy source file: containing a list of functions. mlir::ModuleOp theModule; @@ -93,8 +104,9 @@ class MLIRGenImpl { /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { - if (symbolTable.count(var)) return mlir::failure(); + llvm::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); symbolTable.insert(var, value); return mlir::success(); } @@ -121,7 +133,8 @@ class MLIRGenImpl { // Create an MLIR function for the given prototype. builder.setInsertionPointToEnd(theModule.getBody()); mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); - if (!function) return nullptr; + if (!function) + return nullptr; // Let's start the body of the function now! mlir::Block &entryBlock = function.front(); @@ -150,7 +163,8 @@ class MLIRGenImpl { // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) ReturnOp returnOp; - if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { builder.create(loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { @@ -161,7 +175,8 @@ class MLIRGenImpl { } // If this function isn't main, then set the visibility to private. - if (funcAST.getProto()->getName() != "main") function.setPrivate(); + if (funcAST.getProto()->getName() != "main") + function.setPrivate(); return function; } @@ -180,18 +195,20 @@ class MLIRGenImpl { // and propagate. // mlir::Value lhs = mlirGen(*binop.getLHS()); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; mlir::Value rhs = mlirGen(*binop.getRHS()); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; auto location = loc(binop.loc()); // Derive the operation name from the binary operator. At the moment we only // support '+' and '*'. switch (binop.getOp()) { - case '+': - return builder.create(location, lhs, rhs); - case '*': - return builder.create(location, lhs, rhs); + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -202,7 +219,8 @@ class MLIRGenImpl { /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. mlir::Value mlirGen(VariableExprAST &expr) { - if (auto variable = symbolTable.lookup(expr.getName())) return variable; + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'"; @@ -210,13 +228,14 @@ class MLIRGenImpl { } /// Emit a return operation. This will return failure if any generation fails. - mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + llvm::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. mlir::Value expr = nullptr; if (ret.getExpr().has_value()) { - if (!(expr = mlirGen(**ret.getExpr()))) return mlir::failure(); + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); } // Otherwise, this return operation has zero operands. @@ -278,7 +297,8 @@ class MLIRGenImpl { /// Attributes are the way MLIR attaches constant to operations. void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { - for (auto &value : lit->getValues()) collectData(*value, data); + for (auto &value : lit->getValues()) + collectData(*value, data); return; } @@ -296,7 +316,8 @@ class MLIRGenImpl { SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); - if (!arg) return nullptr; + if (!arg) + return nullptr; operands.push_back(arg); } @@ -304,9 +325,8 @@ class MLIRGenImpl { // straightforward emission. if (callee == "transpose") { if (call.getArgs().size() != 1) { - emitError(location, - "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); return nullptr; } return builder.create(location, operands[0]); @@ -320,9 +340,10 @@ class MLIRGenImpl { /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). - mlir::LogicalResult mlirGen(PrintExprAST &call) { + llvm::LogicalResult mlirGen(PrintExprAST &call) { auto arg = mlirGen(*call.getArg()); - if (!arg) return mlir::failure(); + if (!arg) + return mlir::failure(); builder.create(loc(call.loc()), arg); return mlir::success(); @@ -336,21 +357,21 @@ class MLIRGenImpl { /// Dispatch codegen for the right expression subclass using RTTI. mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast(expr)); - default: - emitError(loc(expr.loc())) - << "MLIR codegen encountered an unhandled expr kind '" - << Twine(expr.getKind()) << "'"; - return nullptr; + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; } } @@ -367,7 +388,8 @@ class MLIRGenImpl { } mlir::Value value = mlirGen(*init); - if (!value) return nullptr; + if (!value) + return nullptr; // We have the initializer value, but in case the variable was declared // with specific shape, we emit a "reshape" operation. It will get @@ -378,29 +400,34 @@ class MLIRGenImpl { } // Register the value in the symbol table. - if (failed(declare(vardecl.getName(), value))) return nullptr; + if (failed(declare(vardecl.getName(), value))) + return nullptr; return value; } /// Codegen a list of expression, return failure if one of them hit an error. - mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + llvm::LogicalResult mlirGen(ExprASTList &blockAST) { ScopedHashTableScope varScope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { - if (!mlirGen(*vardecl)) return mlir::failure(); + if (!mlirGen(*vardecl)) + return mlir::failure(); continue; } - if (auto *ret = dyn_cast(expr.get())) return mlirGen(*ret); + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); if (auto *print = dyn_cast(expr.get())) { - if (mlir::failed(mlirGen(*print))) return mlir::success(); + if (mlir::failed(mlirGen(*print))) + return mlir::success(); continue; } // Generic expression dispatch codegen. - if (!mlirGen(*expr)) return mlir::failure(); + if (!mlirGen(*expr)) + return mlir::failure(); } return mlir::success(); } @@ -420,7 +447,7 @@ class MLIRGenImpl { mlir::Type getType(const VarType &type) { return getType(type.shape); } }; -} // namespace +} // namespace namespace toy { @@ -430,4 +457,4 @@ mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, return MLIRGenImpl(context).mlirGen(moduleAST); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch4/mlir/ShapeInferencePass.cpp b/mlir/example/Ch4/mlir/ShapeInferencePass.cpp index d45baa1..a9e995e 100644 --- a/mlir/example/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/example/Ch4/mlir/ShapeInferencePass.cpp @@ -11,13 +11,21 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "toy/Dialect.h" #include "toy/Passes.h" #include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "shape-inference" diff --git a/mlir/example/Ch4/mlir/ToyCombine.cpp b/mlir/example/Ch4/mlir/ToyCombine.cpp index d124e86..f8397c2 100644 --- a/mlir/example/Ch4/mlir/ToyCombine.cpp +++ b/mlir/example/Ch4/mlir/ToyCombine.cpp @@ -11,10 +11,9 @@ // //===----------------------------------------------------------------------===// -#include - -#include "mlir/IR/Matchers.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "toy/Dialect.h" using namespace mlir; using namespace toy; @@ -22,7 +21,7 @@ using namespace toy; namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "ToyCombine.inc" -} // namespace +} // namespace /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x @@ -36,14 +35,16 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::LogicalResult matchAndRewrite( - TransposeOp op, mlir::PatternRewriter &rewriter) const override { + llvm::LogicalResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. - if (!transposeInputOp) return failure(); + if (!transposeInputOp) + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); diff --git a/mlir/example/Ch4/parser/AST.cpp b/mlir/example/Ch4/parser/AST.cpp index e0bd2fd..2546f2a 100644 --- a/mlir/example/Ch4/parser/AST.cpp +++ b/mlir/example/Ch4/parser/AST.cpp @@ -12,9 +12,12 @@ #include "toy/AST.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include using namespace toy; @@ -31,10 +34,10 @@ struct Indent { /// Helper class that implement the AST tree traversal and print the nodes along /// the way. The only data member is the current indentation level. class ASTDumper { - public: +public: void dump(ModuleAST *node); - private: +private: void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); @@ -51,12 +54,13 @@ class ASTDumper { // Actually print spaces matching the current indentation level void indent() { - for (int i = 0; i < curIndent; i++) llvm::errs() << " "; + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; } int curIndent = 0; }; -} // namespace +} // namespace /// Return a formatted string for the location of any node template @@ -69,8 +73,8 @@ static std::string loc(T *node) { // Helper Macro to bump the indentation level and print the leading spaces for // the current indentations -#define INDENT() \ - Indent level_(curIndent); \ +#define INDENT() \ + Indent level_(curIndent); \ indent(); /// Dispatch to a generic expressions to the appropriate subclass using RTTI @@ -100,7 +104,8 @@ void ASTDumper::dump(VarDeclExprAST *varDecl) { void ASTDumper::dump(ExprASTList *exprList) { INDENT(); llvm::errs() << "Block {\n"; - for (auto &expr : *exprList) dump(expr.get()); + for (auto &expr : *exprList) + dump(expr.get()); indent(); llvm::errs() << "} // Block\n"; } @@ -153,7 +158,8 @@ void ASTDumper::dump(VariableExprAST *node) { void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (node->getExpr().has_value()) return dump(*node->getExpr()); + if (node->getExpr().has_value()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -173,7 +179,8 @@ void ASTDumper::dump(BinaryExprAST *node) { void ASTDumper::dump(CallExprAST *node) { INDENT(); llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; - for (auto &arg : node->getArgs()) dump(arg.get()); + for (auto &arg : node->getArgs()) + dump(arg.get()); indent(); llvm::errs() << "]\n"; } @@ -218,7 +225,8 @@ void ASTDumper::dump(FunctionAST *node) { void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &f : *node) dump(&f); + for (auto &f : *node) + dump(&f); } namespace toy { @@ -226,4 +234,4 @@ namespace toy { // Public API void dump(ModuleAST &module) { ASTDumper().dump(&module); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch4/toyc.cpp b/mlir/example/Ch4/toyc.cpp index e575e13..ae02bc4 100644 --- a/mlir/example/Ch4/toyc.cpp +++ b/mlir/example/Ch4/toyc.cpp @@ -10,24 +10,32 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/Diagnostics.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "toy/Dialect.h" -#include "toy/MLIRGen.h" -#include "toy/Parser.h" -#include "toy/Passes.h" + #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include using namespace toy; namespace cl = llvm::cl; @@ -74,7 +82,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && - !llvm::StringRef(inputFilename).endswith(".mlir")) { + !llvm::StringRef(inputFilename).ends_with(".mlir")) { auto moduleAST = parseInputFile(inputFilename); if (!moduleAST) return 6; diff --git a/mlir/example/Ch5/include/toy/AST.h b/mlir/example/Ch5/include/toy/AST.h index c9d1bdb..d2ba101 100644 --- a/mlir/example/Ch5/include/toy/AST.h +++ b/mlir/example/Ch5/include/toy/AST.h @@ -15,14 +15,14 @@ #ifndef TOY_AST_H #define TOY_AST_H -#include -#include -#include +#include "toy/Lexer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "toy/Lexer.h" +#include +#include +#include namespace toy { @@ -33,7 +33,7 @@ struct VarType { /// Base class for all expression nodes. class ExprAST { - public: +public: enum ExprASTKind { Expr_VarDecl, Expr_Return, @@ -53,7 +53,7 @@ class ExprAST { const Location &loc() { return location; } - private: +private: const ExprASTKind kind; Location location; }; @@ -65,7 +65,7 @@ using ExprASTList = std::vector>; class NumberExprAST : public ExprAST { double val; - public: +public: NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, std::move(loc)), val(val) {} @@ -80,11 +80,10 @@ class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; - public: +public: LiteralExprAST(Location loc, std::vector> values, std::vector dims) - : ExprAST(Expr_Literal, std::move(loc)), - values(std::move(values)), + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), dims(std::move(dims)) {} llvm::ArrayRef> getValues() { return values; } @@ -98,7 +97,7 @@ class LiteralExprAST : public ExprAST { class VariableExprAST : public ExprAST { std::string name; - public: +public: VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, std::move(loc)), name(name) {} @@ -114,13 +113,11 @@ class VarDeclExprAST : public ExprAST { VarType type; std::unique_ptr initVal; - public: +public: VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr initVal) - : ExprAST(Expr_VarDecl, std::move(loc)), - name(name), - type(std::move(type)), - initVal(std::move(initVal)) {} + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } @@ -134,12 +131,13 @@ class VarDeclExprAST : public ExprAST { class ReturnExprAST : public ExprAST { std::optional> expr; - public: +public: ReturnExprAST(Location loc, std::optional> expr) : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} std::optional getExpr() { - if (expr.has_value()) return expr->get(); + if (expr.has_value()) + return expr->get(); return std::nullopt; } @@ -152,16 +150,14 @@ class BinaryExprAST : public ExprAST { char op; std::unique_ptr lhs, rhs; - public: +public: char getOp() { return op; } ExprAST *getLHS() { return lhs.get(); } ExprAST *getRHS() { return rhs.get(); } BinaryExprAST(Location loc, char op, std::unique_ptr lhs, std::unique_ptr rhs) - : ExprAST(Expr_BinOp, std::move(loc)), - op(op), - lhs(std::move(lhs)), + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} /// LLVM style RTTI @@ -173,11 +169,10 @@ class CallExprAST : public ExprAST { std::string callee; std::vector> args; - public: +public: CallExprAST(Location loc, const std::string &callee, std::vector> args) - : ExprAST(Expr_Call, std::move(loc)), - callee(callee), + : ExprAST(Expr_Call, std::move(loc)), callee(callee), args(std::move(args)) {} llvm::StringRef getCallee() { return callee; } @@ -191,7 +186,7 @@ class CallExprAST : public ExprAST { class PrintExprAST : public ExprAST { std::unique_ptr arg; - public: +public: PrintExprAST(Location loc, std::unique_ptr arg) : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} @@ -209,7 +204,7 @@ class PrototypeAST { std::string name; std::vector> args; - public: +public: PrototypeAST(Location location, const std::string &name, std::vector> args) : location(std::move(location)), name(name), args(std::move(args)) {} @@ -224,7 +219,7 @@ class FunctionAST { std::unique_ptr proto; std::unique_ptr body; - public: +public: FunctionAST(std::unique_ptr proto, std::unique_ptr body) : proto(std::move(proto)), body(std::move(body)) {} @@ -236,7 +231,7 @@ class FunctionAST { class ModuleAST { std::vector functions; - public: +public: ModuleAST(std::vector functions) : functions(std::move(functions)) {} @@ -246,6 +241,6 @@ class ModuleAST { void dump(ModuleAST &); -} // namespace toy +} // namespace toy -#endif // TOY_AST_H +#endif // TOY_AST_H diff --git a/mlir/example/Ch5/include/toy/Dialect.h b/mlir/example/Ch5/include/toy/Dialect.h index 8aae4b8..5db325e 100644 --- a/mlir/example/Ch5/include/toy/Dialect.h +++ b/mlir/example/Ch5/include/toy/Dialect.h @@ -14,12 +14,13 @@ #ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ #define MLIR_TUTORIAL_TOY_DIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" @@ -32,4 +33,4 @@ #define GET_OP_CLASSES #include "toy/Ops.h.inc" -#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/example/Ch5/include/toy/Lexer.h b/mlir/example/Ch5/include/toy/Lexer.h index 176a40c..3c59cd9 100644 --- a/mlir/example/Ch5/include/toy/Lexer.h +++ b/mlir/example/Ch5/include/toy/Lexer.h @@ -13,18 +13,18 @@ #ifndef TOY_LEXER_H #define TOY_LEXER_H +#include "llvm/ADT/StringRef.h" + #include #include -#include "llvm/ADT/StringRef.h" - namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename. - int line; ///< line number. - int col; ///< column number. + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. }; // List of Token returned by the lexer. @@ -56,7 +56,7 @@ enum Token : int { /// can proceed by reading the next line from the standard input or from a /// memory mapped file. class Lexer { - public: +public: /// Create a lexer for the given filename. The filename is kept only for /// debugging purpose (attaching a location to a Token). Lexer(std::string filename) @@ -98,7 +98,7 @@ class Lexer { // Return the current column in the file. int getCol() { return curCol; } - private: +private: /// Delegate to a derived class fetching the next line. Returns an empty /// string to signal end of file (EOF). Lines are expected to always finish /// with "\n" @@ -109,11 +109,13 @@ class Lexer { /// needed. int getNextChar() { // The current line buffer should not be empty unless it is the end of file. - if (curLineBuffer.empty()) return EOF; + if (curLineBuffer.empty()) + return EOF; ++curCol; auto nextchar = curLineBuffer.front(); curLineBuffer = curLineBuffer.drop_front(); - if (curLineBuffer.empty()) curLineBuffer = readNextLine(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); if (nextchar == '\n') { ++curLineNum; curCol = 0; @@ -124,7 +126,8 @@ class Lexer { /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(lastChar)) lastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; @@ -136,9 +139,12 @@ class Lexer { while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') identifierStr += (char)lastChar; - if (identifierStr == "return") return tok_return; - if (identifierStr == "def") return tok_def; - if (identifierStr == "var") return tok_var; + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; return tok_identifier; } @@ -160,11 +166,13 @@ class Lexer { lastChar = Token(getNextChar()); } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (lastChar != EOF) return getTok(); + if (lastChar != EOF) + return getTok(); } // Check for end of file. Don't eat the EOF. - if (lastChar == EOF) return tok_eof; + if (lastChar == EOF) + return tok_eof; // Otherwise, just return the character as its ascii value. Token thisChar = Token(lastChar); @@ -201,22 +209,24 @@ class Lexer { /// A lexer implementation operating on a buffer in memory. class LexerBuffer final : public Lexer { - public: +public: LexerBuffer(const char *begin, const char *end, std::string filename) : Lexer(std::move(filename)), current(begin), end(end) {} - private: +private: /// Provide one line at a time to the Lexer, return an empty string when /// reaching the end of the buffer. llvm::StringRef readNextLine() override { auto *begin = current; - while (current <= end && *current && *current != '\n') ++current; - if (current <= end && *current) ++current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; llvm::StringRef result{begin, static_cast(current - begin)}; return result; } const char *current, *end; }; -} // namespace toy +} // namespace toy -#endif // TOY_LEXER_H +#endif // TOY_LEXER_H diff --git a/mlir/example/Ch5/include/toy/MLIRGen.h b/mlir/example/Ch5/include/toy/MLIRGen.h index 5afdf62..fe9dbe5 100644 --- a/mlir/example/Ch5/include/toy/MLIRGen.h +++ b/mlir/example/Ch5/include/toy/MLIRGen.h @@ -21,7 +21,7 @@ class MLIRContext; template class OwningOpRef; class ModuleOp; -} // namespace mlir +} // namespace mlir namespace toy { class ModuleAST; @@ -30,6 +30,6 @@ class ModuleAST; /// or nullptr on failure. mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); -} // namespace toy +} // namespace toy -#endif // TOY_MLIRGEN_H +#endif // TOY_MLIRGEN_H diff --git a/mlir/example/Ch5/include/toy/Ops.td b/mlir/example/Ch5/include/toy/Ops.td index c97d22a..ec6762f 100644 --- a/mlir/example/Ch5/include/toy/Ops.td +++ b/mlir/example/Ch5/include/toy/Ops.td @@ -13,7 +13,7 @@ #ifndef TOY_OPS #define TOY_OPS -include "mlir/IR/FunctionInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -141,8 +141,7 @@ def CastOp : Toy_Op<"cast", [ //===----------------------------------------------------------------------===// def FuncOp : Toy_Op<"func", [ - DeclareOpInterfaceMethods, FunctionOpInterface, - IsolatedFromAbove + FunctionOpInterface, IsolatedFromAbove ]> { let summary = "user defined function operation"; let description = [{ @@ -183,6 +182,9 @@ def FuncOp : Toy_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the function operation that is callable. + Region *getCallableRegion() { return &getBody(); } }]; let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; diff --git a/mlir/example/Ch5/include/toy/Parser.h b/mlir/example/Ch5/include/toy/Parser.h index ededa4c..1f20616 100644 --- a/mlir/example/Ch5/include/toy/Parser.h +++ b/mlir/example/Ch5/include/toy/Parser.h @@ -14,16 +14,17 @@ #ifndef TOY_PARSER_H #define TOY_PARSER_H -#include -#include -#include -#include +#include "toy/AST.h" +#include "toy/Lexer.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" -#include "toy/AST.h" -#include "toy/Lexer.h" + +#include +#include +#include +#include namespace toy { @@ -33,19 +34,20 @@ namespace toy { /// string and the code could reference an undeclared variable and the parsing /// succeeds. class Parser { - public: +public: /// Create a Parser for the supplied lexer. Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. std::unique_ptr parseModule() { - lexer.getNextToken(); // prime the lexer + lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector functions; while (auto f = parseDefinition()) { functions.push_back(std::move(*f)); - if (lexer.getCurToken() == tok_eof) break; + if (lexer.getCurToken() == tok_eof) + break; } // If we didn't reach EOF, there was an error during parsing if (lexer.getCurToken() != tok_eof) @@ -54,7 +56,7 @@ class Parser { return std::make_unique(std::move(functions)); } - private: +private: Lexer &lexer; /// Parse a return statement. @@ -67,7 +69,8 @@ class Parser { std::optional> expr; if (lexer.getCurToken() != ';') { expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; } return std::make_unique(std::move(loc), std::move(expr)); } @@ -97,7 +100,8 @@ class Parser { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { values.push_back(parseTensorLiteralExpr()); - if (!values.back()) return nullptr; // parse error in the nested array. + if (!values.back()) + return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError(" or [", "in literal expression"); @@ -105,17 +109,18 @@ class Parser { } // End of this list on ']' - if (lexer.getCurToken() == ']') break; + if (lexer.getCurToken() == ']') + break; // Elements are separated by a comma. if (lexer.getCurToken() != ',') return parseError("] or ,", "in literal expression"); - lexer.getNextToken(); // eat , + lexer.getNextToken(); // eat , } while (true); if (values.empty()) return parseError("", "to fill literal expression"); - lexer.getNextToken(); // eat ] + lexer.getNextToken(); // eat ] /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); @@ -151,9 +156,10 @@ class Parser { /// parenexpr ::= '(' expression ')' std::unique_ptr parseParenExpr() { - lexer.getNextToken(); // eat (. + lexer.getNextToken(); // eat (. auto v = parseExpression(); - if (!v) return nullptr; + if (!v) + return nullptr; if (lexer.getCurToken() != ')') return parseError(")", "to close expression with parentheses"); @@ -168,9 +174,9 @@ class Parser { std::string name(lexer.getId()); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. + lexer.getNextToken(); // eat identifier. - if (lexer.getCurToken() != '(') // Simple variable ref. + if (lexer.getCurToken() != '(') // Simple variable ref. return std::make_unique(std::move(loc), name); // This is a function call. @@ -183,7 +189,8 @@ class Parser { else return nullptr; - if (lexer.getCurToken() == ')') break; + if (lexer.getCurToken() == ')') + break; if (lexer.getCurToken() != ',') return parseError(", or )", "in argument list"); @@ -211,22 +218,22 @@ class Parser { /// ::= tensorliteral std::unique_ptr parsePrimary() { switch (lexer.getCurToken()) { - default: - llvm::errs() << "unknown token '" << lexer.getCurToken() - << "' when expecting an expression\n"; - return nullptr; - case tok_identifier: - return parseIdentifierExpr(); - case tok_number: - return parseNumberExpr(); - case '(': - return parseParenExpr(); - case '[': - return parseTensorLiteralExpr(); - case ';': - return nullptr; - case '}': - return nullptr; + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; } } @@ -242,7 +249,8 @@ class Parser { // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (tokPrec < exprPrec) return lhs; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. int binOp = lexer.getCurToken(); @@ -259,7 +267,8 @@ class Parser { int nextPrec = getTokPrecedence(); if (tokPrec < nextPrec) { rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; } // Merge lhs/RHS. @@ -271,7 +280,8 @@ class Parser { /// expression::= primary binop rhs std::unique_ptr parseExpression() { auto lhs = parsePrimary(); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; return parseBinOpRHS(0, std::move(lhs)); } @@ -281,19 +291,20 @@ class Parser { std::unique_ptr parseType() { if (lexer.getCurToken() != '<') return parseError("<", "to begin type"); - lexer.getNextToken(); // eat < + lexer.getNextToken(); // eat < auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); lexer.getNextToken(); - if (lexer.getCurToken() == ',') lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); } if (lexer.getCurToken() != '>') return parseError(">", "to end type"); - lexer.getNextToken(); // eat > + lexer.getNextToken(); // eat > return type; } @@ -305,21 +316,23 @@ class Parser { if (lexer.getCurToken() != tok_var) return parseError("var", "to begin declaration"); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat var + lexer.getNextToken(); // eat var if (lexer.getCurToken() != tok_identifier) return parseError("identified", "after 'var' declaration"); std::string id(lexer.getId()); - lexer.getNextToken(); // eat id + lexer.getNextToken(); // eat id - std::unique_ptr type; // Type is optional, it can be inferred + std::unique_ptr type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { type = parseType(); - if (!type) return nullptr; + if (!type) + return nullptr; } - if (!type) type = std::make_unique(); + if (!type) + type = std::make_unique(); lexer.consume(Token('=')); auto expr = parseExpression(); return std::make_unique(std::move(loc), std::move(id), @@ -340,23 +353,27 @@ class Parser { auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration auto varDecl = parseDeclaration(); - if (!varDecl) return nullptr; + if (!varDecl) + return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement auto ret = parseReturn(); - if (!ret) return nullptr; + if (!ret) + return nullptr; exprList->push_back(std::move(ret)); } else { // General expression auto expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; exprList->push_back(std::move(expr)); } // Ensure that elements are separated by a semicolon. @@ -364,7 +381,8 @@ class Parser { return parseError(";", "after expression"); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); } if (lexer.getCurToken() != '}') @@ -401,7 +419,8 @@ class Parser { lexer.consume(tok_identifier); auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); - if (lexer.getCurToken() != ',') break; + if (lexer.getCurToken() != ',') + break; lexer.consume(Token(',')); if (lexer.getCurToken() != tok_identifier) return parseError( @@ -423,7 +442,8 @@ class Parser { /// definition ::= prototype block std::unique_ptr parseDefinition() { auto proto = parsePrototype(); - if (!proto) return nullptr; + if (!proto) + return nullptr; if (auto block = parseBlock()) return std::make_unique(std::move(proto), std::move(block)); @@ -432,18 +452,19 @@ class Parser { /// Get the precedence of the pending binary operator token. int getTokPrecedence() { - if (!isascii(lexer.getCurToken())) return -1; + if (!isascii(lexer.getCurToken())) + return -1; // 1 is lowest precedence. switch (static_cast(lexer.getCurToken())) { - case '-': - return 20; - case '+': - return 20; - case '*': - return 40; - default: - return -1; + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; } } @@ -456,12 +477,13 @@ class Parser { llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " << lexer.getLastLocation().col << "): expected '" << expected << "' " << context << " but has Token " << curToken; - if (isprint(curToken)) llvm::errs() << " '" << (char)curToken << "'"; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; llvm::errs() << "\n"; return nullptr; } }; -} // namespace toy +} // namespace toy -#endif // TOY_PARSER_H +#endif // TOY_PARSER_H diff --git a/mlir/example/Ch5/include/toy/Passes.h b/mlir/example/Ch5/include/toy/Passes.h index 1825c29..02a83cf 100644 --- a/mlir/example/Ch5/include/toy/Passes.h +++ b/mlir/example/Ch5/include/toy/Passes.h @@ -25,7 +25,7 @@ std::unique_ptr createShapeInferencePass(); /// for a subset of the Toy IR (e.g. matmul). std::unique_ptr createLowerToAffinePass(); -} // namespace toy -} // namespace mlir +} // namespace toy +} // namespace mlir -#endif // TOY_PASSES_H +#endif // TOY_PASSES_H diff --git a/mlir/example/Ch5/include/toy/ShapeInferenceInterface.h b/mlir/example/Ch5/include/toy/ShapeInferenceInterface.h index a32ef17..cfe5a87 100644 --- a/mlir/example/Ch5/include/toy/ShapeInferenceInterface.h +++ b/mlir/example/Ch5/include/toy/ShapeInferenceInterface.h @@ -22,7 +22,7 @@ namespace toy { /// Include the auto-generated declarations. #include "toy/ShapeInferenceOpInterfaces.h.inc" -} // namespace toy -} // namespace mlir +} // namespace toy +} // namespace mlir -#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/example/Ch5/mlir/Dialect.cpp b/mlir/example/Ch5/mlir/Dialect.cpp index 6ec105a..72072f9 100644 --- a/mlir/example/Ch5/mlir/Dialect.cpp +++ b/mlir/example/Ch5/mlir/Dialect.cpp @@ -13,11 +13,25 @@ #include "toy/Dialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/FunctionImplementation.h" +#include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include using namespace mlir; using namespace mlir::toy; @@ -59,8 +73,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -190,11 +203,10 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { /// Verifier for the constant operation. This corresponds to the /// `let hasVerifier = 1` in the op definition. -mlir::LogicalResult ConstantOp::verify() { +llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = - llvm::dyn_cast(getResult().getType()); + auto resultType = llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -299,27 +311,6 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { getArgAttrsAttrName(), getResAttrsAttrName()); } -/// Returns the region on the function operation that is callable. -mlir::Region *FuncOp::getCallableRegion() { return &getBody(); } - -/// Returns the results types that the callable region produces when -/// executed. -llvm::ArrayRef FuncOp::getCallableResults() { - return getFunctionType().getResults(); -} - -/// Returns the argument attributes for all callable region arguments or -/// null if there are none. -ArrayAttr FuncOp::getCallableArgAttrs() { - return getArgAttrs().value_or(nullptr); -} - -/// Returns the result attributes for all callable region results or -/// null if there are none. -ArrayAttr FuncOp::getCallableResAttrs() { - return getResAttrs().value_or(nullptr); -} - //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// @@ -349,6 +340,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -374,7 +371,7 @@ void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } // ReturnOp //===----------------------------------------------------------------------===// -mlir::LogicalResult ReturnOp::verify() { +llvm::LogicalResult ReturnOp::verify() { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. auto function = cast((*this)->getParentOp()); @@ -398,8 +395,7 @@ mlir::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || - llvm::isa(inputType) || + if (inputType == resultType || llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); @@ -424,7 +420,7 @@ void TransposeOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } -mlir::LogicalResult TransposeOp::verify() { +llvm::LogicalResult TransposeOp::verify() { auto inputType = llvm::dyn_cast(getOperand().getType()); auto resultType = llvm::dyn_cast(getType()); if (!inputType || !resultType) diff --git a/mlir/example/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/example/Ch5/mlir/LowerToAffineLoops.cpp index 03e87ba..7413214 100644 --- a/mlir/example/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/example/Ch5/mlir/LowerToAffineLoops.cpp @@ -12,16 +12,34 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include using namespace mlir; @@ -154,8 +172,7 @@ struct ConstantOpLowering : public OpRewritePattern { SmallVector constantIndices; if (!valueShape.empty()) { - for (auto i : llvm::seq( - 0, *std::max_element(valueShape.begin(), valueShape.end()))) + for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) constantIndices.push_back( rewriter.create(loc, i)); } else { @@ -241,8 +258,8 @@ struct PrintOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter) const final { // We don't lower "toy.print" in this pass, but we need to update its // operands. - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/example/Ch5/mlir/MLIRGen.cpp b/mlir/example/Ch5/mlir/MLIRGen.cpp index 58ec129..b56e2f7 100644 --- a/mlir/example/Ch5/mlir/MLIRGen.cpp +++ b/mlir/example/Ch5/mlir/MLIRGen.cpp @@ -12,20 +12,30 @@ //===----------------------------------------------------------------------===// #include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "toy/AST.h" +#include "toy/Dialect.h" -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" -#include "toy/AST.h" -#include "toy/Dialect.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include using namespace mlir::toy; using namespace toy; @@ -47,7 +57,7 @@ namespace { /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. class MLIRGenImpl { - public: +public: MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR @@ -57,7 +67,8 @@ class MLIRGenImpl { // add them to the module. theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); - for (FunctionAST &f : moduleAST) mlirGen(f); + for (FunctionAST &f : moduleAST) + mlirGen(f); // Verify the module after we have finished constructing it, this will check // the structural properties of the IR and invoke any specific verifiers we @@ -70,7 +81,7 @@ class MLIRGenImpl { return theModule; } - private: +private: /// A "module" matches a Toy source file: containing a list of functions. mlir::ModuleOp theModule; @@ -93,8 +104,9 @@ class MLIRGenImpl { /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { - if (symbolTable.count(var)) return mlir::failure(); + llvm::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); symbolTable.insert(var, value); return mlir::success(); } @@ -121,7 +133,8 @@ class MLIRGenImpl { // Create an MLIR function for the given prototype. builder.setInsertionPointToEnd(theModule.getBody()); mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); - if (!function) return nullptr; + if (!function) + return nullptr; // Let's start the body of the function now! mlir::Block &entryBlock = function.front(); @@ -150,7 +163,8 @@ class MLIRGenImpl { // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) ReturnOp returnOp; - if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { builder.create(loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { @@ -161,7 +175,8 @@ class MLIRGenImpl { } // If this function isn't main, then set the visibility to private. - if (funcAST.getProto()->getName() != "main") function.setPrivate(); + if (funcAST.getProto()->getName() != "main") + function.setPrivate(); return function; } @@ -180,18 +195,20 @@ class MLIRGenImpl { // and propagate. // mlir::Value lhs = mlirGen(*binop.getLHS()); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; mlir::Value rhs = mlirGen(*binop.getRHS()); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; auto location = loc(binop.loc()); // Derive the operation name from the binary operator. At the moment we only // support '+' and '*'. switch (binop.getOp()) { - case '+': - return builder.create(location, lhs, rhs); - case '*': - return builder.create(location, lhs, rhs); + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -202,7 +219,8 @@ class MLIRGenImpl { /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. mlir::Value mlirGen(VariableExprAST &expr) { - if (auto variable = symbolTable.lookup(expr.getName())) return variable; + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'"; @@ -210,13 +228,14 @@ class MLIRGenImpl { } /// Emit a return operation. This will return failure if any generation fails. - mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + llvm::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. mlir::Value expr = nullptr; if (ret.getExpr().has_value()) { - if (!(expr = mlirGen(**ret.getExpr()))) return mlir::failure(); + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); } // Otherwise, this return operation has zero operands. @@ -278,7 +297,8 @@ class MLIRGenImpl { /// Attributes are the way MLIR attaches constant to operations. void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { - for (auto &value : lit->getValues()) collectData(*value, data); + for (auto &value : lit->getValues()) + collectData(*value, data); return; } @@ -296,7 +316,8 @@ class MLIRGenImpl { SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); - if (!arg) return nullptr; + if (!arg) + return nullptr; operands.push_back(arg); } @@ -304,9 +325,8 @@ class MLIRGenImpl { // straightforward emission. if (callee == "transpose") { if (call.getArgs().size() != 1) { - emitError(location, - "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); return nullptr; } return builder.create(location, operands[0]); @@ -320,9 +340,10 @@ class MLIRGenImpl { /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). - mlir::LogicalResult mlirGen(PrintExprAST &call) { + llvm::LogicalResult mlirGen(PrintExprAST &call) { auto arg = mlirGen(*call.getArg()); - if (!arg) return mlir::failure(); + if (!arg) + return mlir::failure(); builder.create(loc(call.loc()), arg); return mlir::success(); @@ -336,21 +357,21 @@ class MLIRGenImpl { /// Dispatch codegen for the right expression subclass using RTTI. mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast(expr)); - default: - emitError(loc(expr.loc())) - << "MLIR codegen encountered an unhandled expr kind '" - << Twine(expr.getKind()) << "'"; - return nullptr; + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; } } @@ -367,7 +388,8 @@ class MLIRGenImpl { } mlir::Value value = mlirGen(*init); - if (!value) return nullptr; + if (!value) + return nullptr; // We have the initializer value, but in case the variable was declared // with specific shape, we emit a "reshape" operation. It will get @@ -378,29 +400,34 @@ class MLIRGenImpl { } // Register the value in the symbol table. - if (failed(declare(vardecl.getName(), value))) return nullptr; + if (failed(declare(vardecl.getName(), value))) + return nullptr; return value; } /// Codegen a list of expression, return failure if one of them hit an error. - mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + llvm::LogicalResult mlirGen(ExprASTList &blockAST) { ScopedHashTableScope varScope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { - if (!mlirGen(*vardecl)) return mlir::failure(); + if (!mlirGen(*vardecl)) + return mlir::failure(); continue; } - if (auto *ret = dyn_cast(expr.get())) return mlirGen(*ret); + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); if (auto *print = dyn_cast(expr.get())) { - if (mlir::failed(mlirGen(*print))) return mlir::success(); + if (mlir::failed(mlirGen(*print))) + return mlir::success(); continue; } // Generic expression dispatch codegen. - if (!mlirGen(*expr)) return mlir::failure(); + if (!mlirGen(*expr)) + return mlir::failure(); } return mlir::success(); } @@ -420,7 +447,7 @@ class MLIRGenImpl { mlir::Type getType(const VarType &type) { return getType(type.shape); } }; -} // namespace +} // namespace namespace toy { @@ -430,4 +457,4 @@ mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, return MLIRGenImpl(context).mlirGen(moduleAST); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch5/mlir/ShapeInferencePass.cpp b/mlir/example/Ch5/mlir/ShapeInferencePass.cpp index d45baa1..a9e995e 100644 --- a/mlir/example/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/example/Ch5/mlir/ShapeInferencePass.cpp @@ -11,13 +11,21 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "toy/Dialect.h" #include "toy/Passes.h" #include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "shape-inference" diff --git a/mlir/example/Ch5/mlir/ToyCombine.cpp b/mlir/example/Ch5/mlir/ToyCombine.cpp index d124e86..f8397c2 100644 --- a/mlir/example/Ch5/mlir/ToyCombine.cpp +++ b/mlir/example/Ch5/mlir/ToyCombine.cpp @@ -11,10 +11,9 @@ // //===----------------------------------------------------------------------===// -#include - -#include "mlir/IR/Matchers.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "toy/Dialect.h" using namespace mlir; using namespace toy; @@ -22,7 +21,7 @@ using namespace toy; namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "ToyCombine.inc" -} // namespace +} // namespace /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x @@ -36,14 +35,16 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::LogicalResult matchAndRewrite( - TransposeOp op, mlir::PatternRewriter &rewriter) const override { + llvm::LogicalResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. - if (!transposeInputOp) return failure(); + if (!transposeInputOp) + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); diff --git a/mlir/example/Ch5/parser/AST.cpp b/mlir/example/Ch5/parser/AST.cpp index e0bd2fd..2546f2a 100644 --- a/mlir/example/Ch5/parser/AST.cpp +++ b/mlir/example/Ch5/parser/AST.cpp @@ -12,9 +12,12 @@ #include "toy/AST.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include using namespace toy; @@ -31,10 +34,10 @@ struct Indent { /// Helper class that implement the AST tree traversal and print the nodes along /// the way. The only data member is the current indentation level. class ASTDumper { - public: +public: void dump(ModuleAST *node); - private: +private: void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); @@ -51,12 +54,13 @@ class ASTDumper { // Actually print spaces matching the current indentation level void indent() { - for (int i = 0; i < curIndent; i++) llvm::errs() << " "; + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; } int curIndent = 0; }; -} // namespace +} // namespace /// Return a formatted string for the location of any node template @@ -69,8 +73,8 @@ static std::string loc(T *node) { // Helper Macro to bump the indentation level and print the leading spaces for // the current indentations -#define INDENT() \ - Indent level_(curIndent); \ +#define INDENT() \ + Indent level_(curIndent); \ indent(); /// Dispatch to a generic expressions to the appropriate subclass using RTTI @@ -100,7 +104,8 @@ void ASTDumper::dump(VarDeclExprAST *varDecl) { void ASTDumper::dump(ExprASTList *exprList) { INDENT(); llvm::errs() << "Block {\n"; - for (auto &expr : *exprList) dump(expr.get()); + for (auto &expr : *exprList) + dump(expr.get()); indent(); llvm::errs() << "} // Block\n"; } @@ -153,7 +158,8 @@ void ASTDumper::dump(VariableExprAST *node) { void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (node->getExpr().has_value()) return dump(*node->getExpr()); + if (node->getExpr().has_value()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -173,7 +179,8 @@ void ASTDumper::dump(BinaryExprAST *node) { void ASTDumper::dump(CallExprAST *node) { INDENT(); llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; - for (auto &arg : node->getArgs()) dump(arg.get()); + for (auto &arg : node->getArgs()) + dump(arg.get()); indent(); llvm::errs() << "]\n"; } @@ -218,7 +225,8 @@ void ASTDumper::dump(FunctionAST *node) { void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &f : *node) dump(&f); + for (auto &f : *node) + dump(&f); } namespace toy { @@ -226,4 +234,4 @@ namespace toy { // Public API void dump(ModuleAST &module) { ASTDumper().dump(&module); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch5/toyc.cpp b/mlir/example/Ch5/toyc.cpp index 004abcf..6a0c631 100644 --- a/mlir/example/Ch5/toyc.cpp +++ b/mlir/example/Ch5/toyc.cpp @@ -10,27 +10,35 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/IR/Diagnostics.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" #include "mlir/InitAllDialects.h" #include "mlir/Parser/Parser.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "toy/Dialect.h" -#include "toy/MLIRGen.h" -#include "toy/Parser.h" -#include "toy/Passes.h" + #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include using namespace toy; namespace cl = llvm::cl; @@ -79,7 +87,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && - !llvm::StringRef(inputFilename).endswith(".mlir")) { + !llvm::StringRef(inputFilename).ends_with(".mlir")) { auto moduleAST = parseInputFile(inputFilename); if (!moduleAST) return 6; diff --git a/mlir/example/Ch6/include/toy/AST.h b/mlir/example/Ch6/include/toy/AST.h index c9d1bdb..d2ba101 100644 --- a/mlir/example/Ch6/include/toy/AST.h +++ b/mlir/example/Ch6/include/toy/AST.h @@ -15,14 +15,14 @@ #ifndef TOY_AST_H #define TOY_AST_H -#include -#include -#include +#include "toy/Lexer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "toy/Lexer.h" +#include +#include +#include namespace toy { @@ -33,7 +33,7 @@ struct VarType { /// Base class for all expression nodes. class ExprAST { - public: +public: enum ExprASTKind { Expr_VarDecl, Expr_Return, @@ -53,7 +53,7 @@ class ExprAST { const Location &loc() { return location; } - private: +private: const ExprASTKind kind; Location location; }; @@ -65,7 +65,7 @@ using ExprASTList = std::vector>; class NumberExprAST : public ExprAST { double val; - public: +public: NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, std::move(loc)), val(val) {} @@ -80,11 +80,10 @@ class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; - public: +public: LiteralExprAST(Location loc, std::vector> values, std::vector dims) - : ExprAST(Expr_Literal, std::move(loc)), - values(std::move(values)), + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), dims(std::move(dims)) {} llvm::ArrayRef> getValues() { return values; } @@ -98,7 +97,7 @@ class LiteralExprAST : public ExprAST { class VariableExprAST : public ExprAST { std::string name; - public: +public: VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, std::move(loc)), name(name) {} @@ -114,13 +113,11 @@ class VarDeclExprAST : public ExprAST { VarType type; std::unique_ptr initVal; - public: +public: VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr initVal) - : ExprAST(Expr_VarDecl, std::move(loc)), - name(name), - type(std::move(type)), - initVal(std::move(initVal)) {} + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } @@ -134,12 +131,13 @@ class VarDeclExprAST : public ExprAST { class ReturnExprAST : public ExprAST { std::optional> expr; - public: +public: ReturnExprAST(Location loc, std::optional> expr) : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} std::optional getExpr() { - if (expr.has_value()) return expr->get(); + if (expr.has_value()) + return expr->get(); return std::nullopt; } @@ -152,16 +150,14 @@ class BinaryExprAST : public ExprAST { char op; std::unique_ptr lhs, rhs; - public: +public: char getOp() { return op; } ExprAST *getLHS() { return lhs.get(); } ExprAST *getRHS() { return rhs.get(); } BinaryExprAST(Location loc, char op, std::unique_ptr lhs, std::unique_ptr rhs) - : ExprAST(Expr_BinOp, std::move(loc)), - op(op), - lhs(std::move(lhs)), + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} /// LLVM style RTTI @@ -173,11 +169,10 @@ class CallExprAST : public ExprAST { std::string callee; std::vector> args; - public: +public: CallExprAST(Location loc, const std::string &callee, std::vector> args) - : ExprAST(Expr_Call, std::move(loc)), - callee(callee), + : ExprAST(Expr_Call, std::move(loc)), callee(callee), args(std::move(args)) {} llvm::StringRef getCallee() { return callee; } @@ -191,7 +186,7 @@ class CallExprAST : public ExprAST { class PrintExprAST : public ExprAST { std::unique_ptr arg; - public: +public: PrintExprAST(Location loc, std::unique_ptr arg) : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} @@ -209,7 +204,7 @@ class PrototypeAST { std::string name; std::vector> args; - public: +public: PrototypeAST(Location location, const std::string &name, std::vector> args) : location(std::move(location)), name(name), args(std::move(args)) {} @@ -224,7 +219,7 @@ class FunctionAST { std::unique_ptr proto; std::unique_ptr body; - public: +public: FunctionAST(std::unique_ptr proto, std::unique_ptr body) : proto(std::move(proto)), body(std::move(body)) {} @@ -236,7 +231,7 @@ class FunctionAST { class ModuleAST { std::vector functions; - public: +public: ModuleAST(std::vector functions) : functions(std::move(functions)) {} @@ -246,6 +241,6 @@ class ModuleAST { void dump(ModuleAST &); -} // namespace toy +} // namespace toy -#endif // TOY_AST_H +#endif // TOY_AST_H diff --git a/mlir/example/Ch6/include/toy/Dialect.h b/mlir/example/Ch6/include/toy/Dialect.h index 8aae4b8..5db325e 100644 --- a/mlir/example/Ch6/include/toy/Dialect.h +++ b/mlir/example/Ch6/include/toy/Dialect.h @@ -14,12 +14,13 @@ #ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ #define MLIR_TUTORIAL_TOY_DIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" @@ -32,4 +33,4 @@ #define GET_OP_CLASSES #include "toy/Ops.h.inc" -#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/example/Ch6/include/toy/Lexer.h b/mlir/example/Ch6/include/toy/Lexer.h index 176a40c..3c59cd9 100644 --- a/mlir/example/Ch6/include/toy/Lexer.h +++ b/mlir/example/Ch6/include/toy/Lexer.h @@ -13,18 +13,18 @@ #ifndef TOY_LEXER_H #define TOY_LEXER_H +#include "llvm/ADT/StringRef.h" + #include #include -#include "llvm/ADT/StringRef.h" - namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename. - int line; ///< line number. - int col; ///< column number. + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. }; // List of Token returned by the lexer. @@ -56,7 +56,7 @@ enum Token : int { /// can proceed by reading the next line from the standard input or from a /// memory mapped file. class Lexer { - public: +public: /// Create a lexer for the given filename. The filename is kept only for /// debugging purpose (attaching a location to a Token). Lexer(std::string filename) @@ -98,7 +98,7 @@ class Lexer { // Return the current column in the file. int getCol() { return curCol; } - private: +private: /// Delegate to a derived class fetching the next line. Returns an empty /// string to signal end of file (EOF). Lines are expected to always finish /// with "\n" @@ -109,11 +109,13 @@ class Lexer { /// needed. int getNextChar() { // The current line buffer should not be empty unless it is the end of file. - if (curLineBuffer.empty()) return EOF; + if (curLineBuffer.empty()) + return EOF; ++curCol; auto nextchar = curLineBuffer.front(); curLineBuffer = curLineBuffer.drop_front(); - if (curLineBuffer.empty()) curLineBuffer = readNextLine(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); if (nextchar == '\n') { ++curLineNum; curCol = 0; @@ -124,7 +126,8 @@ class Lexer { /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(lastChar)) lastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; @@ -136,9 +139,12 @@ class Lexer { while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') identifierStr += (char)lastChar; - if (identifierStr == "return") return tok_return; - if (identifierStr == "def") return tok_def; - if (identifierStr == "var") return tok_var; + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; return tok_identifier; } @@ -160,11 +166,13 @@ class Lexer { lastChar = Token(getNextChar()); } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (lastChar != EOF) return getTok(); + if (lastChar != EOF) + return getTok(); } // Check for end of file. Don't eat the EOF. - if (lastChar == EOF) return tok_eof; + if (lastChar == EOF) + return tok_eof; // Otherwise, just return the character as its ascii value. Token thisChar = Token(lastChar); @@ -201,22 +209,24 @@ class Lexer { /// A lexer implementation operating on a buffer in memory. class LexerBuffer final : public Lexer { - public: +public: LexerBuffer(const char *begin, const char *end, std::string filename) : Lexer(std::move(filename)), current(begin), end(end) {} - private: +private: /// Provide one line at a time to the Lexer, return an empty string when /// reaching the end of the buffer. llvm::StringRef readNextLine() override { auto *begin = current; - while (current <= end && *current && *current != '\n') ++current; - if (current <= end && *current) ++current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; llvm::StringRef result{begin, static_cast(current - begin)}; return result; } const char *current, *end; }; -} // namespace toy +} // namespace toy -#endif // TOY_LEXER_H +#endif // TOY_LEXER_H diff --git a/mlir/example/Ch6/include/toy/MLIRGen.h b/mlir/example/Ch6/include/toy/MLIRGen.h index 5afdf62..fe9dbe5 100644 --- a/mlir/example/Ch6/include/toy/MLIRGen.h +++ b/mlir/example/Ch6/include/toy/MLIRGen.h @@ -21,7 +21,7 @@ class MLIRContext; template class OwningOpRef; class ModuleOp; -} // namespace mlir +} // namespace mlir namespace toy { class ModuleAST; @@ -30,6 +30,6 @@ class ModuleAST; /// or nullptr on failure. mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); -} // namespace toy +} // namespace toy -#endif // TOY_MLIRGEN_H +#endif // TOY_MLIRGEN_H diff --git a/mlir/example/Ch6/include/toy/Ops.td b/mlir/example/Ch6/include/toy/Ops.td index 157e207..a52bebc 100644 --- a/mlir/example/Ch6/include/toy/Ops.td +++ b/mlir/example/Ch6/include/toy/Ops.td @@ -13,7 +13,7 @@ #ifndef TOY_OPS #define TOY_OPS -include "mlir/IR/FunctionInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -141,8 +141,7 @@ def CastOp : Toy_Op<"cast", [ //===----------------------------------------------------------------------===// def FuncOp : Toy_Op<"func", [ - DeclareOpInterfaceMethods, FunctionOpInterface, - IsolatedFromAbove + FunctionOpInterface, IsolatedFromAbove ]> { let summary = "user defined function operation"; let description = [{ @@ -183,6 +182,9 @@ def FuncOp : Toy_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the function operation that is callable. + Region *getCallableRegion() { return &getBody(); } }]; let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; diff --git a/mlir/example/Ch6/include/toy/Parser.h b/mlir/example/Ch6/include/toy/Parser.h index ededa4c..1f20616 100644 --- a/mlir/example/Ch6/include/toy/Parser.h +++ b/mlir/example/Ch6/include/toy/Parser.h @@ -14,16 +14,17 @@ #ifndef TOY_PARSER_H #define TOY_PARSER_H -#include -#include -#include -#include +#include "toy/AST.h" +#include "toy/Lexer.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" -#include "toy/AST.h" -#include "toy/Lexer.h" + +#include +#include +#include +#include namespace toy { @@ -33,19 +34,20 @@ namespace toy { /// string and the code could reference an undeclared variable and the parsing /// succeeds. class Parser { - public: +public: /// Create a Parser for the supplied lexer. Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. std::unique_ptr parseModule() { - lexer.getNextToken(); // prime the lexer + lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector functions; while (auto f = parseDefinition()) { functions.push_back(std::move(*f)); - if (lexer.getCurToken() == tok_eof) break; + if (lexer.getCurToken() == tok_eof) + break; } // If we didn't reach EOF, there was an error during parsing if (lexer.getCurToken() != tok_eof) @@ -54,7 +56,7 @@ class Parser { return std::make_unique(std::move(functions)); } - private: +private: Lexer &lexer; /// Parse a return statement. @@ -67,7 +69,8 @@ class Parser { std::optional> expr; if (lexer.getCurToken() != ';') { expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; } return std::make_unique(std::move(loc), std::move(expr)); } @@ -97,7 +100,8 @@ class Parser { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { values.push_back(parseTensorLiteralExpr()); - if (!values.back()) return nullptr; // parse error in the nested array. + if (!values.back()) + return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError(" or [", "in literal expression"); @@ -105,17 +109,18 @@ class Parser { } // End of this list on ']' - if (lexer.getCurToken() == ']') break; + if (lexer.getCurToken() == ']') + break; // Elements are separated by a comma. if (lexer.getCurToken() != ',') return parseError("] or ,", "in literal expression"); - lexer.getNextToken(); // eat , + lexer.getNextToken(); // eat , } while (true); if (values.empty()) return parseError("", "to fill literal expression"); - lexer.getNextToken(); // eat ] + lexer.getNextToken(); // eat ] /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); @@ -151,9 +156,10 @@ class Parser { /// parenexpr ::= '(' expression ')' std::unique_ptr parseParenExpr() { - lexer.getNextToken(); // eat (. + lexer.getNextToken(); // eat (. auto v = parseExpression(); - if (!v) return nullptr; + if (!v) + return nullptr; if (lexer.getCurToken() != ')') return parseError(")", "to close expression with parentheses"); @@ -168,9 +174,9 @@ class Parser { std::string name(lexer.getId()); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. + lexer.getNextToken(); // eat identifier. - if (lexer.getCurToken() != '(') // Simple variable ref. + if (lexer.getCurToken() != '(') // Simple variable ref. return std::make_unique(std::move(loc), name); // This is a function call. @@ -183,7 +189,8 @@ class Parser { else return nullptr; - if (lexer.getCurToken() == ')') break; + if (lexer.getCurToken() == ')') + break; if (lexer.getCurToken() != ',') return parseError(", or )", "in argument list"); @@ -211,22 +218,22 @@ class Parser { /// ::= tensorliteral std::unique_ptr parsePrimary() { switch (lexer.getCurToken()) { - default: - llvm::errs() << "unknown token '" << lexer.getCurToken() - << "' when expecting an expression\n"; - return nullptr; - case tok_identifier: - return parseIdentifierExpr(); - case tok_number: - return parseNumberExpr(); - case '(': - return parseParenExpr(); - case '[': - return parseTensorLiteralExpr(); - case ';': - return nullptr; - case '}': - return nullptr; + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; } } @@ -242,7 +249,8 @@ class Parser { // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (tokPrec < exprPrec) return lhs; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. int binOp = lexer.getCurToken(); @@ -259,7 +267,8 @@ class Parser { int nextPrec = getTokPrecedence(); if (tokPrec < nextPrec) { rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; } // Merge lhs/RHS. @@ -271,7 +280,8 @@ class Parser { /// expression::= primary binop rhs std::unique_ptr parseExpression() { auto lhs = parsePrimary(); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; return parseBinOpRHS(0, std::move(lhs)); } @@ -281,19 +291,20 @@ class Parser { std::unique_ptr parseType() { if (lexer.getCurToken() != '<') return parseError("<", "to begin type"); - lexer.getNextToken(); // eat < + lexer.getNextToken(); // eat < auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); lexer.getNextToken(); - if (lexer.getCurToken() == ',') lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); } if (lexer.getCurToken() != '>') return parseError(">", "to end type"); - lexer.getNextToken(); // eat > + lexer.getNextToken(); // eat > return type; } @@ -305,21 +316,23 @@ class Parser { if (lexer.getCurToken() != tok_var) return parseError("var", "to begin declaration"); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat var + lexer.getNextToken(); // eat var if (lexer.getCurToken() != tok_identifier) return parseError("identified", "after 'var' declaration"); std::string id(lexer.getId()); - lexer.getNextToken(); // eat id + lexer.getNextToken(); // eat id - std::unique_ptr type; // Type is optional, it can be inferred + std::unique_ptr type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { type = parseType(); - if (!type) return nullptr; + if (!type) + return nullptr; } - if (!type) type = std::make_unique(); + if (!type) + type = std::make_unique(); lexer.consume(Token('=')); auto expr = parseExpression(); return std::make_unique(std::move(loc), std::move(id), @@ -340,23 +353,27 @@ class Parser { auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration auto varDecl = parseDeclaration(); - if (!varDecl) return nullptr; + if (!varDecl) + return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement auto ret = parseReturn(); - if (!ret) return nullptr; + if (!ret) + return nullptr; exprList->push_back(std::move(ret)); } else { // General expression auto expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; exprList->push_back(std::move(expr)); } // Ensure that elements are separated by a semicolon. @@ -364,7 +381,8 @@ class Parser { return parseError(";", "after expression"); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); } if (lexer.getCurToken() != '}') @@ -401,7 +419,8 @@ class Parser { lexer.consume(tok_identifier); auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); - if (lexer.getCurToken() != ',') break; + if (lexer.getCurToken() != ',') + break; lexer.consume(Token(',')); if (lexer.getCurToken() != tok_identifier) return parseError( @@ -423,7 +442,8 @@ class Parser { /// definition ::= prototype block std::unique_ptr parseDefinition() { auto proto = parsePrototype(); - if (!proto) return nullptr; + if (!proto) + return nullptr; if (auto block = parseBlock()) return std::make_unique(std::move(proto), std::move(block)); @@ -432,18 +452,19 @@ class Parser { /// Get the precedence of the pending binary operator token. int getTokPrecedence() { - if (!isascii(lexer.getCurToken())) return -1; + if (!isascii(lexer.getCurToken())) + return -1; // 1 is lowest precedence. switch (static_cast(lexer.getCurToken())) { - case '-': - return 20; - case '+': - return 20; - case '*': - return 40; - default: - return -1; + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; } } @@ -456,12 +477,13 @@ class Parser { llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " << lexer.getLastLocation().col << "): expected '" << expected << "' " << context << " but has Token " << curToken; - if (isprint(curToken)) llvm::errs() << " '" << (char)curToken << "'"; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; llvm::errs() << "\n"; return nullptr; } }; -} // namespace toy +} // namespace toy -#endif // TOY_PARSER_H +#endif // TOY_PARSER_H diff --git a/mlir/example/Ch6/include/toy/Passes.h b/mlir/example/Ch6/include/toy/Passes.h index 3888afe..62471dd 100644 --- a/mlir/example/Ch6/include/toy/Passes.h +++ b/mlir/example/Ch6/include/toy/Passes.h @@ -29,7 +29,7 @@ std::unique_ptr createLowerToAffinePass(); /// well as `Affine` and `Std`, to the LLVM dialect for codegen. std::unique_ptr createLowerToLLVMPass(); -} // namespace toy -} // namespace mlir +} // namespace toy +} // namespace mlir -#endif // TOY_PASSES_H +#endif // TOY_PASSES_H diff --git a/mlir/example/Ch6/include/toy/ShapeInferenceInterface.h b/mlir/example/Ch6/include/toy/ShapeInferenceInterface.h index a32ef17..cfe5a87 100644 --- a/mlir/example/Ch6/include/toy/ShapeInferenceInterface.h +++ b/mlir/example/Ch6/include/toy/ShapeInferenceInterface.h @@ -22,7 +22,7 @@ namespace toy { /// Include the auto-generated declarations. #include "toy/ShapeInferenceOpInterfaces.h.inc" -} // namespace toy -} // namespace mlir +} // namespace toy +} // namespace mlir -#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/example/Ch6/mlir/Dialect.cpp b/mlir/example/Ch6/mlir/Dialect.cpp index 6ec105a..72072f9 100644 --- a/mlir/example/Ch6/mlir/Dialect.cpp +++ b/mlir/example/Ch6/mlir/Dialect.cpp @@ -13,11 +13,25 @@ #include "toy/Dialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/FunctionImplementation.h" +#include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include using namespace mlir; using namespace mlir::toy; @@ -59,8 +73,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -190,11 +203,10 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { /// Verifier for the constant operation. This corresponds to the /// `let hasVerifier = 1` in the op definition. -mlir::LogicalResult ConstantOp::verify() { +llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = - llvm::dyn_cast(getResult().getType()); + auto resultType = llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -299,27 +311,6 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { getArgAttrsAttrName(), getResAttrsAttrName()); } -/// Returns the region on the function operation that is callable. -mlir::Region *FuncOp::getCallableRegion() { return &getBody(); } - -/// Returns the results types that the callable region produces when -/// executed. -llvm::ArrayRef FuncOp::getCallableResults() { - return getFunctionType().getResults(); -} - -/// Returns the argument attributes for all callable region arguments or -/// null if there are none. -ArrayAttr FuncOp::getCallableArgAttrs() { - return getArgAttrs().value_or(nullptr); -} - -/// Returns the result attributes for all callable region results or -/// null if there are none. -ArrayAttr FuncOp::getCallableResAttrs() { - return getResAttrs().value_or(nullptr); -} - //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// @@ -349,6 +340,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -374,7 +371,7 @@ void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } // ReturnOp //===----------------------------------------------------------------------===// -mlir::LogicalResult ReturnOp::verify() { +llvm::LogicalResult ReturnOp::verify() { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. auto function = cast((*this)->getParentOp()); @@ -398,8 +395,7 @@ mlir::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || - llvm::isa(inputType) || + if (inputType == resultType || llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); @@ -424,7 +420,7 @@ void TransposeOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } -mlir::LogicalResult TransposeOp::verify() { +llvm::LogicalResult TransposeOp::verify() { auto inputType = llvm::dyn_cast(getOperand().getType()); auto resultType = llvm::dyn_cast(getType()); if (!inputType || !resultType) diff --git a/mlir/example/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/example/Ch6/mlir/LowerToAffineLoops.cpp index 03e87ba..7413214 100644 --- a/mlir/example/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/example/Ch6/mlir/LowerToAffineLoops.cpp @@ -12,16 +12,34 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include using namespace mlir; @@ -154,8 +172,7 @@ struct ConstantOpLowering : public OpRewritePattern { SmallVector constantIndices; if (!valueShape.empty()) { - for (auto i : llvm::seq( - 0, *std::max_element(valueShape.begin(), valueShape.end()))) + for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) constantIndices.push_back( rewriter.create(loc, i)); } else { @@ -241,8 +258,8 @@ struct PrintOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter) const final { // We don't lower "toy.print" in this pass, but we need to update its // operands. - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/example/Ch6/mlir/LowerToLLVM.cpp b/mlir/example/Ch6/mlir/LowerToLLVM.cpp index ab28f02..3ad70e7 100644 --- a/mlir/example/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/example/Ch6/mlir/LowerToLLVM.cpp @@ -22,6 +22,16 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -31,7 +41,6 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -39,9 +48,9 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" -#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include using namespace mlir; @@ -60,6 +69,7 @@ class PrintOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + auto *context = rewriter.getContext(); auto memRefType = llvm::cast((*op->operand_type_begin())); auto memRefShape = memRefType.getShape(); auto loc = op->getLoc(); @@ -91,8 +101,8 @@ class PrintOpLowering : public ConversionPattern { // Insert a newline after each of the inner dimensions of the shape. if (i != e - 1) - rewriter.create(loc, printfRef, - rewriter.getIntegerType(32), newLineCst); + rewriter.create(loc, getPrintfType(context), printfRef, + newLineCst); rewriter.create(loc); rewriter.setInsertionPointToStart(loop.getBody()); } @@ -101,8 +111,8 @@ class PrintOpLowering : public ConversionPattern { auto printOp = cast(op); auto elementLoad = rewriter.create(loc, printOp.getInput(), loopIvs); - rewriter.create( - loc, printfRef, rewriter.getIntegerType(32), + rewriter.create( + loc, getPrintfType(context), printfRef, ArrayRef({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. @@ -111,6 +121,16 @@ class PrintOpLowering : public ConversionPattern { } private: + /// Create a function declaration for printf, the signature is: + /// * `i32 (i8*, ...)` + static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { + auto llvmI32Ty = IntegerType::get(context, 32); + auto llvmPtrTy = LLVM::LLVMPointerType::get(context); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, + /*isVarArg=*/true); + return llvmFnType; + } + /// Return a symbol reference to the printf function, inserting it into the /// module if necessary. static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, @@ -119,17 +139,11 @@ class PrintOpLowering : public ConversionPattern { if (module.lookupSymbol("printf")) return SymbolRefAttr::get(context, "printf"); - // Create a function declaration for printf, the signature is: - // * `i32 (i8*, ...)` - auto llvmI32Ty = IntegerType::get(context, 32); - auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, - /*isVarArg=*/true); - // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), "printf", llvmFnType); + rewriter.create(module.getLoc(), "printf", + getPrintfType(context)); return SymbolRefAttr::get(context, "printf"); } @@ -156,8 +170,7 @@ class PrintOpLowering : public ConversionPattern { Value cst0 = builder.create(loc, builder.getI64Type(), builder.getIndexAttr(0)); return builder.create( - loc, - LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)), + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), globalPtr, ArrayRef({cst0, cst0})); } }; diff --git a/mlir/example/Ch6/mlir/MLIRGen.cpp b/mlir/example/Ch6/mlir/MLIRGen.cpp index 58ec129..b56e2f7 100644 --- a/mlir/example/Ch6/mlir/MLIRGen.cpp +++ b/mlir/example/Ch6/mlir/MLIRGen.cpp @@ -12,20 +12,30 @@ //===----------------------------------------------------------------------===// #include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "toy/AST.h" +#include "toy/Dialect.h" -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" -#include "toy/AST.h" -#include "toy/Dialect.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include using namespace mlir::toy; using namespace toy; @@ -47,7 +57,7 @@ namespace { /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. class MLIRGenImpl { - public: +public: MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR @@ -57,7 +67,8 @@ class MLIRGenImpl { // add them to the module. theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); - for (FunctionAST &f : moduleAST) mlirGen(f); + for (FunctionAST &f : moduleAST) + mlirGen(f); // Verify the module after we have finished constructing it, this will check // the structural properties of the IR and invoke any specific verifiers we @@ -70,7 +81,7 @@ class MLIRGenImpl { return theModule; } - private: +private: /// A "module" matches a Toy source file: containing a list of functions. mlir::ModuleOp theModule; @@ -93,8 +104,9 @@ class MLIRGenImpl { /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { - if (symbolTable.count(var)) return mlir::failure(); + llvm::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); symbolTable.insert(var, value); return mlir::success(); } @@ -121,7 +133,8 @@ class MLIRGenImpl { // Create an MLIR function for the given prototype. builder.setInsertionPointToEnd(theModule.getBody()); mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); - if (!function) return nullptr; + if (!function) + return nullptr; // Let's start the body of the function now! mlir::Block &entryBlock = function.front(); @@ -150,7 +163,8 @@ class MLIRGenImpl { // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) ReturnOp returnOp; - if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { builder.create(loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { @@ -161,7 +175,8 @@ class MLIRGenImpl { } // If this function isn't main, then set the visibility to private. - if (funcAST.getProto()->getName() != "main") function.setPrivate(); + if (funcAST.getProto()->getName() != "main") + function.setPrivate(); return function; } @@ -180,18 +195,20 @@ class MLIRGenImpl { // and propagate. // mlir::Value lhs = mlirGen(*binop.getLHS()); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; mlir::Value rhs = mlirGen(*binop.getRHS()); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; auto location = loc(binop.loc()); // Derive the operation name from the binary operator. At the moment we only // support '+' and '*'. switch (binop.getOp()) { - case '+': - return builder.create(location, lhs, rhs); - case '*': - return builder.create(location, lhs, rhs); + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -202,7 +219,8 @@ class MLIRGenImpl { /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. mlir::Value mlirGen(VariableExprAST &expr) { - if (auto variable = symbolTable.lookup(expr.getName())) return variable; + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'"; @@ -210,13 +228,14 @@ class MLIRGenImpl { } /// Emit a return operation. This will return failure if any generation fails. - mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + llvm::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. mlir::Value expr = nullptr; if (ret.getExpr().has_value()) { - if (!(expr = mlirGen(**ret.getExpr()))) return mlir::failure(); + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); } // Otherwise, this return operation has zero operands. @@ -278,7 +297,8 @@ class MLIRGenImpl { /// Attributes are the way MLIR attaches constant to operations. void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { - for (auto &value : lit->getValues()) collectData(*value, data); + for (auto &value : lit->getValues()) + collectData(*value, data); return; } @@ -296,7 +316,8 @@ class MLIRGenImpl { SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); - if (!arg) return nullptr; + if (!arg) + return nullptr; operands.push_back(arg); } @@ -304,9 +325,8 @@ class MLIRGenImpl { // straightforward emission. if (callee == "transpose") { if (call.getArgs().size() != 1) { - emitError(location, - "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); return nullptr; } return builder.create(location, operands[0]); @@ -320,9 +340,10 @@ class MLIRGenImpl { /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). - mlir::LogicalResult mlirGen(PrintExprAST &call) { + llvm::LogicalResult mlirGen(PrintExprAST &call) { auto arg = mlirGen(*call.getArg()); - if (!arg) return mlir::failure(); + if (!arg) + return mlir::failure(); builder.create(loc(call.loc()), arg); return mlir::success(); @@ -336,21 +357,21 @@ class MLIRGenImpl { /// Dispatch codegen for the right expression subclass using RTTI. mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast(expr)); - default: - emitError(loc(expr.loc())) - << "MLIR codegen encountered an unhandled expr kind '" - << Twine(expr.getKind()) << "'"; - return nullptr; + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; } } @@ -367,7 +388,8 @@ class MLIRGenImpl { } mlir::Value value = mlirGen(*init); - if (!value) return nullptr; + if (!value) + return nullptr; // We have the initializer value, but in case the variable was declared // with specific shape, we emit a "reshape" operation. It will get @@ -378,29 +400,34 @@ class MLIRGenImpl { } // Register the value in the symbol table. - if (failed(declare(vardecl.getName(), value))) return nullptr; + if (failed(declare(vardecl.getName(), value))) + return nullptr; return value; } /// Codegen a list of expression, return failure if one of them hit an error. - mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + llvm::LogicalResult mlirGen(ExprASTList &blockAST) { ScopedHashTableScope varScope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { - if (!mlirGen(*vardecl)) return mlir::failure(); + if (!mlirGen(*vardecl)) + return mlir::failure(); continue; } - if (auto *ret = dyn_cast(expr.get())) return mlirGen(*ret); + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); if (auto *print = dyn_cast(expr.get())) { - if (mlir::failed(mlirGen(*print))) return mlir::success(); + if (mlir::failed(mlirGen(*print))) + return mlir::success(); continue; } // Generic expression dispatch codegen. - if (!mlirGen(*expr)) return mlir::failure(); + if (!mlirGen(*expr)) + return mlir::failure(); } return mlir::success(); } @@ -420,7 +447,7 @@ class MLIRGenImpl { mlir::Type getType(const VarType &type) { return getType(type.shape); } }; -} // namespace +} // namespace namespace toy { @@ -430,4 +457,4 @@ mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, return MLIRGenImpl(context).mlirGen(moduleAST); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch6/mlir/ShapeInferencePass.cpp b/mlir/example/Ch6/mlir/ShapeInferencePass.cpp index d45baa1..a9e995e 100644 --- a/mlir/example/Ch6/mlir/ShapeInferencePass.cpp +++ b/mlir/example/Ch6/mlir/ShapeInferencePass.cpp @@ -11,13 +11,21 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "toy/Dialect.h" #include "toy/Passes.h" #include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "shape-inference" diff --git a/mlir/example/Ch6/mlir/ToyCombine.cpp b/mlir/example/Ch6/mlir/ToyCombine.cpp index d124e86..f8397c2 100644 --- a/mlir/example/Ch6/mlir/ToyCombine.cpp +++ b/mlir/example/Ch6/mlir/ToyCombine.cpp @@ -11,10 +11,9 @@ // //===----------------------------------------------------------------------===// -#include - -#include "mlir/IR/Matchers.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "toy/Dialect.h" using namespace mlir; using namespace toy; @@ -22,7 +21,7 @@ using namespace toy; namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "ToyCombine.inc" -} // namespace +} // namespace /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x @@ -36,14 +35,16 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::LogicalResult matchAndRewrite( - TransposeOp op, mlir::PatternRewriter &rewriter) const override { + llvm::LogicalResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. - if (!transposeInputOp) return failure(); + if (!transposeInputOp) + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); diff --git a/mlir/example/Ch6/parser/AST.cpp b/mlir/example/Ch6/parser/AST.cpp index e0bd2fd..2546f2a 100644 --- a/mlir/example/Ch6/parser/AST.cpp +++ b/mlir/example/Ch6/parser/AST.cpp @@ -12,9 +12,12 @@ #include "toy/AST.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include using namespace toy; @@ -31,10 +34,10 @@ struct Indent { /// Helper class that implement the AST tree traversal and print the nodes along /// the way. The only data member is the current indentation level. class ASTDumper { - public: +public: void dump(ModuleAST *node); - private: +private: void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); @@ -51,12 +54,13 @@ class ASTDumper { // Actually print spaces matching the current indentation level void indent() { - for (int i = 0; i < curIndent; i++) llvm::errs() << " "; + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; } int curIndent = 0; }; -} // namespace +} // namespace /// Return a formatted string for the location of any node template @@ -69,8 +73,8 @@ static std::string loc(T *node) { // Helper Macro to bump the indentation level and print the leading spaces for // the current indentations -#define INDENT() \ - Indent level_(curIndent); \ +#define INDENT() \ + Indent level_(curIndent); \ indent(); /// Dispatch to a generic expressions to the appropriate subclass using RTTI @@ -100,7 +104,8 @@ void ASTDumper::dump(VarDeclExprAST *varDecl) { void ASTDumper::dump(ExprASTList *exprList) { INDENT(); llvm::errs() << "Block {\n"; - for (auto &expr : *exprList) dump(expr.get()); + for (auto &expr : *exprList) + dump(expr.get()); indent(); llvm::errs() << "} // Block\n"; } @@ -153,7 +158,8 @@ void ASTDumper::dump(VariableExprAST *node) { void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (node->getExpr().has_value()) return dump(*node->getExpr()); + if (node->getExpr().has_value()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -173,7 +179,8 @@ void ASTDumper::dump(BinaryExprAST *node) { void ASTDumper::dump(CallExprAST *node) { INDENT(); llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; - for (auto &arg : node->getArgs()) dump(arg.get()); + for (auto &arg : node->getArgs()) + dump(arg.get()); indent(); llvm::errs() << "]\n"; } @@ -218,7 +225,8 @@ void ASTDumper::dump(FunctionAST *node) { void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &f : *node) dump(&f); + for (auto &f : *node) + dump(&f); } namespace toy { @@ -226,4 +234,4 @@ namespace toy { // Public API void dump(ModuleAST &module) { ASTDumper().dump(&module); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch6/toyc.cpp b/mlir/example/Ch6/toyc.cpp index a7bc397..c244b31 100644 --- a/mlir/example/Ch6/toyc.cpp +++ b/mlir/example/Ch6/toyc.cpp @@ -10,8 +10,16 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" @@ -21,17 +29,14 @@ #include "mlir/IR/Verifier.h" #include "mlir/InitAllDialects.h" #include "mlir/Parser/Parser.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/Passes.h" -#include "toy/Dialect.h" -#include "toy/MLIRGen.h" -#include "toy/Parser.h" -#include "toy/Passes.h" + #include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" @@ -39,6 +44,11 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include using namespace toy; namespace cl = llvm::cl; @@ -101,7 +111,7 @@ int loadMLIR(mlir::MLIRContext &context, mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && - !llvm::StringRef(inputFilename).endswith(".mlir")) { + !llvm::StringRef(inputFilename).ends_with(".mlir")) { auto moduleAST = parseInputFile(inputFilename); if (!moduleAST) return 6; @@ -176,8 +186,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, // This is necessary to have line tables emitted and basic // debugger working. In the future we will add proper debug information // emission directly from our frontend. - pm.addNestedPass( - mlir::LLVM::createDIScopeForLLVMFuncOpPass()); + pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); } if (mlir::failed(pm.run(*module))) diff --git a/mlir/example/Ch7/include/toy/AST.h b/mlir/example/Ch7/include/toy/AST.h index faaf6e8..42d64ed 100644 --- a/mlir/example/Ch7/include/toy/AST.h +++ b/mlir/example/Ch7/include/toy/AST.h @@ -15,14 +15,14 @@ #ifndef TOY_AST_H #define TOY_AST_H -#include -#include -#include +#include "toy/Lexer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "toy/Lexer.h" +#include +#include +#include namespace toy { @@ -34,7 +34,7 @@ struct VarType { /// Base class for all expression nodes. class ExprAST { - public: +public: enum ExprASTKind { Expr_VarDecl, Expr_Return, @@ -55,7 +55,7 @@ class ExprAST { const Location &loc() { return location; } - private: +private: const ExprASTKind kind; Location location; }; @@ -67,7 +67,7 @@ using ExprASTList = std::vector>; class NumberExprAST : public ExprAST { double val; - public: +public: NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, std::move(loc)), val(val) {} @@ -82,11 +82,10 @@ class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; - public: +public: LiteralExprAST(Location loc, std::vector> values, std::vector dims) - : ExprAST(Expr_Literal, std::move(loc)), - values(std::move(values)), + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), dims(std::move(dims)) {} llvm::ArrayRef> getValues() { return values; } @@ -100,11 +99,11 @@ class LiteralExprAST : public ExprAST { class StructLiteralExprAST : public ExprAST { std::vector> values; - public: +public: StructLiteralExprAST(Location loc, std::vector> values) - : ExprAST(Expr_StructLiteral, std::move(loc)), - values(std::move(values)) {} + : ExprAST(Expr_StructLiteral, std::move(loc)), values(std::move(values)) { + } llvm::ArrayRef> getValues() { return values; } @@ -118,7 +117,7 @@ class StructLiteralExprAST : public ExprAST { class VariableExprAST : public ExprAST { std::string name; - public: +public: VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, std::move(loc)), name(name) {} @@ -134,13 +133,11 @@ class VarDeclExprAST : public ExprAST { VarType type; std::unique_ptr initVal; - public: +public: VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr initVal = nullptr) - : ExprAST(Expr_VarDecl, std::move(loc)), - name(name), - type(std::move(type)), - initVal(std::move(initVal)) {} + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } @@ -154,12 +151,13 @@ class VarDeclExprAST : public ExprAST { class ReturnExprAST : public ExprAST { std::optional> expr; - public: +public: ReturnExprAST(Location loc, std::optional> expr) : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} std::optional getExpr() { - if (expr.has_value()) return expr->get(); + if (expr.has_value()) + return expr->get(); return std::nullopt; } @@ -172,16 +170,14 @@ class BinaryExprAST : public ExprAST { char op; std::unique_ptr lhs, rhs; - public: +public: char getOp() { return op; } ExprAST *getLHS() { return lhs.get(); } ExprAST *getRHS() { return rhs.get(); } BinaryExprAST(Location loc, char op, std::unique_ptr lhs, std::unique_ptr rhs) - : ExprAST(Expr_BinOp, std::move(loc)), - op(op), - lhs(std::move(lhs)), + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} /// LLVM style RTTI @@ -193,11 +189,10 @@ class CallExprAST : public ExprAST { std::string callee; std::vector> args; - public: +public: CallExprAST(Location loc, const std::string &callee, std::vector> args) - : ExprAST(Expr_Call, std::move(loc)), - callee(callee), + : ExprAST(Expr_Call, std::move(loc)), callee(callee), args(std::move(args)) {} llvm::StringRef getCallee() { return callee; } @@ -211,7 +206,7 @@ class CallExprAST : public ExprAST { class PrintExprAST : public ExprAST { std::unique_ptr arg; - public: +public: PrintExprAST(Location loc, std::unique_ptr arg) : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} @@ -229,7 +224,7 @@ class PrototypeAST { std::string name; std::vector> args; - public: +public: PrototypeAST(Location location, const std::string &name, std::vector> args) : location(std::move(location)), name(name), args(std::move(args)) {} @@ -241,7 +236,7 @@ class PrototypeAST { /// This class represents a top level record in a module. class RecordAST { - public: +public: enum RecordASTKind { Record_Function, Record_Struct, @@ -252,7 +247,7 @@ class RecordAST { RecordASTKind getKind() const { return kind; } - private: +private: const RecordASTKind kind; }; @@ -261,11 +256,10 @@ class FunctionAST : public RecordAST { std::unique_ptr proto; std::unique_ptr body; - public: +public: FunctionAST(std::unique_ptr proto, std::unique_ptr body) - : RecordAST(Record_Function), - proto(std::move(proto)), + : RecordAST(Record_Function), proto(std::move(proto)), body(std::move(body)) {} PrototypeAST *getProto() { return proto.get(); } ExprASTList *getBody() { return body.get(); } @@ -282,12 +276,10 @@ class StructAST : public RecordAST { std::string name; std::vector> variables; - public: +public: StructAST(Location location, const std::string &name, std::vector> variables) - : RecordAST(Record_Struct), - location(std::move(location)), - name(name), + : RecordAST(Record_Struct), location(std::move(location)), name(name), variables(std::move(variables)) {} const Location &loc() { return location; } @@ -306,7 +298,7 @@ class StructAST : public RecordAST { class ModuleAST { std::vector> records; - public: +public: ModuleAST(std::vector> records) : records(std::move(records)) {} @@ -316,6 +308,6 @@ class ModuleAST { void dump(ModuleAST &); -} // namespace toy +} // namespace toy -#endif // TOY_AST_H +#endif // TOY_AST_H diff --git a/mlir/example/Ch7/include/toy/Dialect.h b/mlir/example/Ch7/include/toy/Dialect.h index ff6b816..64094c3 100644 --- a/mlir/example/Ch7/include/toy/Dialect.h +++ b/mlir/example/Ch7/include/toy/Dialect.h @@ -14,12 +14,13 @@ #ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ #define MLIR_TUTORIAL_TOY_DIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" @@ -27,9 +28,9 @@ namespace mlir { namespace toy { namespace detail { struct StructTypeStorage; -} // namespace detail -} // namespace toy -} // namespace mlir +} // namespace detail +} // namespace toy +} // namespace mlir /// Include the auto-generated header file containing the declaration of the toy /// dialect. @@ -58,7 +59,7 @@ namespace toy { /// (StructTypeStorage). class StructType : public mlir::Type::TypeBase { - public: +public: /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; @@ -71,8 +72,11 @@ class StructType : public mlir::Type::TypeBase #include -#include "llvm/ADT/StringRef.h" - namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename. - int line; ///< line number. - int col; ///< column number. + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. }; // List of Token returned by the lexer. @@ -57,7 +57,7 @@ enum Token : int { /// can proceed by reading the next line from the standard input or from a /// memory mapped file. class Lexer { - public: +public: /// Create a lexer for the given filename. The filename is kept only for /// debugging purpose (attaching a location to a Token). Lexer(std::string filename) @@ -99,7 +99,7 @@ class Lexer { // Return the current column in the file. int getCol() { return curCol; } - private: +private: /// Delegate to a derived class fetching the next line. Returns an empty /// string to signal end of file (EOF). Lines are expected to always finish /// with "\n" @@ -110,11 +110,13 @@ class Lexer { /// needed. int getNextChar() { // The current line buffer should not be empty unless it is the end of file. - if (curLineBuffer.empty()) return EOF; + if (curLineBuffer.empty()) + return EOF; ++curCol; auto nextchar = curLineBuffer.front(); curLineBuffer = curLineBuffer.drop_front(); - if (curLineBuffer.empty()) curLineBuffer = readNextLine(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); if (nextchar == '\n') { ++curLineNum; curCol = 0; @@ -125,7 +127,8 @@ class Lexer { /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(lastChar)) lastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; @@ -137,10 +140,14 @@ class Lexer { while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') identifierStr += (char)lastChar; - if (identifierStr == "return") return tok_return; - if (identifierStr == "def") return tok_def; - if (identifierStr == "struct") return tok_struct; - if (identifierStr == "var") return tok_var; + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "struct") + return tok_struct; + if (identifierStr == "var") + return tok_var; return tok_identifier; } @@ -162,11 +169,13 @@ class Lexer { lastChar = Token(getNextChar()); } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (lastChar != EOF) return getTok(); + if (lastChar != EOF) + return getTok(); } // Check for end of file. Don't eat the EOF. - if (lastChar == EOF) return tok_eof; + if (lastChar == EOF) + return tok_eof; // Otherwise, just return the character as its ascii value. Token thisChar = Token(lastChar); @@ -203,22 +212,24 @@ class Lexer { /// A lexer implementation operating on a buffer in memory. class LexerBuffer final : public Lexer { - public: +public: LexerBuffer(const char *begin, const char *end, std::string filename) : Lexer(std::move(filename)), current(begin), end(end) {} - private: +private: /// Provide one line at a time to the Lexer, return an empty string when /// reaching the end of the buffer. llvm::StringRef readNextLine() override { auto *begin = current; - while (current <= end && *current && *current != '\n') ++current; - if (current <= end && *current) ++current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; llvm::StringRef result{begin, static_cast(current - begin)}; return result; } const char *current, *end; }; -} // namespace toy +} // namespace toy -#endif // TOY_LEXER_H +#endif // TOY_LEXER_H diff --git a/mlir/example/Ch7/include/toy/MLIRGen.h b/mlir/example/Ch7/include/toy/MLIRGen.h index 5afdf62..fe9dbe5 100644 --- a/mlir/example/Ch7/include/toy/MLIRGen.h +++ b/mlir/example/Ch7/include/toy/MLIRGen.h @@ -21,7 +21,7 @@ class MLIRContext; template class OwningOpRef; class ModuleOp; -} // namespace mlir +} // namespace mlir namespace toy { class ModuleAST; @@ -30,6 +30,6 @@ class ModuleAST; /// or nullptr on failure. mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); -} // namespace toy +} // namespace toy -#endif // TOY_MLIRGEN_H +#endif // TOY_MLIRGEN_H diff --git a/mlir/example/Ch7/include/toy/Ops.td b/mlir/example/Ch7/include/toy/Ops.td index 422a2eb..cfd6859 100644 --- a/mlir/example/Ch7/include/toy/Ops.td +++ b/mlir/example/Ch7/include/toy/Ops.td @@ -13,7 +13,7 @@ #ifndef TOY_OPS #define TOY_OPS -include "mlir/IR/FunctionInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -165,8 +165,7 @@ def CastOp : Toy_Op<"cast", [ //===----------------------------------------------------------------------===// def FuncOp : Toy_Op<"func", [ - DeclareOpInterfaceMethods, FunctionOpInterface, - IsolatedFromAbove + FunctionOpInterface, IsolatedFromAbove ]> { let summary = "user defined function operation"; let description = [{ @@ -207,6 +206,8 @@ def FuncOp : Toy_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } }]; let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; diff --git a/mlir/example/Ch7/include/toy/Parser.h b/mlir/example/Ch7/include/toy/Parser.h index a5ddacd..7ba7b8f 100644 --- a/mlir/example/Ch7/include/toy/Parser.h +++ b/mlir/example/Ch7/include/toy/Parser.h @@ -14,16 +14,17 @@ #ifndef TOY_PARSER_H #define TOY_PARSER_H -#include -#include -#include -#include +#include "toy/AST.h" +#include "toy/Lexer.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" -#include "toy/AST.h" -#include "toy/Lexer.h" + +#include +#include +#include +#include namespace toy { @@ -33,32 +34,33 @@ namespace toy { /// string and the code could reference an undeclared variable and the parsing /// succeeds. class Parser { - public: +public: /// Create a Parser for the supplied lexer. Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. std::unique_ptr parseModule() { - lexer.getNextToken(); // prime the lexer + lexer.getNextToken(); // prime the lexer // Parse functions and structs one at a time and accumulate in this vector. std::vector> records; while (true) { std::unique_ptr record; switch (lexer.getCurToken()) { - case tok_eof: - break; - case tok_def: - record = parseDefinition(); - break; - case tok_struct: - record = parseStruct(); - break; - default: - return parseError("'def' or 'struct'", - "when parsing top level module records"); + case tok_eof: + break; + case tok_def: + record = parseDefinition(); + break; + case tok_struct: + record = parseStruct(); + break; + default: + return parseError("'def' or 'struct'", + "when parsing top level module records"); } - if (!record) break; + if (!record) + break; records.push_back(std::move(record)); } @@ -69,7 +71,7 @@ class Parser { return std::make_unique(std::move(records)); } - private: +private: Lexer &lexer; /// Parse a return statement. @@ -82,7 +84,8 @@ class Parser { std::optional> expr; if (lexer.getCurToken() != ';') { expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; } return std::make_unique(std::move(loc), std::move(expr)); } @@ -112,7 +115,8 @@ class Parser { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { values.push_back(parseTensorLiteralExpr()); - if (!values.back()) return nullptr; // parse error in the nested array. + if (!values.back()) + return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError(" or [", "in literal expression"); @@ -120,17 +124,18 @@ class Parser { } // End of this list on ']' - if (lexer.getCurToken() == ']') break; + if (lexer.getCurToken() == ']') + break; // Elements are separated by a comma. if (lexer.getCurToken() != ',') return parseError("] or ,", "in literal expression"); - lexer.getNextToken(); // eat , + lexer.getNextToken(); // eat , } while (true); if (values.empty()) return parseError("", "to fill literal expression"); - lexer.getNextToken(); // eat ] + lexer.getNextToken(); // eat ] /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); @@ -176,10 +181,12 @@ class Parser { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { values.push_back(parseTensorLiteralExpr()); - if (!values.back()) return nullptr; + if (!values.back()) + return nullptr; } else if (lexer.getCurToken() == tok_number) { values.push_back(parseNumberExpr()); - if (!values.back()) return nullptr; + if (!values.back()) + return nullptr; } else { if (lexer.getCurToken() != '{') return parseError("{, [, or number", @@ -188,18 +195,19 @@ class Parser { } // End of this list on '}' - if (lexer.getCurToken() == '}') break; + if (lexer.getCurToken() == '}') + break; // Elements are separated by a comma. if (lexer.getCurToken() != ',') return parseError("} or ,", "in struct literal expression"); - lexer.getNextToken(); // eat , + lexer.getNextToken(); // eat , } while (true); if (values.empty()) return parseError("", "to fill struct literal expression"); - lexer.getNextToken(); // eat } + lexer.getNextToken(); // eat } return std::make_unique(std::move(loc), std::move(values)); @@ -207,9 +215,10 @@ class Parser { /// parenexpr ::= '(' expression ')' std::unique_ptr parseParenExpr() { - lexer.getNextToken(); // eat (. + lexer.getNextToken(); // eat (. auto v = parseExpression(); - if (!v) return nullptr; + if (!v) + return nullptr; if (lexer.getCurToken() != ')') return parseError(")", "to close expression with parentheses"); @@ -229,7 +238,8 @@ class Parser { else return nullptr; - if (lexer.getCurToken() == ')') break; + if (lexer.getCurToken() == ')') + break; if (lexer.getCurToken() != ',') return parseError(", or )", "in argument list"); @@ -258,9 +268,9 @@ class Parser { std::string name(lexer.getId()); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. + lexer.getNextToken(); // eat identifier. - if (lexer.getCurToken() != '(') // Simple variable ref. + if (lexer.getCurToken() != '(') // Simple variable ref. return std::make_unique(std::move(loc), name); // This is a function call. @@ -274,24 +284,24 @@ class Parser { /// ::= tensorliteral std::unique_ptr parsePrimary() { switch (lexer.getCurToken()) { - default: - llvm::errs() << "unknown token '" << lexer.getCurToken() - << "' when expecting an expression\n"; - return nullptr; - case tok_identifier: - return parseIdentifierExpr(); - case tok_number: - return parseNumberExpr(); - case '(': - return parseParenExpr(); - case '[': - return parseTensorLiteralExpr(); - case '{': - return parseStructLiteralExpr(); - case ';': - return nullptr; - case '}': - return nullptr; + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case '{': + return parseStructLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; } } @@ -307,7 +317,8 @@ class Parser { // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (tokPrec < exprPrec) return lhs; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. int binOp = lexer.getCurToken(); @@ -324,7 +335,8 @@ class Parser { int nextPrec = getTokPrecedence(); if (tokPrec < nextPrec) { rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; } // Merge lhs/RHS. @@ -336,7 +348,8 @@ class Parser { /// expression::= primary binop rhs std::unique_ptr parseExpression() { auto lhs = parsePrimary(); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; return parseBinOpRHS(0, std::move(lhs)); } @@ -346,19 +359,20 @@ class Parser { std::unique_ptr parseType() { if (lexer.getCurToken() != '<') return parseError("<", "to begin type"); - lexer.getNextToken(); // eat < + lexer.getNextToken(); // eat < auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); lexer.getNextToken(); - if (lexer.getCurToken() == ',') lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); } if (lexer.getCurToken() != '>') return parseError(">", "to end type"); - lexer.getNextToken(); // eat > + lexer.getNextToken(); // eat > return type; } @@ -369,20 +383,22 @@ class Parser { lexer.consume(tok_identifier); // Check for a call expression. - if (lexer.getCurToken() == '(') return parseCallExpr(id, loc); + if (lexer.getCurToken() == '(') + return parseCallExpr(id, loc); // Otherwise, this is a variable declaration. return parseTypedDeclaration(id, /*requiresInitializer=*/true, loc); } /// Parse a typed variable declaration. - std::unique_ptr parseTypedDeclaration( - llvm::StringRef typeName, bool requiresInitializer, const Location &loc) { + std::unique_ptr + parseTypedDeclaration(llvm::StringRef typeName, bool requiresInitializer, + const Location &loc) { // Parse the variable name. if (lexer.getCurToken() != tok_identifier) return parseError("name", "in variable declaration"); std::string id(lexer.getId()); - lexer.getNextToken(); // eat id + lexer.getNextToken(); // eat id // Parse the initializer. std::unique_ptr expr; @@ -414,7 +430,7 @@ class Parser { return parseError("type name", "in variable declaration"); auto loc = lexer.getLastLocation(); std::string typeName(lexer.getId()); - lexer.getNextToken(); // eat id + lexer.getNextToken(); // eat id // Parse the rest of the declaration. return parseTypedDeclaration(typeName, requiresInitializer, loc); @@ -424,25 +440,27 @@ class Parser { /// and identifier and an optional type (shape specification) before the /// optionally required initializer. /// decl ::= var identifier [ type ] (= expr)? - std::unique_ptr parseVarDeclaration( - bool requiresInitializer) { + std::unique_ptr + parseVarDeclaration(bool requiresInitializer) { if (lexer.getCurToken() != tok_var) return parseError("var", "to begin declaration"); auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat var + lexer.getNextToken(); // eat var if (lexer.getCurToken() != tok_identifier) return parseError("identified", "after 'var' declaration"); std::string id(lexer.getId()); - lexer.getNextToken(); // eat id + lexer.getNextToken(); // eat id - std::unique_ptr type; // Type is optional, it can be inferred + std::unique_ptr type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { type = parseType(); - if (!type) return nullptr; + if (!type) + return nullptr; } - if (!type) type = std::make_unique(); + if (!type) + type = std::make_unique(); std::unique_ptr expr; if (requiresInitializer) { @@ -467,28 +485,33 @@ class Parser { auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_identifier) { // Variable declaration or call auto expr = parseDeclarationOrCallExpr(); - if (!expr) return nullptr; + if (!expr) + return nullptr; exprList->push_back(std::move(expr)); } else if (lexer.getCurToken() == tok_var) { // Variable declaration auto varDecl = parseDeclaration(/*requiresInitializer=*/true); - if (!varDecl) return nullptr; + if (!varDecl) + return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement auto ret = parseReturn(); - if (!ret) return nullptr; + if (!ret) + return nullptr; exprList->push_back(std::move(ret)); } else { // General expression auto expr = parseExpression(); - if (!expr) return nullptr; + if (!expr) + return nullptr; exprList->push_back(std::move(expr)); } // Ensure that elements are separated by a semicolon. @@ -496,7 +519,8 @@ class Parser { return parseError(";", "after expression"); // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') lexer.consume(Token(';')); + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); } if (lexer.getCurToken() != '}') @@ -550,7 +574,8 @@ class Parser { args.push_back( std::make_unique(std::move(loc), name, type)); - if (lexer.getCurToken() != ',') break; + if (lexer.getCurToken() != ',') + break; lexer.consume(Token(',')); if (lexer.getCurToken() != tok_identifier) return parseError( @@ -572,7 +597,8 @@ class Parser { /// definition ::= prototype block std::unique_ptr parseDefinition() { auto proto = parsePrototype(); - if (!proto) return nullptr; + if (!proto) + return nullptr; if (auto block = parseBlock()) return std::make_unique(std::move(proto), std::move(block)); @@ -601,7 +627,8 @@ class Parser { std::vector> decls; do { auto decl = parseDeclaration(/*requiresInitializer=*/false); - if (!decl) return nullptr; + if (!decl) + return nullptr; decls.push_back(std::move(decl)); if (lexer.getCurToken() != ';') @@ -617,20 +644,21 @@ class Parser { /// Get the precedence of the pending binary operator token. int getTokPrecedence() { - if (!isascii(lexer.getCurToken())) return -1; + if (!isascii(lexer.getCurToken())) + return -1; // 1 is lowest precedence. switch (static_cast(lexer.getCurToken())) { - case '-': - return 20; - case '+': - return 20; - case '*': - return 40; - case '.': - return 60; - default: - return -1; + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + case '.': + return 60; + default: + return -1; } } @@ -643,12 +671,13 @@ class Parser { llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " << lexer.getLastLocation().col << "): expected '" << expected << "' " << context << " but has Token " << curToken; - if (isprint(curToken)) llvm::errs() << " '" << (char)curToken << "'"; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; llvm::errs() << "\n"; return nullptr; } }; -} // namespace toy +} // namespace toy -#endif // TOY_PARSER_H +#endif // TOY_PARSER_H diff --git a/mlir/example/Ch7/include/toy/Passes.h b/mlir/example/Ch7/include/toy/Passes.h index 3888afe..62471dd 100644 --- a/mlir/example/Ch7/include/toy/Passes.h +++ b/mlir/example/Ch7/include/toy/Passes.h @@ -29,7 +29,7 @@ std::unique_ptr createLowerToAffinePass(); /// well as `Affine` and `Std`, to the LLVM dialect for codegen. std::unique_ptr createLowerToLLVMPass(); -} // namespace toy -} // namespace mlir +} // namespace toy +} // namespace mlir -#endif // TOY_PASSES_H +#endif // TOY_PASSES_H diff --git a/mlir/example/Ch7/include/toy/ShapeInferenceInterface.h b/mlir/example/Ch7/include/toy/ShapeInferenceInterface.h index a32ef17..cfe5a87 100644 --- a/mlir/example/Ch7/include/toy/ShapeInferenceInterface.h +++ b/mlir/example/Ch7/include/toy/ShapeInferenceInterface.h @@ -22,7 +22,7 @@ namespace toy { /// Include the auto-generated declarations. #include "toy/ShapeInferenceOpInterfaces.h.inc" -} // namespace toy -} // namespace mlir +} // namespace toy +} // namespace mlir -#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/example/Ch7/mlir/Dialect.cpp b/mlir/example/Ch7/mlir/Dialect.cpp index 579e6af..7e030ff 100644 --- a/mlir/example/Ch7/mlir/Dialect.cpp +++ b/mlir/example/Ch7/mlir/Dialect.cpp @@ -13,12 +13,31 @@ #include "toy/Dialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/FunctionImplementation.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include using namespace mlir; using namespace mlir::toy; @@ -60,8 +79,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -176,7 +194,7 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { } /// Verify that the given attribute value is valid for the given type. -static mlir::LogicalResult verifyConstantForType(mlir::Type type, +static llvm::LogicalResult verifyConstantForType(mlir::Type type, mlir::Attribute opaqueValue, mlir::Operation *op) { if (llvm::isa(type)) { @@ -233,11 +251,11 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type, /// Verifier for the constant operation. This corresponds to the `::verify(...)` /// in the op definition. -mlir::LogicalResult ConstantOp::verify() { +llvm::LogicalResult ConstantOp::verify() { return verifyConstantForType(getResult().getType(), getValue(), *this); } -mlir::LogicalResult StructConstantOp::verify() { +llvm::LogicalResult StructConstantOp::verify() { return verifyConstantForType(getResult().getType(), getValue(), *this); } @@ -327,27 +345,6 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { getArgAttrsAttrName(), getResAttrsAttrName()); } -/// Returns the region on the function operation that is callable. -mlir::Region *FuncOp::getCallableRegion() { return &getBody(); } - -/// Returns the results types that the callable region produces when -/// executed. -llvm::ArrayRef FuncOp::getCallableResults() { - return getFunctionType().getResults(); -} - -/// Returns the argument attributes for all callable region arguments or -/// null if there are none. -ArrayAttr FuncOp::getCallableArgAttrs() { - return getArgAttrs().value_or(nullptr); -} - -/// Returns the result attributes for all callable region results or -/// null if there are none. -ArrayAttr FuncOp::getCallableResAttrs() { - return getResAttrs().value_or(nullptr); -} - //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// @@ -377,6 +374,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -402,7 +405,7 @@ void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } // ReturnOp //===----------------------------------------------------------------------===// -mlir::LogicalResult ReturnOp::verify() { +llvm::LogicalResult ReturnOp::verify() { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. auto function = cast((*this)->getParentOp()); @@ -426,8 +429,7 @@ mlir::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || - llvm::isa(inputType) || + if (inputType == resultType || llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); @@ -451,7 +453,7 @@ void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state, build(b, state, resultType, input, b.getI64IntegerAttr(index)); } -mlir::LogicalResult StructAccessOp::verify() { +llvm::LogicalResult StructAccessOp::verify() { StructType structTy = llvm::cast(getInput().getType()); size_t indexValue = getIndex(); if (indexValue >= structTy.getNumElementTypes()) @@ -480,7 +482,7 @@ void TransposeOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } -mlir::LogicalResult TransposeOp::verify() { +llvm::LogicalResult TransposeOp::verify() { auto inputType = llvm::dyn_cast(getOperand().getType()); auto resultType = llvm::dyn_cast(getType()); if (!inputType || !resultType) diff --git a/mlir/example/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/example/Ch7/mlir/LowerToAffineLoops.cpp index 03e87ba..7413214 100644 --- a/mlir/example/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/example/Ch7/mlir/LowerToAffineLoops.cpp @@ -12,16 +12,34 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include using namespace mlir; @@ -154,8 +172,7 @@ struct ConstantOpLowering : public OpRewritePattern { SmallVector constantIndices; if (!valueShape.empty()) { - for (auto i : llvm::seq( - 0, *std::max_element(valueShape.begin(), valueShape.end()))) + for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) constantIndices.push_back( rewriter.create(loc, i)); } else { @@ -241,8 +258,8 @@ struct PrintOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter) const final { // We don't lower "toy.print" in this pass, but we need to update its // operands. - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/example/Ch7/mlir/LowerToLLVM.cpp b/mlir/example/Ch7/mlir/LowerToLLVM.cpp index ab28f02..3ad70e7 100644 --- a/mlir/example/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/example/Ch7/mlir/LowerToLLVM.cpp @@ -22,6 +22,16 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -31,7 +41,6 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -39,9 +48,9 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" -#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include using namespace mlir; @@ -60,6 +69,7 @@ class PrintOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + auto *context = rewriter.getContext(); auto memRefType = llvm::cast((*op->operand_type_begin())); auto memRefShape = memRefType.getShape(); auto loc = op->getLoc(); @@ -91,8 +101,8 @@ class PrintOpLowering : public ConversionPattern { // Insert a newline after each of the inner dimensions of the shape. if (i != e - 1) - rewriter.create(loc, printfRef, - rewriter.getIntegerType(32), newLineCst); + rewriter.create(loc, getPrintfType(context), printfRef, + newLineCst); rewriter.create(loc); rewriter.setInsertionPointToStart(loop.getBody()); } @@ -101,8 +111,8 @@ class PrintOpLowering : public ConversionPattern { auto printOp = cast(op); auto elementLoad = rewriter.create(loc, printOp.getInput(), loopIvs); - rewriter.create( - loc, printfRef, rewriter.getIntegerType(32), + rewriter.create( + loc, getPrintfType(context), printfRef, ArrayRef({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. @@ -111,6 +121,16 @@ class PrintOpLowering : public ConversionPattern { } private: + /// Create a function declaration for printf, the signature is: + /// * `i32 (i8*, ...)` + static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { + auto llvmI32Ty = IntegerType::get(context, 32); + auto llvmPtrTy = LLVM::LLVMPointerType::get(context); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, + /*isVarArg=*/true); + return llvmFnType; + } + /// Return a symbol reference to the printf function, inserting it into the /// module if necessary. static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, @@ -119,17 +139,11 @@ class PrintOpLowering : public ConversionPattern { if (module.lookupSymbol("printf")) return SymbolRefAttr::get(context, "printf"); - // Create a function declaration for printf, the signature is: - // * `i32 (i8*, ...)` - auto llvmI32Ty = IntegerType::get(context, 32); - auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, - /*isVarArg=*/true); - // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), "printf", llvmFnType); + rewriter.create(module.getLoc(), "printf", + getPrintfType(context)); return SymbolRefAttr::get(context, "printf"); } @@ -156,8 +170,7 @@ class PrintOpLowering : public ConversionPattern { Value cst0 = builder.create(loc, builder.getI64Type(), builder.getIndexAttr(0)); return builder.create( - loc, - LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)), + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), globalPtr, ArrayRef({cst0, cst0})); } }; diff --git a/mlir/example/Ch7/mlir/MLIRGen.cpp b/mlir/example/Ch7/mlir/MLIRGen.cpp index 5ba4f18..090e5ff 100644 --- a/mlir/example/Ch7/mlir/MLIRGen.cpp +++ b/mlir/example/Ch7/mlir/MLIRGen.cpp @@ -12,21 +12,36 @@ //===----------------------------------------------------------------------===// #include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "toy/AST.h" +#include "toy/Dialect.h" -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" -#include "toy/AST.h" -#include "toy/Dialect.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include using namespace mlir::toy; using namespace toy; @@ -48,7 +63,7 @@ namespace { /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. class MLIRGenImpl { - public: +public: MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR @@ -61,10 +76,12 @@ class MLIRGenImpl { for (auto &record : moduleAST) { if (FunctionAST *funcAST = llvm::dyn_cast(record.get())) { mlir::toy::FuncOp func = mlirGen(*funcAST); - if (!func) return nullptr; + if (!func) + return nullptr; functionMap.insert({func.getName(), func}); } else if (StructAST *str = llvm::dyn_cast(record.get())) { - if (failed(mlirGen(*str))) return nullptr; + if (failed(mlirGen(*str))) + return nullptr; } else { llvm_unreachable("unknown record type"); } @@ -81,7 +98,7 @@ class MLIRGenImpl { return theModule; } - private: +private: /// A "module" matches a Toy source file: containing a list of functions. mlir::ModuleOp theModule; @@ -115,14 +132,15 @@ class MLIRGenImpl { /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) { - if (symbolTable.count(var.getName())) return mlir::failure(); + llvm::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) { + if (symbolTable.count(var.getName())) + return mlir::failure(); symbolTable.insert(var.getName(), {value, &var}); return mlir::success(); } /// Create an MLIR type for the given struct. - mlir::LogicalResult mlirGen(StructAST &str) { + llvm::LogicalResult mlirGen(StructAST &str) { if (structMap.count(str.getName())) return emitError(loc(str.loc())) << "error: struct type with name `" << str.getName() << "' already exists"; @@ -141,7 +159,8 @@ class MLIRGenImpl { "initializers"; mlir::Type type = getType(variable->getType(), variable->loc()); - if (!type) return mlir::failure(); + if (!type) + return mlir::failure(); elementTypes.push_back(type); } @@ -159,7 +178,8 @@ class MLIRGenImpl { argTypes.reserve(proto.getArgs().size()); for (auto &arg : proto.getArgs()) { mlir::Type type = getType(arg->getType(), arg->loc()); - if (!type) return nullptr; + if (!type) + return nullptr; argTypes.push_back(type); } auto funcType = builder.getFunctionType(argTypes, std::nullopt); @@ -175,7 +195,8 @@ class MLIRGenImpl { // Create an MLIR function for the given prototype. builder.setInsertionPointToEnd(theModule.getBody()); mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); - if (!function) return nullptr; + if (!function) + return nullptr; // Let's start the body of the function now! mlir::Block &entryBlock = function.front(); @@ -203,7 +224,8 @@ class MLIRGenImpl { // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) ReturnOp returnOp; - if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { builder.create(loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { @@ -215,7 +237,8 @@ class MLIRGenImpl { } // If this function isn't main, then set the visibility to private. - if (funcAST.getProto()->getName() != "main") function.setPrivate(); + if (funcAST.getProto()->getName() != "main") + function.setPrivate(); return function; } @@ -226,15 +249,19 @@ class MLIRGenImpl { llvm::StringRef structName; if (auto *decl = llvm::dyn_cast(expr)) { auto varIt = symbolTable.lookup(decl->getName()); - if (!varIt.first) return nullptr; + if (!varIt.first) + return nullptr; structName = varIt.second->getType().name; } else if (auto *access = llvm::dyn_cast(expr)) { - if (access->getOp() != '.') return nullptr; + if (access->getOp() != '.') + return nullptr; // The name being accessed should be in the RHS. auto *name = llvm::dyn_cast(access->getRHS()); - if (!name) return nullptr; + if (!name) + return nullptr; StructAST *parentStruct = getStructFor(access->getLHS()); - if (!parentStruct) return nullptr; + if (!parentStruct) + return nullptr; // Get the element within the struct corresponding to the name. VarDeclExprAST *decl = nullptr; @@ -244,14 +271,17 @@ class MLIRGenImpl { break; } } - if (!decl) return nullptr; + if (!decl) + return nullptr; structName = decl->getType().name; } - if (structName.empty()) return nullptr; + if (structName.empty()) + return nullptr; // If the struct name was valid, check for an entry in the struct map. auto structIt = structMap.find(structName); - if (structIt == structMap.end()) return nullptr; + if (structIt == structMap.end()) + return nullptr; return structIt->second.second; } @@ -261,17 +291,20 @@ class MLIRGenImpl { // Lookup the struct node for the LHS. StructAST *structAST = getStructFor(accessOp.getLHS()); - if (!structAST) return std::nullopt; + if (!structAST) + return std::nullopt; // Get the name from the RHS. VariableExprAST *name = llvm::dyn_cast(accessOp.getRHS()); - if (!name) return std::nullopt; + if (!name) + return std::nullopt; auto structVars = structAST->getVariables(); const auto *it = llvm::find_if(structVars, [&](auto &var) { return var->getName() == name->getName(); }); - if (it == structVars.end()) return std::nullopt; + if (it == structVars.end()) + return std::nullopt; return it - structVars.begin(); } @@ -289,7 +322,8 @@ class MLIRGenImpl { // and propagate. // mlir::Value lhs = mlirGen(*binop.getLHS()); - if (!lhs) return nullptr; + if (!lhs) + return nullptr; auto location = loc(binop.loc()); // If this is an access operation, handle it immediately. @@ -304,15 +338,16 @@ class MLIRGenImpl { // Otherwise, this is a normal binary op. mlir::Value rhs = mlirGen(*binop.getRHS()); - if (!rhs) return nullptr; + if (!rhs) + return nullptr; // Derive the operation name from the binary operator. At the moment we only // support '+' and '*'. switch (binop.getOp()) { - case '+': - return builder.create(location, lhs, rhs); - case '*': - return builder.create(location, lhs, rhs); + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -332,13 +367,14 @@ class MLIRGenImpl { } /// Emit a return operation. This will return failure if any generation fails. - mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + llvm::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. mlir::Value expr = nullptr; if (ret.getExpr().has_value()) { - if (!(expr = mlirGen(**ret.getExpr()))) return mlir::failure(); + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); } // Otherwise, this return operation has zero operands. @@ -397,8 +433,8 @@ class MLIRGenImpl { /// other literals in an Attribute attached to a `toy.struct_constant` /// operation. This function returns the generated constant, along with the /// corresponding struct type. - std::pair getConstantAttr( - StructLiteralExprAST &lit) { + std::pair + getConstantAttr(StructLiteralExprAST &lit) { std::vector attrElements; std::vector typeElements; @@ -454,7 +490,8 @@ class MLIRGenImpl { /// Attributes are the way MLIR attaches constant to operations. void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { - for (auto &value : lit->getValues()) collectData(*value, data); + for (auto &value : lit->getValues()) + collectData(*value, data); return; } @@ -472,7 +509,8 @@ class MLIRGenImpl { SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); - if (!arg) return nullptr; + if (!arg) + return nullptr; operands.push_back(arg); } @@ -480,9 +518,8 @@ class MLIRGenImpl { // straightforward emission. if (callee == "transpose") { if (call.getArgs().size() != 1) { - emitError(location, - "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); return nullptr; } return builder.create(location, operands[0]); @@ -504,9 +541,10 @@ class MLIRGenImpl { /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). - mlir::LogicalResult mlirGen(PrintExprAST &call) { + llvm::LogicalResult mlirGen(PrintExprAST &call) { auto arg = mlirGen(*call.getArg()); - if (!arg) return mlir::failure(); + if (!arg) + return mlir::failure(); builder.create(loc(call.loc()), arg); return mlir::success(); @@ -520,23 +558,23 @@ class MLIRGenImpl { /// Dispatch codegen for the right expression subclass using RTTI. mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_StructLiteral: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast(expr)); - default: - emitError(loc(expr.loc())) - << "MLIR codegen encountered an unhandled expr kind '" - << Twine(expr.getKind()) << "'"; - return nullptr; + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_StructLiteral: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; } } @@ -553,7 +591,8 @@ class MLIRGenImpl { } mlir::Value value = mlirGen(*init); - if (!value) return nullptr; + if (!value) + return nullptr; // Handle the case where we are initializing a struct value. VarType varType = vardecl.getType(); @@ -561,7 +600,8 @@ class MLIRGenImpl { // Check that the initializer type is the same as the variable // declaration. mlir::Type type = getType(varType, vardecl.loc()); - if (!type) return nullptr; + if (!type) + return nullptr; if (type != value.getType()) { emitError(loc(vardecl.loc())) << "struct type of initializer is different than the variable " @@ -579,29 +619,34 @@ class MLIRGenImpl { } // Register the value in the symbol table. - if (failed(declare(vardecl, value))) return nullptr; + if (failed(declare(vardecl, value))) + return nullptr; return value; } /// Codegen a list of expression, return failure if one of them hit an error. - mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + llvm::LogicalResult mlirGen(ExprASTList &blockAST) { SymbolTableScopeT varScope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { - if (!mlirGen(*vardecl)) return mlir::failure(); + if (!mlirGen(*vardecl)) + return mlir::failure(); continue; } - if (auto *ret = dyn_cast(expr.get())) return mlirGen(*ret); + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); if (auto *print = dyn_cast(expr.get())) { - if (mlir::failed(mlirGen(*print))) return mlir::success(); + if (mlir::failed(mlirGen(*print))) + return mlir::success(); continue; } // Generic expression dispatch codegen. - if (!mlirGen(*expr)) return mlir::failure(); + if (!mlirGen(*expr)) + return mlir::failure(); } return mlir::success(); } @@ -633,7 +678,7 @@ class MLIRGenImpl { } }; -} // namespace +} // namespace namespace toy { @@ -643,4 +688,4 @@ mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, return MLIRGenImpl(context).mlirGen(moduleAST); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch7/mlir/ShapeInferencePass.cpp b/mlir/example/Ch7/mlir/ShapeInferencePass.cpp index d45baa1..a9e995e 100644 --- a/mlir/example/Ch7/mlir/ShapeInferencePass.cpp +++ b/mlir/example/Ch7/mlir/ShapeInferencePass.cpp @@ -11,13 +11,21 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "toy/Dialect.h" #include "toy/Passes.h" #include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "shape-inference" diff --git a/mlir/example/Ch7/mlir/ToyCombine.cpp b/mlir/example/Ch7/mlir/ToyCombine.cpp index b637377..1d8cf74 100644 --- a/mlir/example/Ch7/mlir/ToyCombine.cpp +++ b/mlir/example/Ch7/mlir/ToyCombine.cpp @@ -11,11 +11,14 @@ // //===----------------------------------------------------------------------===// -#include - -#include "mlir/IR/Matchers.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "toy/Dialect.h" +#include "llvm/Support/Casting.h" +#include using namespace mlir; using namespace toy; @@ -53,7 +56,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::LogicalResult + llvm::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. diff --git a/mlir/example/Ch7/parser/AST.cpp b/mlir/example/Ch7/parser/AST.cpp index d33cd8a..e38a743 100644 --- a/mlir/example/Ch7/parser/AST.cpp +++ b/mlir/example/Ch7/parser/AST.cpp @@ -12,9 +12,12 @@ #include "toy/AST.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include using namespace toy; @@ -31,10 +34,10 @@ struct Indent { /// Helper class that implement the AST tree traversal and print the nodes along /// the way. The only data member is the current indentation level. class ASTDumper { - public: +public: void dump(ModuleAST *node); - private: +private: void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); @@ -53,12 +56,13 @@ class ASTDumper { // Actually print spaces matching the current indentation level void indent() { - for (int i = 0; i < curIndent; i++) llvm::errs() << " "; + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; } int curIndent = 0; }; -} // namespace +} // namespace /// Return a formatted string for the location of any node template @@ -71,8 +75,8 @@ static std::string loc(T *node) { // Helper Macro to bump the indentation level and print the leading spaces for // the current indentations -#define INDENT() \ - Indent level_(curIndent); \ +#define INDENT() \ + Indent level_(curIndent); \ indent(); /// Dispatch to a generic expressions to the appropriate subclass using RTTI @@ -95,14 +99,16 @@ void ASTDumper::dump(VarDeclExprAST *varDecl) { llvm::errs() << "VarDecl " << varDecl->getName(); dump(varDecl->getType()); llvm::errs() << " " << loc(varDecl) << "\n"; - if (auto *initVal = varDecl->getInitVal()) dump(initVal); + if (auto *initVal = varDecl->getInitVal()) + dump(initVal); } /// A "block", or a list of expression void ASTDumper::dump(ExprASTList *exprList) { INDENT(); llvm::errs() << "Block {\n"; - for (auto &expr : *exprList) dump(expr.get()); + for (auto &expr : *exprList) + dump(expr.get()); indent(); llvm::errs() << "} // Block\n"; } @@ -149,7 +155,8 @@ void ASTDumper::dump(LiteralExprAST *node) { void ASTDumper::dump(StructLiteralExprAST *node) { INDENT(); llvm::errs() << "Struct Literal: "; - for (auto &value : node->getValues()) dump(value.get()); + for (auto &value : node->getValues()) + dump(value.get()); indent(); llvm::errs() << " " << loc(node) << "\n"; } @@ -164,7 +171,8 @@ void ASTDumper::dump(VariableExprAST *node) { void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (node->getExpr().has_value()) return dump(*node->getExpr()); + if (node->getExpr().has_value()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -184,7 +192,8 @@ void ASTDumper::dump(BinaryExprAST *node) { void ASTDumper::dump(CallExprAST *node) { INDENT(); llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; - for (auto &arg : node->getArgs()) dump(arg.get()); + for (auto &arg : node->getArgs()) + dump(arg.get()); indent(); llvm::errs() << "]\n"; } @@ -236,7 +245,8 @@ void ASTDumper::dump(StructAST *node) { { INDENT(); llvm::errs() << "Variables: [\n"; - for (auto &variable : node->getVariables()) dump(variable.get()); + for (auto &variable : node->getVariables()) + dump(variable.get()); indent(); llvm::errs() << "]\n"; } @@ -261,4 +271,4 @@ namespace toy { // Public API void dump(ModuleAST &module) { ASTDumper().dump(&module); } -} // namespace toy +} // namespace toy diff --git a/mlir/example/Ch7/toyc.cpp b/mlir/example/Ch7/toyc.cpp index 94d809e..fea5679 100644 --- a/mlir/example/Ch7/toyc.cpp +++ b/mlir/example/Ch7/toyc.cpp @@ -10,8 +10,16 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" @@ -21,17 +29,14 @@ #include "mlir/IR/Verifier.h" #include "mlir/InitAllDialects.h" #include "mlir/Parser/Parser.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/Passes.h" -#include "toy/Dialect.h" -#include "toy/MLIRGen.h" -#include "toy/Parser.h" -#include "toy/Passes.h" + #include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" @@ -39,6 +44,11 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include using namespace toy; namespace cl = llvm::cl; @@ -101,7 +111,7 @@ int loadMLIR(mlir::MLIRContext &context, mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && - !llvm::StringRef(inputFilename).endswith(".mlir")) { + !llvm::StringRef(inputFilename).ends_with(".mlir")) { auto moduleAST = parseInputFile(inputFilename); if (!moduleAST) return 6; @@ -177,8 +187,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, // This is necessary to have line tables emitted and basic // debugger working. In the future we will add proper debug information // emission directly from our frontend. - pm.addNestedPass( - mlir::LLVM::createDIScopeForLLVMFuncOpPass()); + pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); } if (mlir::failed(pm.run(*module))) diff --git a/mlir/example/Ch8/CMakeLists.txt b/mlir/example/Ch8/CMakeLists.txt index 4b9f26f..0bfd338 100644 --- a/mlir/example/Ch8/CMakeLists.txt +++ b/mlir/example/Ch8/CMakeLists.txt @@ -39,12 +39,10 @@ target_link_libraries( MLIRExecutionEngine MLIRIR MLIRLLVMCommonConversion - MLIRLLVMDialect MLIRLLVMToLLVMIRTranslation MLIRMemRefDialect MLIRParser MLIRPass MLIRSideEffectInterfaces - MLIRSupport MLIRTargetLLVMIRExport MLIRTransforms) diff --git a/mlir/example/Ch8/example.toy b/mlir/example/Ch8/example.toy deleted file mode 100644 index 724a23e..0000000 --- a/mlir/example/Ch8/example.toy +++ /dev/null @@ -1,13 +0,0 @@ -def main() { - # Define a variable `a` with shape <2, 3>, initialized with the literal value. - # The shape is inferred from the supplied literal. - var a = [[1, 2, 3], [4, 5, 6]]; - - # b is identical to a, the literal tensor is implicitly reshaped: defining new - # variables is the way to reshape tensors (element count must match). - var b<2, 3> = [1, 2, 3, 4, 5, 6]; - - # transpose() and print() are the only builtin, the following will transpose - # a and b and perform an element-wise multiplication before printing the result. - print(transpose(a) * transpose(b)); -} diff --git a/mlir/example/Ch8/include/toy/AST.h b/mlir/example/Ch8/include/toy/AST.h index fe08307..4827865 100644 --- a/mlir/example/Ch8/include/toy/AST.h +++ b/mlir/example/Ch8/include/toy/AST.h @@ -15,19 +15,20 @@ #ifndef TOY_AST_H #define TOY_AST_H -#include -#include -#include - #include "toy/Lexer.h" + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" +#include +#include +#include namespace toy { -/// A variable type with shape information. +/// A variable type with either name or shape information. struct VarType { + std::string name; std::vector shape; }; @@ -39,6 +40,7 @@ class ExprAST { Expr_Return, Expr_Num, Expr_Literal, + Expr_StructLiteral, Expr_Var, Expr_BinOp, Expr_Call, @@ -93,6 +95,24 @@ class LiteralExprAST : public ExprAST { static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } }; +/// Expression class for a literal struct value. +class StructLiteralExprAST : public ExprAST { + std::vector> values; + +public: + StructLiteralExprAST(Location loc, + std::vector> values) + : ExprAST(Expr_StructLiteral, std::move(loc)), values(std::move(values)) { + } + + llvm::ArrayRef> getValues() { return values; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { + return c->getKind() == Expr_StructLiteral; + } +}; + /// Expression class for referencing a variable, like "a". class VariableExprAST : public ExprAST { std::string name; @@ -115,7 +135,7 @@ class VarDeclExprAST : public ExprAST { public: VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, - std::unique_ptr initVal) + std::unique_ptr initVal = nullptr) : ExprAST(Expr_VarDecl, std::move(loc)), name(name), type(std::move(type)), initVal(std::move(initVal)) {} @@ -202,41 +222,88 @@ class PrintExprAST : public ExprAST { class PrototypeAST { Location location; std::string name; - std::vector> args; + std::vector> args; public: PrototypeAST(Location location, const std::string &name, - std::vector> args) + std::vector> args) : location(std::move(location)), name(name), args(std::move(args)) {} const Location &loc() { return location; } llvm::StringRef getName() const { return name; } - llvm::ArrayRef> getArgs() { return args; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a top level record in a module. +class RecordAST { +public: + enum RecordASTKind { + Record_Function, + Record_Struct, + }; + + RecordAST(RecordASTKind kind) : kind(kind) {} + virtual ~RecordAST() = default; + + RecordASTKind getKind() const { return kind; } + +private: + const RecordASTKind kind; }; /// This class represents a function definition itself. -class FunctionAST { +class FunctionAST : public RecordAST { std::unique_ptr proto; std::unique_ptr body; public: FunctionAST(std::unique_ptr proto, std::unique_ptr body) - : proto(std::move(proto)), body(std::move(body)) {} + : RecordAST(Record_Function), proto(std::move(proto)), + body(std::move(body)) {} PrototypeAST *getProto() { return proto.get(); } ExprASTList *getBody() { return body.get(); } + + /// LLVM style RTTI + static bool classof(const RecordAST *r) { + return r->getKind() == Record_Function; + } +}; + +/// This class represents a struct definition. +class StructAST : public RecordAST { + Location location; + std::string name; + std::vector> variables; + +public: + StructAST(Location location, const std::string &name, + std::vector> variables) + : RecordAST(Record_Struct), location(std::move(location)), name(name), + variables(std::move(variables)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getVariables() { + return variables; + } + + /// LLVM style RTTI + static bool classof(const RecordAST *r) { + return r->getKind() == Record_Struct; + } }; /// This class represents a list of functions to be processed together class ModuleAST { - std::vector functions; + std::vector> records; public: - ModuleAST(std::vector functions) - : functions(std::move(functions)) {} + ModuleAST(std::vector> records) + : records(std::move(records)) {} - auto begin() { return functions.begin(); } - auto end() { return functions.end(); } + auto begin() { return records.begin(); } + auto end() { return records.end(); } }; void dump(ModuleAST &); diff --git a/mlir/example/Ch8/include/toy/Dialect.h b/mlir/example/Ch8/include/toy/Dialect.h index 927b168..64094c3 100644 --- a/mlir/example/Ch8/include/toy/Dialect.h +++ b/mlir/example/Ch8/include/toy/Dialect.h @@ -14,22 +14,69 @@ #ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ #define MLIR_TUTORIAL_TOY_DIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "toy/ShapeInferenceInterface.h" +namespace mlir { +namespace toy { +namespace detail { +struct StructTypeStorage; +} // namespace detail +} // namespace toy +} // namespace mlir + /// Include the auto-generated header file containing the declaration of the toy /// dialect. #include "toy/Dialect.h.inc" +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + /// Include the auto-generated header file containing the declarations of the /// toy operations. #define GET_OP_CLASSES #include "toy/Ops.h.inc" +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// Toy Types +//===----------------------------------------------------------------------===// + +/// This class defines the Toy struct type. It represents a collection of +/// element types. All derived types in MLIR must inherit from the CRTP class +/// 'Type::TypeBase'. It takes as template parameters the concrete type +/// (StructType), the base class to use (Type), and the storage class +/// (StructTypeStorage). +class StructType : public mlir::Type::TypeBase { +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + + /// Create an instance of a `StructType` with the given element types. There + /// *must* be atleast one element type. + static StructType get(llvm::ArrayRef elementTypes); + + /// Returns the element types of this struct type. + llvm::ArrayRef getElementTypes(); + + /// Returns the number of element type held by this struct. + size_t getNumElementTypes() { return getElementTypes().size(); } + + /// The name of this struct type. + static constexpr StringLiteral name = "toy.struct"; +}; +} // namespace toy +} // namespace mlir + #endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/example/Ch8/include/toy/Lexer.h b/mlir/example/Ch8/include/toy/Lexer.h index 17d3e0f..a3fde91 100644 --- a/mlir/example/Ch8/include/toy/Lexer.h +++ b/mlir/example/Ch8/include/toy/Lexer.h @@ -13,11 +13,11 @@ #ifndef TOY_LEXER_H #define TOY_LEXER_H +#include "llvm/ADT/StringRef.h" + #include #include -#include "llvm/ADT/StringRef.h" - namespace toy { /// Structure definition a location in a file. @@ -43,10 +43,11 @@ enum Token : int { tok_return = -2, tok_var = -3, tok_def = -4, + tok_struct = -5, // primary - tok_identifier = -5, - tok_number = -6, + tok_identifier = -6, + tok_number = -7, }; /// The Lexer is an abstract base class providing all the facilities that the @@ -143,13 +144,15 @@ class Lexer { return tok_return; if (identifierStr == "def") return tok_def; + if (identifierStr == "struct") + return tok_struct; if (identifierStr == "var") return tok_var; return tok_identifier; } - // Number: [0-9.]+ - if (isdigit(lastChar) || lastChar == '.') { + // Number: [0-9] ([0-9.])* + if (isdigit(lastChar)) { std::string numStr; do { numStr += lastChar; diff --git a/mlir/example/Ch8/include/toy/Ops.td b/mlir/example/Ch8/include/toy/Ops.td index 298bd3e..f714ee3 100644 --- a/mlir/example/Ch8/include/toy/Ops.td +++ b/mlir/example/Ch8/include/toy/Ops.td @@ -13,7 +13,7 @@ #ifndef TOY_OPS #define TOY_OPS -include "mlir/IR/FunctionInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -25,6 +25,15 @@ include "toy/ShapeInferenceInterface.td" def Toy_Dialect : Dialect { let name = "toy"; let cppNamespace = "::mlir::toy"; + + // We set this bit to generate a declaration of the `materializeConstant` + // method so that we can materialize constants for our toy operations. + let hasConstantMaterializer = 1; + + // We set this bit to generate the declarations for the dialect's type parsing + // and printing hooks. + let useDefaultTypePrinterParser = 1; + } // Base class for toy dialect operations. This operation inherits from the base @@ -35,6 +44,16 @@ def Toy_Dialect : Dialect { class Toy_Op traits = []> : Op; +// Provide a definition for the Toy StructType for use in ODS. This allows for +// using StructType in a similar way to Tensor or MemRef. We use `DialectType` +// to demarcate the StructType as belonging to the Toy dialect. +def Toy_StructType : + DialectType($_self)">, + "Toy struct type">; + +// Provide a definition of the types that are used within the Toy dialect. +def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; + //===----------------------------------------------------------------------===// // Toy Operations //===----------------------------------------------------------------------===// @@ -47,7 +66,9 @@ class Toy_Op traits = []> : // Here we provide the mnemonic and a list of traits for the operation. The // constant operation is marked as 'Pure' as it is a pure operation // and may be removed if dead. -def ConstantOp : Toy_Op<"constant", [Pure]> { +def ConstantOp : Toy_Op<"constant", + [ConstantLike, Pure, + DeclareOpInterfaceMethods]> { // Provide a summary and description for this operation. This can be used to // auto-generate documentation of the operations within our dialect. let summary = "constant"; @@ -85,6 +106,9 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Indicate that additional verification for this operation is necessary. let hasVerifier = 1; + + // Set the folder bit so that we can implement constant folders. + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -141,8 +165,7 @@ def CastOp : Toy_Op<"cast", [ //===----------------------------------------------------------------------===// def FuncOp : Toy_Op<"func", [ - DeclareOpInterfaceMethods, FunctionOpInterface, - IsolatedFromAbove + FunctionOpInterface, IsolatedFromAbove ]> { let summary = "user defined function operation"; let description = [{ @@ -183,6 +206,8 @@ def FuncOp : Toy_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } }]; let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; @@ -212,10 +237,11 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); - // The generic call operation returns a single value of TensorType. - let results = (outs F64Tensor); + // The generic call operation returns a single value of TensorType or + // StructType. + let results = (outs Toy_Type); // Specialize assembly printing and parsing using a declarative format. let assemblyFormat = [{ @@ -307,7 +333,7 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, let summary = "return operation"; let description = [{ The "return" operation represents a return operation within a function. - The operation takes an optional tensor operand and produces no results. + The operation takes an optional operand and produces no results. The operand type must match the signature of the function that contains the operation. For example: @@ -321,7 +347,7 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, // The return operation takes an optional input operand to return. This // value must match the return type of the enclosing function. - let arguments = (ins Variadic:$input); + let arguments = (ins Variadic:$input); // The return operation only emits the input in the format if it is present. let assemblyFormat = "($input^ `:` type($input))? attr-dict "; @@ -340,6 +366,63 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// StructAccessOp +//===----------------------------------------------------------------------===// + +def StructAccessOp : Toy_Op<"struct_access", [Pure]> { + let summary = "struct access"; + let description = [{ + Access the Nth element of a value returning a struct type. + }]; + + let arguments = (ins Toy_StructType:$input, I64Attr:$index); + let results = (outs Toy_Type:$output); + + let assemblyFormat = [{ + $input `[` $index `]` attr-dict `:` type($input) `->` type($output) + }]; + + // Allow building a StructAccessOp with just a struct value and an index. + let builders = [ + OpBuilder<(ins "Value":$input, "size_t":$index)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; + + // Set the folder bit so that we can fold constant accesses. + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// StructConstantOp +//===----------------------------------------------------------------------===// + +def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, Pure]> { + let summary = "struct constant"; + let description = [{ + Constant operation turns a literal struct value into an SSA value. The data + is attached to the operation as an attribute. The struct constant is encoded + as an array of other constant values. For example: + + ```mlir + %0 = toy.struct_constant [ + dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> + ] : !toy.struct> + ``` + }]; + + let arguments = (ins ArrayAttr:$value); + let results = (outs Toy_StructType:$output); + + let assemblyFormat = "$value attr-dict `:` type($output)"; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/example/Ch8/include/toy/Parser.h b/mlir/example/Ch8/include/toy/Parser.h index fa2f882..101b03d 100644 --- a/mlir/example/Ch8/include/toy/Parser.h +++ b/mlir/example/Ch8/include/toy/Parser.h @@ -14,17 +14,18 @@ #ifndef TOY_PARSER_H #define TOY_PARSER_H -#include -#include -#include -#include - #include "toy/AST.h" #include "toy/Lexer.h" + #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include + namespace toy { /// This is a simple recursive parser for the Toy language. It produces a well @@ -41,18 +42,33 @@ class Parser { std::unique_ptr parseModule() { lexer.getNextToken(); // prime the lexer - // Parse functions one at a time and accumulate in this vector. - std::vector functions; - while (auto f = parseDefinition()) { - functions.push_back(std::move(*f)); - if (lexer.getCurToken() == tok_eof) + // Parse functions and structs one at a time and accumulate in this vector. + std::vector> records; + while (true) { + std::unique_ptr record; + switch (lexer.getCurToken()) { + case tok_eof: + break; + case tok_def: + record = parseDefinition(); + break; + case tok_struct: + record = parseStruct(); + break; + default: + return parseError("'def' or 'struct'", + "when parsing top level module records"); + } + if (!record) break; + records.push_back(std::move(record)); } + // If we didn't reach EOF, there was an error during parsing if (lexer.getCurToken() != tok_eof) return parseError("nothing", "at end of module"); - return std::make_unique(std::move(functions)); + return std::make_unique(std::move(records)); } private: @@ -153,6 +169,50 @@ class Parser { std::move(dims)); } + /// Parse a literal struct expression. + /// structLiteral ::= { (structLiteral | tensorLiteral)+ } + std::unique_ptr parseStructLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('{')); + + // Hold the list of values. + std::vector> values; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; + } else if (lexer.getCurToken() == tok_number) { + values.push_back(parseNumberExpr()); + if (!values.back()) + return nullptr; + } else { + if (lexer.getCurToken() != '{') + return parseError("{, [, or number", + "in struct literal expression"); + values.push_back(parseStructLiteralExpr()); + } + + // End of this list on '}' + if (lexer.getCurToken() == '}') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("} or ,", "in struct literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", + "to fill struct literal expression"); + lexer.getNextToken(); // eat } + + return std::make_unique(std::move(loc), + std::move(values)); + } + /// parenexpr ::= '(' expression ')' std::unique_ptr parseParenExpr() { lexer.getNextToken(); // eat (. @@ -166,19 +226,9 @@ class Parser { return v; } - /// identifierexpr - /// ::= identifier - /// ::= identifier '(' expression ')' - std::unique_ptr parseIdentifierExpr() { - std::string name(lexer.getId()); - - auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. - - if (lexer.getCurToken() != '(') // Simple variable ref. - return std::make_unique(std::move(loc), name); - - // This is a function call. + /// Parse a call expression. + std::unique_ptr parseCallExpr(llvm::StringRef name, + const Location &loc) { lexer.consume(Token('(')); std::vector> args; if (lexer.getCurToken() != ')') { @@ -203,11 +253,28 @@ class Parser { if (args.size() != 1) return parseError("", "as argument to print()"); - return std::make_unique(std::move(loc), std::move(args[0])); + return std::make_unique(loc, std::move(args[0])); } // Call to a user-defined function - return std::make_unique(std::move(loc), name, std::move(args)); + return std::make_unique(loc, std::string(name), + std::move(args)); + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name(lexer.getId()); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + return parseCallExpr(name, loc); } /// primary @@ -229,6 +296,8 @@ class Parser { return parseParenExpr(); case '[': return parseTensorLiteralExpr(); + case '{': + return parseStructLiteralExpr(); case ';': return nullptr; case '}': @@ -307,11 +376,72 @@ class Parser { return type; } + /// Parse either a variable declaration or a call expression. + std::unique_ptr parseDeclarationOrCallExpr() { + auto loc = lexer.getLastLocation(); + std::string id(lexer.getId()); + lexer.consume(tok_identifier); + + // Check for a call expression. + if (lexer.getCurToken() == '(') + return parseCallExpr(id, loc); + + // Otherwise, this is a variable declaration. + return parseTypedDeclaration(id, /*requiresInitializer=*/true, loc); + } + + /// Parse a typed variable declaration. + std::unique_ptr + parseTypedDeclaration(llvm::StringRef typeName, bool requiresInitializer, + const Location &loc) { + // Parse the variable name. + if (lexer.getCurToken() != tok_identifier) + return parseError("name", "in variable declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + // Parse the initializer. + std::unique_ptr expr; + if (requiresInitializer) { + if (lexer.getCurToken() != '=') + return parseError("initializer", + "in variable declaration"); + lexer.consume(Token('=')); + expr = parseExpression(); + } + + VarType type; + type.name = std::string(typeName); + return std::make_unique(loc, std::move(id), std::move(type), + std::move(expr)); + } + + /// Parse a variable declaration, for either a tensor value or a struct value, + /// with an optionally required initializer. + /// decl ::= var identifier [ type ] (= expr)? + /// decl ::= identifier identifier (= expr)? + std::unique_ptr parseDeclaration(bool requiresInitializer) { + // Check to see if this is a 'var' declaration. + if (lexer.getCurToken() == tok_var) + return parseVarDeclaration(requiresInitializer); + + // Parse the type name. + if (lexer.getCurToken() != tok_identifier) + return parseError("type name", "in variable declaration"); + auto loc = lexer.getLastLocation(); + std::string typeName(lexer.getId()); + lexer.getNextToken(); // eat id + + // Parse the rest of the declaration. + return parseTypedDeclaration(typeName, requiresInitializer, loc); + } + /// Parse a variable declaration, it starts with a `var` keyword followed by /// and identifier and an optional type (shape specification) before the - /// initializer. - /// decl ::= var identifier [ type ] = expr - std::unique_ptr parseDeclaration() { + /// optionally required initializer. + /// decl ::= var identifier [ type ] (= expr)? + std::unique_ptr + parseVarDeclaration(bool requiresInitializer) { if (lexer.getCurToken() != tok_var) return parseError("var", "to begin declaration"); auto loc = lexer.getLastLocation(); @@ -329,11 +459,14 @@ class Parser { if (!type) return nullptr; } - if (!type) type = std::make_unique(); - lexer.consume(Token('=')); - auto expr = parseExpression(); + + std::unique_ptr expr; + if (requiresInitializer) { + lexer.consume(Token('=')); + expr = parseExpression(); + } return std::make_unique(std::move(loc), std::move(id), std::move(*type), std::move(expr)); } @@ -356,9 +489,15 @@ class Parser { lexer.consume(Token(';')); while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { - if (lexer.getCurToken() == tok_var) { + if (lexer.getCurToken() == tok_identifier) { + // Variable declaration or call + auto expr = parseDeclarationOrCallExpr(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } else if (lexer.getCurToken() == tok_var) { // Variable declaration - auto varDecl = parseDeclaration(); + auto varDecl = parseDeclaration(/*requiresInitializer=*/true); if (!varDecl) return nullptr; exprList->push_back(std::move(varDecl)); @@ -410,14 +549,31 @@ class Parser { return parseError("(", "in prototype"); lexer.consume(Token('(')); - std::vector> args; + std::vector> args; if (lexer.getCurToken() != ')') { do { - std::string name(lexer.getId()); + VarType type; + std::string name; + + // Parse either the name of the variable, or its type. + std::string nameOrType(lexer.getId()); auto loc = lexer.getLastLocation(); lexer.consume(tok_identifier); - auto decl = std::make_unique(std::move(loc), name); - args.push_back(std::move(decl)); + + // If the next token is an identifier, we just parsed the type. + if (lexer.getCurToken() == tok_identifier) { + type.name = std::move(nameOrType); + + // Parse the name. + name = std::string(lexer.getId()); + lexer.consume(tok_identifier); + } else { + // Otherwise, we just parsed the name. + name = std::move(nameOrType); + } + + args.push_back( + std::make_unique(std::move(loc), name, type)); if (lexer.getCurToken() != ',') break; lexer.consume(Token(',')); @@ -449,6 +605,43 @@ class Parser { return nullptr; } + /// Parse a struct definition, we expect a struct initiated with the + /// `struct` keyword, followed by a block containing a list of variable + /// declarations. + /// + /// definition ::= `struct` identifier `{` decl+ `}` + std::unique_ptr parseStruct() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_struct); + if (lexer.getCurToken() != tok_identifier) + return parseError("name", "in struct definition"); + std::string name(lexer.getId()); + lexer.consume(tok_identifier); + + // Parse: '{' + if (lexer.getCurToken() != '{') + return parseError("{", "in struct definition"); + lexer.consume(Token('{')); + + // Parse: decl+ + std::vector> decls; + do { + auto decl = parseDeclaration(/*requiresInitializer=*/false); + if (!decl) + return nullptr; + decls.push_back(std::move(decl)); + + if (lexer.getCurToken() != ';') + return parseError(";", + "after variable in struct definition"); + lexer.consume(Token(';')); + } while (lexer.getCurToken() != '}'); + + // Parse: '}' + lexer.consume(Token('}')); + return std::make_unique(loc, name, std::move(decls)); + } + /// Get the precedence of the pending binary operator token. int getTokPrecedence() { if (!isascii(lexer.getCurToken())) @@ -462,6 +655,8 @@ class Parser { return 20; case '*': return 40; + case '.': + return 60; default: return -1; } diff --git a/mlir/example/Ch8/mlir/Dialect.cpp b/mlir/example/Ch8/mlir/Dialect.cpp index d750782..0f5152d 100644 --- a/mlir/example/Ch8/mlir/Dialect.cpp +++ b/mlir/example/Ch8/mlir/Dialect.cpp @@ -13,11 +13,32 @@ #include "toy/Dialect.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/FunctionImplementation.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include using namespace mlir; using namespace mlir::toy; @@ -59,8 +80,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -82,20 +102,6 @@ struct ToyInlinerInterface : public DialectInlinerInterface { } }; -//===----------------------------------------------------------------------===// -// ToyDialect -//===----------------------------------------------------------------------===// - -/// Dialect initialization, the instance will be owned by the context. This is -/// the point of registration of types and operations for the dialect. -void ToyDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "toy/Ops.cpp.inc" - >(); - addInterfaces(); -} - //===----------------------------------------------------------------------===// // Toy Operations //===----------------------------------------------------------------------===// @@ -188,37 +194,78 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { printer << getValue(); } -/// Verifier for the constant operation. This corresponds to the -/// `let hasVerifier = 1` in the op definition. -mlir::LogicalResult ConstantOp::verify() { - // If the return type of the constant is not an unranked tensor, the shape - // must match the shape of the attribute holding the data. - auto resultType = - llvm::dyn_cast(getResult().getType()); - if (!resultType) - return success(); - - // Check that the rank of the attribute type matches the rank of the constant - // result type. - auto attrType = llvm::cast(getValue().getType()); - if (attrType.getRank() != resultType.getRank()) { - return emitOpError("return type must match the one of the attached value " - "attribute: ") - << attrType.getRank() << " != " << resultType.getRank(); - } +/// Verify that the given attribute value is valid for the given type. +static llvm::LogicalResult verifyConstantForType(mlir::Type type, + mlir::Attribute opaqueValue, + mlir::Operation *op) { + if (llvm::isa(type)) { + // Check that the value is an elements attribute. + auto attrValue = llvm::dyn_cast(opaqueValue); + if (!attrValue) + return op->emitError("constant of TensorType must be initialized by " + "a DenseFPElementsAttr, got ") + << opaqueValue; + + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = llvm::dyn_cast(type); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the + // constant result type. + auto attrType = llvm::cast(attrValue.getType()); + if (attrType.getRank() != resultType.getRank()) { + return op->emitOpError("return type must match the one of the attached " + "value attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } - // Check that each of the dimensions match between the two types. - for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { - if (attrType.getShape()[dim] != resultType.getShape()[dim]) { - return emitOpError( - "return type shape mismatches its attribute at dimension ") - << dim << ": " << attrType.getShape()[dim] - << " != " << resultType.getShape()[dim]; + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op->emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } } + return mlir::success(); } + auto resultType = llvm::cast(type); + llvm::ArrayRef resultElementTypes = resultType.getElementTypes(); + + // Verify that the initializer is an Array. + auto attrValue = llvm::dyn_cast(opaqueValue); + if (!attrValue || attrValue.getValue().size() != resultElementTypes.size()) + return op->emitError("constant of StructType must be initialized by an " + "ArrayAttr with the same number of elements, got ") + << opaqueValue; + + // Check that each of the elements are valid. + llvm::ArrayRef attrElementValues = attrValue.getValue(); + for (const auto it : llvm::zip(resultElementTypes, attrElementValues)) + if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op))) + return mlir::failure(); return mlir::success(); } +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +llvm::LogicalResult ConstantOp::verify() { + return verifyConstantForType(getResult().getType(), getValue(), *this); +} + +llvm::LogicalResult StructConstantOp::verify() { + return verifyConstantForType(getResult().getType(), getValue(), *this); +} + +/// Infer the output shape of the ConstantOp, this is required by the shape +/// inference interface. +void ConstantOp::inferShapes() { + getResult().setType(cast(getValue().getType())); +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -299,27 +346,6 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { getArgAttrsAttrName(), getResAttrsAttrName()); } -/// Returns the region on the function operation that is callable. -mlir::Region *FuncOp::getCallableRegion() { return &getBody(); } - -/// Returns the results types that the callable region produces when -/// executed. -llvm::ArrayRef FuncOp::getCallableResults() { - return getFunctionType().getResults(); -} - -/// Returns the argument attributes for all callable region arguments or -/// null if there are none. -ArrayAttr FuncOp::getCallableArgAttrs() { - return getArgAttrs().value_or(nullptr); -} - -/// Returns the result attributes for all callable region results or -/// null if there are none. -ArrayAttr FuncOp::getCallableResAttrs() { - return getResAttrs().value_or(nullptr); -} - //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// @@ -349,6 +375,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -374,7 +406,7 @@ void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } // ReturnOp //===----------------------------------------------------------------------===// -mlir::LogicalResult ReturnOp::verify() { +llvm::LogicalResult ReturnOp::verify() { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. auto function = cast((*this)->getParentOp()); @@ -408,6 +440,34 @@ mlir::LogicalResult ReturnOp::verify() { << ")"; } +//===----------------------------------------------------------------------===// +// StructAccessOp +//===----------------------------------------------------------------------===// + +void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state, + mlir::Value input, size_t index) { + // Extract the result type from the input type. + StructType structTy = llvm::cast(input.getType()); + assert(index < structTy.getNumElementTypes()); + mlir::Type resultType = structTy.getElementTypes()[index]; + + // Call into the auto-generated build method. + build(b, state, resultType, input, b.getI64IntegerAttr(index)); +} + +llvm::LogicalResult StructAccessOp::verify() { + StructType structTy = llvm::cast(getInput().getType()); + size_t indexValue = getIndex(); + if (indexValue >= structTy.getNumElementTypes()) + return emitOpError() + << "index should be within the range of the input struct type"; + mlir::Type resultType = getResult().getType(); + if (resultType != structTy.getElementTypes()[indexValue]) + return emitOpError() << "must have the same result type as the struct " + "element referred to by the index"; + return mlir::success(); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// @@ -424,7 +484,7 @@ void TransposeOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } -mlir::LogicalResult TransposeOp::verify() { +llvm::LogicalResult TransposeOp::verify() { auto inputType = llvm::dyn_cast(getOperand().getType()); auto resultType = llvm::dyn_cast(getType()); if (!inputType || !resultType) @@ -449,17 +509,24 @@ void MatMulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, state.addOperands({lhs, rhs}); } -// mlir::ParseResult MatMulOp::parse(mlir::OpAsmParser &parser, -// mlir::OperationState &result) { -// return parseBinaryOp(parser, result); -// } - -// void MatMulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } +/// Infer the output shape of the MatMulOp, this is required by the shape +/// inference interface. +void MatMulOp::inferShapes() { + RankedTensorType lhsType = + llvm::dyn_cast(getLhs().getType()); + RankedTensorType rhsType = + llvm::dyn_cast(getRhs().getType()); + auto lhsShape = lhsType.getShape(); + auto rhsShape = rhsType.getShape(); + RankedTensorType res_type = RankedTensorType::get({lhsShape[0], rhsShape[1]}, + lhsType.getElementType()); + getResult().setType(res_type); +} -mlir::LogicalResult MatMulOp::verify() { - auto lhsType = getLhs().getType().dyn_cast(); - auto rhsType = getRhs().getType().dyn_cast(); - auto resultType = getType().dyn_cast(); +llvm::LogicalResult MatMulOp::verify() { + auto lhsType = llvm::dyn_cast(getLhs().getType()); + auto rhsType = llvm::dyn_cast(getRhs().getType()); + auto resultType = llvm::dyn_cast(getType()); if (!lhsType || !rhsType || !resultType) return mlir::success(); @@ -484,16 +551,137 @@ mlir::LogicalResult MatMulOp::verify() { return mlir::success(); } -/// Infer the output shape of the MatMulOp, this is required by the shape -/// inference interface. -void MatMulOp::inferShapes() { - RankedTensorType lhsType = getLhs().getType().cast(); - RankedTensorType rhsType = getRhs().getType().cast(); - auto lhsShape = lhsType.getShape(); - auto rhsShape = rhsType.getShape(); - RankedTensorType res_type = RankedTensorType::get({lhsShape[0], rhsShape[1]}, - lhsType.getElementType()); - getResult().setType(res_type); +//===----------------------------------------------------------------------===// +// Toy Types +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { +namespace detail { +/// This class represents the internal storage of the Toy `StructType`. +struct StructTypeStorage : public mlir::TypeStorage { + /// The `KeyTy` is a required type that provides an interface for the storage + /// instance. This type will be used when uniquing an instance of the type + /// storage. For our struct type, we will unique each instance structurally on + /// the elements that it contains. + using KeyTy = llvm::ArrayRef; + + /// A constructor for the type storage instance. + StructTypeStorage(llvm::ArrayRef elementTypes) + : elementTypes(elementTypes) {} + + /// Define the comparison function for the key type with the current storage + /// instance. This is used when constructing a new instance to ensure that we + /// haven't already uniqued an instance of the given key. + bool operator==(const KeyTy &key) const { return key == elementTypes; } + + /// Define a hash function for the key type. This is used when uniquing + /// instances of the storage, see the `StructType::get` method. + /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type + /// have hash functions available, so we could just omit this entirely. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Define a construction function for the key type from a set of parameters. + /// These parameters will be provided when constructing the storage instance + /// itself. + /// Note: This method isn't necessary because KeyTy can be directly + /// constructed with the given parameters. + static KeyTy getKey(llvm::ArrayRef elementTypes) { + return KeyTy(elementTypes); + } + + /// Define a construction method for creating a new instance of this storage. + /// This method takes an instance of a storage allocator, and an instance of a + /// `KeyTy`. The given allocator must be used for *all* necessary dynamic + /// allocations used to create the type storage and its internal. + static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the elements from the provided `KeyTy` into the allocator. + llvm::ArrayRef elementTypes = allocator.copyInto(key); + + // Allocate the storage instance and construct it. + return new (allocator.allocate()) + StructTypeStorage(elementTypes); + } + + /// The following field contains the element types of the struct. + llvm::ArrayRef elementTypes; +}; +} // namespace detail +} // namespace toy +} // namespace mlir + +/// Create an instance of a `StructType` with the given element types. There +/// *must* be at least one element type. +StructType StructType::get(llvm::ArrayRef elementTypes) { + assert(!elementTypes.empty() && "expected at least 1 element type"); + + // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance + // of this type. The first parameter is the context to unique in. The + // parameters after the context are forwarded to the storage instance. + mlir::MLIRContext *ctx = elementTypes.front().getContext(); + return Base::get(ctx, elementTypes); +} + +/// Returns the element types of this struct type. +llvm::ArrayRef StructType::getElementTypes() { + // 'getImpl' returns a pointer to the internal storage instance. + return getImpl()->elementTypes; +} + +/// Parse an instance of a type registered to the toy dialect. +mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { + // Parse a struct type in the following form: + // struct-type ::= `struct` `<` type (`,` type)* `>` + + // NOTE: All MLIR parser function return a ParseResult. This is a + // specialization of LogicalResult that auto-converts to a `true` boolean + // value on failure to allow for chaining, but may be used with explicit + // `mlir::failed/mlir::succeeded` as desired. + + // Parse: `struct` `<` + if (parser.parseKeyword("struct") || parser.parseLess()) + return Type(); + + // Parse the element types of the struct. + SmallVector elementTypes; + do { + // Parse the current element type. + SMLoc typeLoc = parser.getCurrentLocation(); + mlir::Type elementType; + if (parser.parseType(elementType)) + return nullptr; + + // Check that the type is either a TensorType or another StructType. + if (!llvm::isa(elementType)) { + parser.emitError(typeLoc, "element type for a struct must either " + "be a TensorType or a StructType, got: ") + << elementType; + return Type(); + } + elementTypes.push_back(elementType); + + // Parse the optional: `,` + } while (succeeded(parser.parseOptionalComma())); + + // Parse: `>` + if (parser.parseGreater()) + return Type(); + return StructType::get(elementTypes); +} + +/// Print an instance of a type registered to the toy dialect. +void ToyDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // Currently the only toy type is a struct type. + StructType structType = llvm::cast(type); + + // Print the struct type according to the parser format. + printer << "struct<"; + llvm::interleaveComma(structType.getElementTypes(), printer); + printer << '>'; } //===----------------------------------------------------------------------===// @@ -502,3 +690,29 @@ void MatMulOp::inferShapes() { #define GET_OP_CLASSES #include "toy/Ops.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void ToyDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); + addTypes(); +} + +mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + if (llvm::isa(type)) + return builder.create(loc, type, + llvm::cast(value)); + return builder.create(loc, type, + llvm::cast(value)); +} diff --git a/mlir/example/Ch8/mlir/LowerToAffineLoops.cpp b/mlir/example/Ch8/mlir/LowerToAffineLoops.cpp index 03e87ba..7413214 100644 --- a/mlir/example/Ch8/mlir/LowerToAffineLoops.cpp +++ b/mlir/example/Ch8/mlir/LowerToAffineLoops.cpp @@ -12,16 +12,34 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include using namespace mlir; @@ -154,8 +172,7 @@ struct ConstantOpLowering : public OpRewritePattern { SmallVector constantIndices; if (!valueShape.empty()) { - for (auto i : llvm::seq( - 0, *std::max_element(valueShape.begin(), valueShape.end()))) + for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) constantIndices.push_back( rewriter.create(loc, i)); } else { @@ -241,8 +258,8 @@ struct PrintOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter) const final { // We don't lower "toy.print" in this pass, but we need to update its // operands. - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/example/Ch8/mlir/LowerToLLVM.cpp b/mlir/example/Ch8/mlir/LowerToLLVM.cpp index ab28f02..3ad70e7 100644 --- a/mlir/example/Ch8/mlir/LowerToLLVM.cpp +++ b/mlir/example/Ch8/mlir/LowerToLLVM.cpp @@ -22,6 +22,16 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -31,7 +41,6 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -39,9 +48,9 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" -#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include using namespace mlir; @@ -60,6 +69,7 @@ class PrintOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + auto *context = rewriter.getContext(); auto memRefType = llvm::cast((*op->operand_type_begin())); auto memRefShape = memRefType.getShape(); auto loc = op->getLoc(); @@ -91,8 +101,8 @@ class PrintOpLowering : public ConversionPattern { // Insert a newline after each of the inner dimensions of the shape. if (i != e - 1) - rewriter.create(loc, printfRef, - rewriter.getIntegerType(32), newLineCst); + rewriter.create(loc, getPrintfType(context), printfRef, + newLineCst); rewriter.create(loc); rewriter.setInsertionPointToStart(loop.getBody()); } @@ -101,8 +111,8 @@ class PrintOpLowering : public ConversionPattern { auto printOp = cast(op); auto elementLoad = rewriter.create(loc, printOp.getInput(), loopIvs); - rewriter.create( - loc, printfRef, rewriter.getIntegerType(32), + rewriter.create( + loc, getPrintfType(context), printfRef, ArrayRef({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. @@ -111,6 +121,16 @@ class PrintOpLowering : public ConversionPattern { } private: + /// Create a function declaration for printf, the signature is: + /// * `i32 (i8*, ...)` + static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { + auto llvmI32Ty = IntegerType::get(context, 32); + auto llvmPtrTy = LLVM::LLVMPointerType::get(context); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, + /*isVarArg=*/true); + return llvmFnType; + } + /// Return a symbol reference to the printf function, inserting it into the /// module if necessary. static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, @@ -119,17 +139,11 @@ class PrintOpLowering : public ConversionPattern { if (module.lookupSymbol("printf")) return SymbolRefAttr::get(context, "printf"); - // Create a function declaration for printf, the signature is: - // * `i32 (i8*, ...)` - auto llvmI32Ty = IntegerType::get(context, 32); - auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, - /*isVarArg=*/true); - // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), "printf", llvmFnType); + rewriter.create(module.getLoc(), "printf", + getPrintfType(context)); return SymbolRefAttr::get(context, "printf"); } @@ -156,8 +170,7 @@ class PrintOpLowering : public ConversionPattern { Value cst0 = builder.create(loc, builder.getI64Type(), builder.getIndexAttr(0)); return builder.create( - loc, - LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)), + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), globalPtr, ArrayRef({cst0, cst0})); } }; diff --git a/mlir/example/Ch8/mlir/MLIRGen.cpp b/mlir/example/Ch8/mlir/MLIRGen.cpp index d0e5491..090e5ff 100644 --- a/mlir/example/Ch8/mlir/MLIRGen.cpp +++ b/mlir/example/Ch8/mlir/MLIRGen.cpp @@ -12,8 +12,11 @@ //===----------------------------------------------------------------------===// #include "toy/MLIRGen.h" - -#include +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "toy/AST.h" +#include "toy/Dialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -21,11 +24,24 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" -#include "toy/AST.h" -#include "toy/Dialect.h" +#include "toy/Lexer.h" + #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include using namespace mlir::toy; using namespace toy; @@ -57,8 +73,19 @@ class MLIRGenImpl { // add them to the module. theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); - for (FunctionAST &f : moduleAST) - mlirGen(f); + for (auto &record : moduleAST) { + if (FunctionAST *funcAST = llvm::dyn_cast(record.get())) { + mlir::toy::FuncOp func = mlirGen(*funcAST); + if (!func) + return nullptr; + functionMap.insert({func.getName(), func}); + } else if (StructAST *str = llvm::dyn_cast(record.get())) { + if (failed(mlirGen(*str))) + return nullptr; + } else { + llvm_unreachable("unknown record type"); + } + } // Verify the module after we have finished constructing it, this will check // the structural properties of the IR and invoke any specific verifiers we @@ -84,7 +111,18 @@ class MLIRGenImpl { /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable> + symbolTable; + using SymbolTableScopeT = + llvm::ScopedHashTableScope>; + + /// A mapping for the functions that have been code generated to MLIR. + llvm::StringMap functionMap; + + /// A mapping for named struct types to the underlying MLIR type and the + /// original AST node. + llvm::StringMap> structMap; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(const Location &loc) { @@ -94,10 +132,39 @@ class MLIRGenImpl { /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { - if (symbolTable.count(var)) + llvm::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) { + if (symbolTable.count(var.getName())) return mlir::failure(); - symbolTable.insert(var, value); + symbolTable.insert(var.getName(), {value, &var}); + return mlir::success(); + } + + /// Create an MLIR type for the given struct. + llvm::LogicalResult mlirGen(StructAST &str) { + if (structMap.count(str.getName())) + return emitError(loc(str.loc())) << "error: struct type with name `" + << str.getName() << "' already exists"; + + auto variables = str.getVariables(); + std::vector elementTypes; + elementTypes.reserve(variables.size()); + for (auto &variable : variables) { + if (variable->getInitVal()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + if (!variable->getType().shape.empty()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + + mlir::Type type = getType(variable->getType(), variable->loc()); + if (!type) + return mlir::failure(); + elementTypes.push_back(type); + } + + structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str); return mlir::success(); } @@ -107,9 +174,14 @@ class MLIRGenImpl { auto location = loc(proto.loc()); // This is a generic function, the return type will be inferred later. - // Arguments type are uniformly unranked tensors. - llvm::SmallVector argTypes(proto.getArgs().size(), - getType(VarType{})); + llvm::SmallVector argTypes; + argTypes.reserve(proto.getArgs().size()); + for (auto &arg : proto.getArgs()) { + mlir::Type type = getType(arg->getType(), arg->loc()); + if (!type) + return nullptr; + argTypes.push_back(type); + } auto funcType = builder.getFunctionType(argTypes, std::nullopt); return builder.create(location, proto.getName(), funcType); @@ -118,7 +190,7 @@ class MLIRGenImpl { /// Emit a new function and add it to the MLIR module. mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope varScope(symbolTable); + SymbolTableScopeT varScope(symbolTable); // Create an MLIR function for the given prototype. builder.setInsertionPointToEnd(theModule.getBody()); @@ -133,8 +205,7 @@ class MLIRGenImpl { // Declare all the function arguments in the symbol table. for (const auto nameValue : llvm::zip(protoArgs, entryBlock.getArguments())) { - if (failed(declare(std::get<0>(nameValue)->getName(), - std::get<1>(nameValue)))) + if (failed(declare(*std::get<0>(nameValue), std::get<1>(nameValue)))) return nullptr; } @@ -160,8 +231,9 @@ class MLIRGenImpl { } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. - function.setType(builder.getFunctionType( - function.getFunctionType().getInputs(), getType(VarType{}))); + function.setType( + builder.getFunctionType(function.getFunctionType().getInputs(), + *returnOp.operand_type_begin())); } // If this function isn't main, then set the visibility to private. @@ -171,6 +243,71 @@ class MLIRGenImpl { return function; } + /// Return the struct type that is the result of the given expression, or null + /// if it cannot be inferred. + StructAST *getStructFor(ExprAST *expr) { + llvm::StringRef structName; + if (auto *decl = llvm::dyn_cast(expr)) { + auto varIt = symbolTable.lookup(decl->getName()); + if (!varIt.first) + return nullptr; + structName = varIt.second->getType().name; + } else if (auto *access = llvm::dyn_cast(expr)) { + if (access->getOp() != '.') + return nullptr; + // The name being accessed should be in the RHS. + auto *name = llvm::dyn_cast(access->getRHS()); + if (!name) + return nullptr; + StructAST *parentStruct = getStructFor(access->getLHS()); + if (!parentStruct) + return nullptr; + + // Get the element within the struct corresponding to the name. + VarDeclExprAST *decl = nullptr; + for (auto &var : parentStruct->getVariables()) { + if (var->getName() == name->getName()) { + decl = var.get(); + break; + } + } + if (!decl) + return nullptr; + structName = decl->getType().name; + } + if (structName.empty()) + return nullptr; + + // If the struct name was valid, check for an entry in the struct map. + auto structIt = structMap.find(structName); + if (structIt == structMap.end()) + return nullptr; + return structIt->second.second; + } + + /// Return the numeric member index of the given struct access expression. + std::optional getMemberIndex(BinaryExprAST &accessOp) { + assert(accessOp.getOp() == '.' && "expected access operation"); + + // Lookup the struct node for the LHS. + StructAST *structAST = getStructFor(accessOp.getLHS()); + if (!structAST) + return std::nullopt; + + // Get the name from the RHS. + VariableExprAST *name = llvm::dyn_cast(accessOp.getRHS()); + if (!name) + return std::nullopt; + + auto structVars = structAST->getVariables(); + const auto *it = llvm::find_if(structVars, [&](auto &var) { + return var->getName() == name->getName(); + }); + if (it == structVars.end()) + return std::nullopt; + return it - structVars.begin(); + } + /// Emit a binary operation mlir::Value mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting @@ -187,10 +324,22 @@ class MLIRGenImpl { mlir::Value lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; + auto location = loc(binop.loc()); + + // If this is an access operation, handle it immediately. + if (binop.getOp() == '.') { + std::optional accessIndex = getMemberIndex(binop); + if (!accessIndex) { + emitError(location, "invalid access into struct expression"); + return nullptr; + } + return builder.create(location, lhs, *accessIndex); + } + + // Otherwise, this is a normal binary op. mlir::Value rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; - auto location = loc(binop.loc()); // Derive the operation name from the binary operator. At the moment we only // support '+' and '*'. @@ -209,7 +358,7 @@ class MLIRGenImpl { /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. mlir::Value mlirGen(VariableExprAST &expr) { - if (auto variable = symbolTable.lookup(expr.getName())) + if (auto variable = symbolTable.lookup(expr.getName()).first) return variable; emitError(loc(expr.loc()), "error: unknown variable '") @@ -218,7 +367,7 @@ class MLIRGenImpl { } /// Emit a return operation. This will return failure if any generation fails. - mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + llvm::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. @@ -234,10 +383,10 @@ class MLIRGenImpl { return mlir::success(); } - /// Emit a literal/constant array. It will be emitted as a flattened array of - /// data in an Attribute attached to a `toy.constant` operation. - /// See documentation on [Attributes](LangRef.md#attributes) for more details. - /// Here is an excerpt: + /// Emit a constant for a literal/constant array. It will be emitted as a + /// flattened array of data in an Attribute attached to a `toy.constant` + /// operation. See documentation on [Attributes](LangRef.md#attributes) for + /// more details. Here is an excerpt: /// /// Attributes are the mechanism for specifying constant data in MLIR in /// places where a variable is never allowed [...]. They consist of a name @@ -252,9 +401,7 @@ class MLIRGenImpl { /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::Value mlirGen(LiteralExprAST &lit) { - auto type = getType(lit.getDims()); - + mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) { // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector data; @@ -269,14 +416,70 @@ class MLIRGenImpl { // This is the actual attribute that holds the list of values for this // tensor literal. - auto dataAttribute = - mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); + return mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); + } + mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) { + // The type of this attribute is tensor of 64-bit floating-point with no + // shape. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get({}, elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + return mlir::DenseElementsAttr::get(dataType, + llvm::ArrayRef(lit.getValue())); + } + /// Emit a constant for a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. This function returns the generated constant, along with the + /// corresponding struct type. + std::pair + getConstantAttr(StructLiteralExprAST &lit) { + std::vector attrElements; + std::vector typeElements; + + for (auto &var : lit.getValues()) { + if (auto *number = llvm::dyn_cast(var.get())) { + attrElements.push_back(getConstantAttr(*number)); + typeElements.push_back(getType(std::nullopt)); + } else if (auto *lit = llvm::dyn_cast(var.get())) { + attrElements.push_back(getConstantAttr(*lit)); + typeElements.push_back(getType(std::nullopt)); + } else { + auto *structLit = llvm::cast(var.get()); + auto attrTypePair = getConstantAttr(*structLit); + attrElements.push_back(attrTypePair.first); + typeElements.push_back(attrTypePair.second); + } + } + mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements); + mlir::Type dataType = StructType::get(typeElements); + return std::make_pair(dataAttr, dataType); + } + + /// Emit an array literal. + mlir::Value mlirGen(LiteralExprAST &lit) { + mlir::Type type = getType(lit.getDims()); + mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit); // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. return builder.create(loc(lit.loc()), type, dataAttribute); } + /// Emit a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. + mlir::Value mlirGen(StructLiteralExprAST &lit) { + mlir::ArrayAttr dataAttr; + mlir::Type dataType; + std::tie(dataAttr, dataType) = getConstantAttr(lit); + + // Build the MLIR op `toy.struct_constant`. This invokes the + // `StructConstantOp::build` method. + return builder.create(loc(lit.loc()), dataType, dataAttr); + } + /// Recursive helper function to accumulate the data that compose an array /// literal. It flattens the nested structure in the supplied vector. For /// example with this array: @@ -325,12 +528,20 @@ class MLIRGenImpl { // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create(location, callee, operands); + auto calledFuncIt = functionMap.find(callee); + if (calledFuncIt == functionMap.end()) { + emitError(location) << "no defined function found for '" << callee << "'"; + return nullptr; + } + mlir::toy::FuncOp calledFunc = calledFuncIt->second; + return builder.create( + location, calledFunc.getFunctionType().getResult(0), + mlir::SymbolRefAttr::get(builder.getContext(), callee), operands); } /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). - mlir::LogicalResult mlirGen(PrintExprAST &call) { + llvm::LogicalResult mlirGen(PrintExprAST &call) { auto arg = mlirGen(*call.getArg()); if (!arg) return mlir::failure(); @@ -353,6 +564,8 @@ class MLIRGenImpl { return mlirGen(cast(expr)); case toy::ExprAST::Expr_Literal: return mlirGen(cast(expr)); + case toy::ExprAST::Expr_StructLiteral: + return mlirGen(cast(expr)); case toy::ExprAST::Expr_Call: return mlirGen(cast(expr)); case toy::ExprAST::Expr_Num: @@ -381,23 +594,39 @@ class MLIRGenImpl { if (!value) return nullptr; - // We have the initializer value, but in case the variable was declared - // with specific shape, we emit a "reshape" operation. It will get - // optimized out later as needed. - if (!vardecl.getType().shape.empty()) { + // Handle the case where we are initializing a struct value. + VarType varType = vardecl.getType(); + if (!varType.name.empty()) { + // Check that the initializer type is the same as the variable + // declaration. + mlir::Type type = getType(varType, vardecl.loc()); + if (!type) + return nullptr; + if (type != value.getType()) { + emitError(loc(vardecl.loc())) + << "struct type of initializer is different than the variable " + "declaration. Got " + << value.getType() << ", but expected " << type; + return nullptr; + } + + // Otherwise, we have the initializer value, but in case the variable was + // declared with specific shape, we emit a "reshape" operation. It will + // get optimized out later as needed. + } else if (!varType.shape.empty()) { value = builder.create(loc(vardecl.loc()), - getType(vardecl.getType()), value); + getType(varType.shape), value); } // Register the value in the symbol table. - if (failed(declare(vardecl.getName(), value))) + if (failed(declare(vardecl, value))) return nullptr; return value; } /// Codegen a list of expression, return failure if one of them hit an error. - mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope varScope(symbolTable); + llvm::LogicalResult mlirGen(ExprASTList &blockAST) { + SymbolTableScopeT varScope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested @@ -433,8 +662,20 @@ class MLIRGenImpl { } /// Build an MLIR type from a Toy AST variable type (forward to the generic - /// getType above). - mlir::Type getType(const VarType &type) { return getType(type.shape); } + /// getType above for non-struct types). + mlir::Type getType(const VarType &type, const Location &location) { + if (!type.name.empty()) { + auto it = structMap.find(type.name); + if (it == structMap.end()) { + emitError(loc(location)) + << "error: unknown struct type '" << type.name << "'"; + return nullptr; + } + return it->second.first; + } + + return getType(type.shape); + } }; } // namespace diff --git a/mlir/example/Ch8/mlir/ShapeInferencePass.cpp b/mlir/example/Ch8/mlir/ShapeInferencePass.cpp index d45baa1..a9e995e 100644 --- a/mlir/example/Ch8/mlir/ShapeInferencePass.cpp +++ b/mlir/example/Ch8/mlir/ShapeInferencePass.cpp @@ -11,13 +11,21 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "toy/Dialect.h" #include "toy/Passes.h" #include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "shape-inference" diff --git a/mlir/example/Ch8/mlir/ToyCombine.cpp b/mlir/example/Ch8/mlir/ToyCombine.cpp index 43ffc5e..1d8cf74 100644 --- a/mlir/example/Ch8/mlir/ToyCombine.cpp +++ b/mlir/example/Ch8/mlir/ToyCombine.cpp @@ -11,11 +11,14 @@ // //===----------------------------------------------------------------------===// -#include - -#include "mlir/IR/Matchers.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "toy/Dialect.h" +#include "llvm/Support/Casting.h" +#include using namespace mlir; using namespace toy; @@ -24,6 +27,23 @@ namespace { #include "ToyCombine.inc" } // namespace +/// Fold constants. +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } + +/// Fold struct constants. +OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } + +/// Fold simple struct access operations that access into a constant. +OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { + auto structAttr = + llvm::dyn_cast_if_present(adaptor.getInput()); + if (!structAttr) + return nullptr; + + size_t elementIndex = getIndex(); + return structAttr[elementIndex]; +} + /// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { @@ -36,7 +56,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::LogicalResult + llvm::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. diff --git a/mlir/example/Ch8/parser/AST.cpp b/mlir/example/Ch8/parser/AST.cpp index 2eaabb1..e38a743 100644 --- a/mlir/example/Ch8/parser/AST.cpp +++ b/mlir/example/Ch8/parser/AST.cpp @@ -12,9 +12,12 @@ #include "toy/AST.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include using namespace toy; @@ -41,6 +44,7 @@ class ASTDumper { void dump(ExprASTList *exprList); void dump(NumberExprAST *num); void dump(LiteralExprAST *node); + void dump(StructLiteralExprAST *node); void dump(VariableExprAST *node); void dump(ReturnExprAST *node); void dump(BinaryExprAST *node); @@ -48,6 +52,7 @@ class ASTDumper { void dump(PrintExprAST *node); void dump(PrototypeAST *node); void dump(FunctionAST *node); + void dump(StructAST *node); // Actually print spaces matching the current indentation level void indent() { @@ -78,8 +83,8 @@ static std::string loc(T *node) { void ASTDumper::dump(ExprAST *expr) { llvm::TypeSwitch(expr) .Case( - [&](auto *node) { this->dump(node); }) + PrintExprAST, ReturnExprAST, StructLiteralExprAST, VarDeclExprAST, + VariableExprAST>([&](auto *node) { this->dump(node); }) .Default([&](ExprAST *) { // No match, fallback to a generic message INDENT(); @@ -94,7 +99,8 @@ void ASTDumper::dump(VarDeclExprAST *varDecl) { llvm::errs() << "VarDecl " << varDecl->getName(); dump(varDecl->getType()); llvm::errs() << " " << loc(varDecl) << "\n"; - dump(varDecl->getInitVal()); + if (auto *initVal = varDecl->getInitVal()) + dump(initVal); } /// A "block", or a list of expression @@ -145,6 +151,16 @@ void ASTDumper::dump(LiteralExprAST *node) { llvm::errs() << " " << loc(node) << "\n"; } +/// Print a struct literal. +void ASTDumper::dump(StructLiteralExprAST *node) { + INDENT(); + llvm::errs() << "Struct Literal: "; + for (auto &value : node->getValues()) + dump(value.get()); + indent(); + llvm::errs() << " " << loc(node) << "\n"; +} + /// Print a variable reference (just a name). void ASTDumper::dump(VariableExprAST *node) { INDENT(); @@ -194,7 +210,10 @@ void ASTDumper::dump(PrintExprAST *node) { /// Print type: only the shape is printed in between '<' and '>' void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - llvm::interleaveComma(type.shape, llvm::errs()); + if (!type.name.empty()) + llvm::errs() << type.name; + else + llvm::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } @@ -218,12 +237,33 @@ void ASTDumper::dump(FunctionAST *node) { dump(node->getBody()); } +/// Print a struct. +void ASTDumper::dump(StructAST *node) { + INDENT(); + llvm::errs() << "Struct: " << node->getName() << " " << loc(node) << "\n"; + + { + INDENT(); + llvm::errs() << "Variables: [\n"; + for (auto &variable : node->getVariables()) + dump(variable.get()); + indent(); + llvm::errs() << "]\n"; + } +} + /// Print a module, actually loop over the functions and print them in sequence. void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &f : *node) - dump(&f); + for (auto &record : *node) { + if (FunctionAST *function = llvm::dyn_cast(record.get())) + dump(function); + else if (StructAST *str = llvm::dyn_cast(record.get())) + dump(str); + else + llvm::errs() << "getKind() << ">\n"; + } } namespace toy { diff --git a/mlir/example/Ch8/struct-codegen.toy b/mlir/example/Ch8/struct-codegen.toy new file mode 100644 index 0000000..fa639c0 --- /dev/null +++ b/mlir/example/Ch8/struct-codegen.toy @@ -0,0 +1,19 @@ +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} diff --git a/mlir/example/Ch8/toyc.cpp b/mlir/example/Ch8/toyc.cpp index a7bc397..fea5679 100644 --- a/mlir/example/Ch8/toyc.cpp +++ b/mlir/example/Ch8/toyc.cpp @@ -10,8 +10,16 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" @@ -21,17 +29,14 @@ #include "mlir/IR/Verifier.h" #include "mlir/InitAllDialects.h" #include "mlir/Parser/Parser.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/Passes.h" -#include "toy/Dialect.h" -#include "toy/MLIRGen.h" -#include "toy/Parser.h" -#include "toy/Passes.h" + #include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" @@ -39,6 +44,11 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include using namespace toy; namespace cl = llvm::cl; @@ -101,7 +111,7 @@ int loadMLIR(mlir::MLIRContext &context, mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && - !llvm::StringRef(inputFilename).endswith(".mlir")) { + !llvm::StringRef(inputFilename).ends_with(".mlir")) { auto moduleAST = parseInputFile(inputFilename); if (!moduleAST) return 6; @@ -149,6 +159,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, // Now that there is only one function, we can infer the shapes of each of // the operations. mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(mlir::toy::createShapeInferencePass()); optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(mlir::createCSEPass()); @@ -176,8 +187,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, // This is necessary to have line tables emitted and basic // debugger working. In the future we will add proper debug information // emission directly from our frontend. - pm.addNestedPass( - mlir::LLVM::createDIScopeForLLVMFuncOpPass()); + pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); } if (mlir::failed(pm.run(*module))) @@ -216,7 +226,7 @@ int dumpLLVMIR(mlir::ModuleOp module) { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); - // Configure the LLVM Module + // Create target machine and configure the LLVM Module auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); if (!tmBuilderOrError) { llvm::errs() << "Could not create JITTargetMachineBuilder\n"; diff --git a/mlir/example/README.md b/mlir/example/README.md index 99b885e..178c467 100644 --- a/mlir/example/README.md +++ b/mlir/example/README.md @@ -4,17 +4,20 @@ ## Environment Setup +### Environment Preparation with conda + - OS must be higher than ubuntu 22.04. - install gcc-11 and g++-11 ```bash apt update -y && \ -apt install -yq gcc-11 g++-11 +apt install -yq gcc-13 g++-13 # apt install -yq software-properties-common \ # add-apt-repository -y ppa:ubuntu-toolchain-r/test \ +# apt update -y # apt install -yq gcc-11 g++-11 -update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 20 -update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 20 +update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-13 20 +update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-13 20 ``` - install cmake and ninja you can choose one way you like. conda is best for me. @@ -23,17 +26,38 @@ update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 20 conda create -n mlir -y conda activate mlir # conda install cmake ninja clang-format clang lld ncurses mlir llvm -c conda-forge -conda install cmake ninja clang-format clang clang-tools mlir zlib spdlog fmt lit llvm=17.* -c conda-forge -y +conda install cmake ninja clang-format clang clang-tools mlir zlib spdlog fmt lit llvm=19.* -c conda-forge -y # create -n mlir cmake ninja clang-format clang mlir zlib spdlog fmt lit llvm -c conda-forge -y ``` -## build example +### build example with conda + +```bash +cd example +bash build_with_conda.sh all +``` + +### Environment Preparation with dev containers + +Please choose the `Dev Containers: Open Folder in Container...` + +### build example with dev containers ```bash cd example +bash scripts/sync_deps.sh +bash scripts/build_deps.sh bash build.sh all ``` +## Configure the Clangd + +```bash +cd example +# after you configure the project with cmake, you can configure the clangd by run the following command +compdb -p build list > compile_commands.json +``` + ## Run These code and understand mlir - Ch1 @@ -902,19 +926,29 @@ $ ./build/Ch6/mlir-example-ch6 Ch6/example.toy -emit=llvm --mlir-print-ir-after- ```bash $ ./build/Ch7/mlir-example-ch7 Ch7/struct-codegen.toy -emit=jit +# 1.000000 16.000000 +# 4.000000 25.000000 +# 9.000000 36.000000 +``` + +- Ch8 + +```bash +$ ./vscode_build/Ch8/mlir-example-ch8 Ch8/matmul.toy.mlir -emit=mlir # module { -# toy.func private @multiply_transpose(%arg0: !toy.struct, tensor<*xf64>>) -> tensor<*xf64> { -# %0 = toy.struct_access %arg0[0] : !toy.struct, tensor<*xf64>> -> tensor<*xf64> -# %1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64> -# %2 = toy.struct_access %arg0[1] : !toy.struct, tensor<*xf64>> -> tensor<*xf64> -# %3 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64> -# %4 = toy.mul %1, %3 : tensor<*xf64> -# toy.return %4 : tensor<*xf64> +# toy.func private @matmul_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { +# %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> +# %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> +# %2 = toy.matmul(%0 : tensor<*xf64>, %1 : tensor<*xf64>) to tensor<*xf64> +# toy.return %2 : tensor<*xf64> # } # toy.func @main() { -# %0 = toy.struct_constant [dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>] : !toy.struct, tensor<*xf64>> -# %1 = toy.generic_call @multiply_transpose(%0) : (!toy.struct, tensor<*xf64>>) -> tensor<*xf64> -# toy.print %1 : tensor<*xf64> +# %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> +# %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> +# %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<3x2xf64> +# %4 = toy.generic_call @matmul_transpose(%1, %3) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<*xf64> +# toy.print %4 : tensor<*xf64> # toy.return # } # } diff --git a/mlir/example/build.sh b/mlir/example/build.sh index ba9a755..8879550 100644 --- a/mlir/example/build.sh +++ b/mlir/example/build.sh @@ -10,25 +10,15 @@ _workspaceFolder=$(pwd) cd build # For non-conda users: -# cmake .. -G Ninja \ -# -DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=TRUE \ -# -DCMAKE_BUILD_TYPE:STRING=Debug \ -# -DCMAKE_C_COMPILER:FILEPATH=/usr/bin/gcc \ -# -DCMAKE_CXX_COMPILER:FILEPATH=/usr/bin/g++ \ -# -DMLIR_DIR=${_workspaceFolder}/third_party/lib/cmake/mlir \ -# -DLLVM_DIR=${_workspaceFolder}/third_party/lib/cmake/llvm \ -# -DCMAKE_MODULE_PATH="${_workspaceFolder}/third_party/lib/cmake/mlir;${_workspaceFolder}/third_party/lib/cmake/llvm" \ -# -DMLIR_TABLEGEN_EXE=${_workspaceFolder}/third_party/bin/mlir-tblgen -# # -DLibEdit_DIR=/root/miniconda3/envs/mlir/lib - -cmake .. -G Ninja --no-warn-unused-cli \ - -Wno-dev \ - -DCMAKE_MODULE_PATH="/root/miniconda3/envs/mlir/lib/cmake/mlir;/root/miniconda3/envs/mlir/lib/cmake/llvm" \ - -DMLIR_TABLEGEN_EXE:FILEPATH=/root/miniconda3/envs/mlir/bin/mlir-tblgen \ +cmake .. -Wno-dev -G Ninja \ -DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=TRUE \ -DCMAKE_BUILD_TYPE:STRING=Debug \ -DCMAKE_C_COMPILER:FILEPATH=/usr/bin/gcc \ - -DCMAKE_CXX_COMPILER:FILEPATH=/usr/bin/g++ + -DCMAKE_CXX_COMPILER:FILEPATH=/usr/bin/g++ \ + -DMLIR_DIR=${_workspaceFolder}/third_party/llvm/lib/cmake/mlir \ + -DLLVM_DIR=${_workspaceFolder}/third_party/llvm/lib/cmake/llvm \ + -DCMAKE_MODULE_PATH="${_workspaceFolder}/third_party/llvm/lib/cmake/mlir;${_workspaceFolder}/third_party/llvm/lib/cmake/llvm" \ + -DMLIR_TABLEGEN_EXE=${_workspaceFolder}/third_party/llvm/bin/mlir-tblgen # ninja cmake \ diff --git a/mlir/example/build_with_conda.sh b/mlir/example/build_with_conda.sh new file mode 100644 index 0000000..c81f22d --- /dev/null +++ b/mlir/example/build_with_conda.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +_target=${1:-'all'} + +rm -rf build +mkdir build + +_workspaceFolder=$(pwd) + +cd build + + +cmake .. -G Ninja --no-warn-unused-cli \ + -Wno-dev \ + -DCMAKE_MODULE_PATH="/root/miniconda3/envs/mlir/lib/cmake/mlir;/root/miniconda3/envs/mlir/lib/cmake/llvm" \ + -DMLIR_TABLEGEN_EXE:FILEPATH=/root/miniconda3/envs/mlir/bin/mlir-tblgen \ + -DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=TRUE \ + -DCMAKE_BUILD_TYPE:STRING=Debug \ + -DCMAKE_C_COMPILER:FILEPATH=/usr/bin/gcc \ + -DCMAKE_CXX_COMPILER:FILEPATH=/usr/bin/g++ + +# ninja +cmake \ + --build ${_workspaceFolder}/build \ + --config Debug --target ${_target} diff --git a/mlir/example/scripts/build_deps.sh b/mlir/example/scripts/build_deps.sh new file mode 100644 index 0000000..07d4088 --- /dev/null +++ b/mlir/example/scripts/build_deps.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +if [[ $# -ne 2 ]] ; then + echo "Usage: $0 " + exit 1 +fi + +# LLVM source +LLVM_SRC_DIR="${1:-third_party/llvm-project}" +build_dir="${LLVM_SRC_DIR}/build" +install_dir="${2:-third_party}"/llvm + +if ! [ -f "$LLVM_SRC_DIR/llvm/CMakeLists.txt" ]; then + echo "Expected the path to LLVM to be set correctly (got '$LLVM_SRC_DIR'): can't find CMakeLists.txt" + exit 1 +fi +echo "Using LLVM source dir: $LLVM_SRC_DIR" + +# Setup directories. +echo "Building MLIR in $build_dir" +rm -rf "$build_dir" +mkdir -p "$build_dir" + +echo "Installing MLIR in $install_dir" +rm -rf ${install_dir} +mkdir -p ${install_dir} + +echo "Beginning build (commands will echo)" +set -x + +cmake -GNinja \ + "-H$LLVM_SRC_DIR/llvm" \ + "-B$build_dir" \ + -DCMAKE_BUILD_TYPE=Debug \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \ + -DLLVM_ENABLE_LLD=ON \ + -DLLVM_ENABLE_BACKTRACES=OFF \ + -DLLVM_INCLUDE_UTILS=ON \ + -DCMAKE_INSTALL_PREFIX=${install_dir} \ + -DLLVM_INSTALL_UTILS=ON \ + -DLLVM_BUILD_UTILS=ON \ + -DLLVM_INCLUDE_TOOLS=ON \ + -DLLVM_BUILD_TOOLS=ON \ + -DLLVM_BUILD_LLVM_DYLIB=ON \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DLLVM_LINK_LLVM_DYLIB=ON + + # -DLLVM_ENABLE_RTTI=ON \ + # -DLLVM_ENABLE_LIBEDIT=OFF \ + # -DLLVM_ENABLE_BINDINGS=OFF \ + # -DLLVM_INCLUDE_DOCS=OFF \ + # -DLLVM_INCLUDE_TESTS=ON \ + # -DLLVM_INCLUDE_BENCHMARKS=OFF \ + # -DLLVM_ENABLE_BACKTRACES=ON \ + # -DLLVM_INCLUDE_EXAMPLES=OFF \ + # -DLLVM_ENABLE_ASSERTIONS=On + # -DBUILD_SHARED_LIBS=ON \ + +# cmake --build "$build_dir" +cmake --build "$build_dir" + +pushd "$build_dir" +ninja install +popd diff --git a/mlir/example/scripts/patch/matmul.patch b/mlir/example/scripts/patch/matmul.patch index e68fb37..f8a1417 100644 --- a/mlir/example/scripts/patch/matmul.patch +++ b/mlir/example/scripts/patch/matmul.patch @@ -74,17 +74,24 @@ index 6ec105a..d750782 100644 + state.addOperands({lhs, rhs}); +} + -+// mlir::ParseResult MatMulOp::parse(mlir::OpAsmParser &parser, -+// mlir::OperationState &result) { -+// return parseBinaryOp(parser, result); -+// } -+ -+// void MatMulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } ++/// Infer the output shape of the MatMulOp, this is required by the shape ++/// inference interface. ++void MatMulOp::inferShapes() { ++ RankedTensorType lhsType = ++ llvm::dyn_cast(getLhs().getType()); ++ RankedTensorType rhsType = ++ llvm::dyn_cast(getRhs().getType()); ++ auto lhsShape = lhsType.getShape(); ++ auto rhsShape = rhsType.getShape(); ++ RankedTensorType res_type = RankedTensorType::get({lhsShape[0], rhsShape[1]}, ++ lhsType.getElementType()); ++ getResult().setType(res_type); ++} + -+mlir::LogicalResult MatMulOp::verify() { -+ auto lhsType = getLhs().getType().dyn_cast(); -+ auto rhsType = getRhs().getType().dyn_cast(); -+ auto resultType = getType().dyn_cast(); ++llvm::LogicalResult MatMulOp::verify() { ++ auto lhsType = llvm::dyn_cast(getLhs().getType()); ++ auto rhsType = llvm::dyn_cast(getRhs().getType()); ++ auto resultType = llvm::dyn_cast(getType()); + + if (!lhsType || !rhsType || !resultType) + return mlir::success(); @@ -108,26 +115,6 @@ index 6ec105a..d750782 100644 + + return mlir::success(); +} -+ -+/// Infer the output shape of the MatMulOp, this is required by the shape -+/// inference interface. -+void MatMulOp::inferShapes() { -+ RankedTensorType lhsType = getLhs().getType().cast(); -+ RankedTensorType rhsType = getRhs().getType().cast(); -+ auto lhsShape = lhsType.getShape(); -+ auto rhsShape = rhsType.getShape(); -+ RankedTensorType res_type = RankedTensorType::get({lhsShape[0], rhsShape[1]}, -+ lhsType.getElementType()); -+ getResult().setType(res_type); -+} -+ //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// -diff --git a/torch/pytorch b/torch/pytorch -index 256fed0..138e289 160000 ---- a/torch/pytorch -+++ b/torch/pytorch -@@ -1 +1 @@ --Subproject commit 256fed02e930210dbcd7e5e23fcf142362098c2a -+Subproject commit 138e2895d08a6517c5718b2a0118c1b23ff4664c-dirty diff --git a/mlir/example/scripts/sync_deps.sh b/mlir/example/scripts/sync_deps.sh new file mode 100644 index 0000000..a1edad4 --- /dev/null +++ b/mlir/example/scripts/sync_deps.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +mkdir -p third_party + +git clone -b release/19.x --depth 1 https://github.com/llvm/llvm-project.git third_party/llvm-project diff --git a/mlir/example/scripts/update.sh b/mlir/example/scripts/update.sh index 894f3b4..e6cf698 100644 --- a/mlir/example/scripts/update.sh +++ b/mlir/example/scripts/update.sh @@ -2,13 +2,19 @@ WORKSPACE=`pwd` -_llvm_branch=$1 +_llvm_branch=${1:-"release/19.x"} _dirs="Ch1 Ch2 Ch3 Ch4 Ch5 Ch6 Ch7" +_transform_dirs="Ch2 Ch3 Ch4" -_mlir_example_dir="llvm-project/mlir/examples/toy" +_example_in_llvm_project="third_party/llvm-project/mlir/examples" -[[ -d llvm-project ]] || git clone -b $_llvm_branch https://github.com/llvm/llvm-project.git +_mlir_example_dir="${_example_in_llvm_project}/toy" +_mlir_transform_dir="${_example_in_llvm_project}/transform" + +[[ -d "third_party/llvm-project" ]] || git clone -b $_llvm_branch https://github.com/llvm/llvm-project.git third_party/llvm-project + +# update the mlir Toy examples for dir in $_dirs; do @@ -18,7 +24,7 @@ for dir in $_dirs; do rm -rf $(find ./ -name "*.td") popd - pushd "$WORKSPACE/llvm-project/mlir/examples/toy/$dir" + pushd "$WORKSPACE/${_mlir_example_dir}/$dir" for cpps in $(find ./ -name "*.cpp"); do cp ${cpps} "$WORKSPACE/$dir/${cpps}" @@ -35,3 +41,34 @@ for dir in $_dirs; do popd done + +# update the mlir transform examples + +for tdir in $_transform_dirs; do + + pushd "$WORKSPACE/transform_$tdir" + rm -rf $(find ./ -name "*.cpp") + rm -rf $(find ./ -name "*.h") + rm -rf $(find ./ -name "*.td") + popd + + pushd "$WORKSPACE/${_mlir_transform_dir}/$tdir" + + for cpps in $(find ./ -name "*.cpp"); do + cp ${cpps} "$WORKSPACE/transform_$tdir/${cpps}" + # echo "cp ${cpps} $WORKSPACE/transform_$tdir/${cpps}" + done + + for hs in $(find ./ -name "*.h"); do + cp ${hs} "$WORKSPACE/transform_$tdir/${hs}" + # echo "cp ${hs} $WORKSPACE/transform_$tdir/${hs}" + done + + for tds in $(find ./ -name "*.td"); do + cp ${tds} "$WORKSPACE/transform_$tdir/${tds}" + # echo "cp ${tds} $WORKSPACE/transform_$tdir/${tds}" + done + + popd + +done diff --git a/mlir/example/Toy/Ch1/ast.toy b/mlir/example/tests/Toy/Ch1/ast.toy similarity index 100% rename from mlir/example/Toy/Ch1/ast.toy rename to mlir/example/tests/Toy/Ch1/ast.toy diff --git a/mlir/example/Toy/Ch1/empty.toy b/mlir/example/tests/Toy/Ch1/empty.toy similarity index 100% rename from mlir/example/Toy/Ch1/empty.toy rename to mlir/example/tests/Toy/Ch1/empty.toy diff --git a/mlir/example/Toy/Ch2/ast.toy b/mlir/example/tests/Toy/Ch2/ast.toy similarity index 100% rename from mlir/example/Toy/Ch2/ast.toy rename to mlir/example/tests/Toy/Ch2/ast.toy diff --git a/mlir/example/Toy/Ch2/codegen.toy b/mlir/example/tests/Toy/Ch2/codegen.toy similarity index 100% rename from mlir/example/Toy/Ch2/codegen.toy rename to mlir/example/tests/Toy/Ch2/codegen.toy diff --git a/mlir/example/Toy/Ch2/empty.toy b/mlir/example/tests/Toy/Ch2/empty.toy similarity index 100% rename from mlir/example/Toy/Ch2/empty.toy rename to mlir/example/tests/Toy/Ch2/empty.toy diff --git a/mlir/example/Toy/Ch2/invalid.mlir b/mlir/example/tests/Toy/Ch2/invalid.mlir similarity index 100% rename from mlir/example/Toy/Ch2/invalid.mlir rename to mlir/example/tests/Toy/Ch2/invalid.mlir diff --git a/mlir/example/Toy/Ch2/scalar.toy b/mlir/example/tests/Toy/Ch2/scalar.toy similarity index 100% rename from mlir/example/Toy/Ch2/scalar.toy rename to mlir/example/tests/Toy/Ch2/scalar.toy diff --git a/mlir/example/Toy/Ch3/ast.toy b/mlir/example/tests/Toy/Ch3/ast.toy similarity index 100% rename from mlir/example/Toy/Ch3/ast.toy rename to mlir/example/tests/Toy/Ch3/ast.toy diff --git a/mlir/example/Toy/Ch3/codegen.toy b/mlir/example/tests/Toy/Ch3/codegen.toy similarity index 100% rename from mlir/example/Toy/Ch3/codegen.toy rename to mlir/example/tests/Toy/Ch3/codegen.toy diff --git a/mlir/example/Toy/Ch3/empty.toy b/mlir/example/tests/Toy/Ch3/empty.toy similarity index 100% rename from mlir/example/Toy/Ch3/empty.toy rename to mlir/example/tests/Toy/Ch3/empty.toy diff --git a/mlir/example/Toy/Ch3/invalid.mlir b/mlir/example/tests/Toy/Ch3/invalid.mlir similarity index 100% rename from mlir/example/Toy/Ch3/invalid.mlir rename to mlir/example/tests/Toy/Ch3/invalid.mlir diff --git a/mlir/example/Toy/Ch3/scalar.toy b/mlir/example/tests/Toy/Ch3/scalar.toy similarity index 100% rename from mlir/example/Toy/Ch3/scalar.toy rename to mlir/example/tests/Toy/Ch3/scalar.toy diff --git a/mlir/example/Toy/Ch3/transpose_transpose.toy b/mlir/example/tests/Toy/Ch3/transpose_transpose.toy similarity index 100% rename from mlir/example/Toy/Ch3/transpose_transpose.toy rename to mlir/example/tests/Toy/Ch3/transpose_transpose.toy diff --git a/mlir/example/Toy/Ch3/trivial_reshape.toy b/mlir/example/tests/Toy/Ch3/trivial_reshape.toy similarity index 100% rename from mlir/example/Toy/Ch3/trivial_reshape.toy rename to mlir/example/tests/Toy/Ch3/trivial_reshape.toy diff --git a/mlir/example/Toy/Ch4/ast.toy b/mlir/example/tests/Toy/Ch4/ast.toy similarity index 100% rename from mlir/example/Toy/Ch4/ast.toy rename to mlir/example/tests/Toy/Ch4/ast.toy diff --git a/mlir/example/Toy/Ch4/codegen.toy b/mlir/example/tests/Toy/Ch4/codegen.toy similarity index 100% rename from mlir/example/Toy/Ch4/codegen.toy rename to mlir/example/tests/Toy/Ch4/codegen.toy diff --git a/mlir/example/Toy/Ch4/empty.toy b/mlir/example/tests/Toy/Ch4/empty.toy similarity index 100% rename from mlir/example/Toy/Ch4/empty.toy rename to mlir/example/tests/Toy/Ch4/empty.toy diff --git a/mlir/example/Toy/Ch4/invalid.mlir b/mlir/example/tests/Toy/Ch4/invalid.mlir similarity index 100% rename from mlir/example/Toy/Ch4/invalid.mlir rename to mlir/example/tests/Toy/Ch4/invalid.mlir diff --git a/mlir/example/Toy/Ch4/scalar.toy b/mlir/example/tests/Toy/Ch4/scalar.toy similarity index 100% rename from mlir/example/Toy/Ch4/scalar.toy rename to mlir/example/tests/Toy/Ch4/scalar.toy diff --git a/mlir/example/Toy/Ch4/shape_inference.mlir b/mlir/example/tests/Toy/Ch4/shape_inference.mlir similarity index 100% rename from mlir/example/Toy/Ch4/shape_inference.mlir rename to mlir/example/tests/Toy/Ch4/shape_inference.mlir diff --git a/mlir/example/Toy/Ch4/transpose_transpose.toy b/mlir/example/tests/Toy/Ch4/transpose_transpose.toy similarity index 100% rename from mlir/example/Toy/Ch4/transpose_transpose.toy rename to mlir/example/tests/Toy/Ch4/transpose_transpose.toy diff --git a/mlir/example/Toy/Ch4/trivial_reshape.toy b/mlir/example/tests/Toy/Ch4/trivial_reshape.toy similarity index 100% rename from mlir/example/Toy/Ch4/trivial_reshape.toy rename to mlir/example/tests/Toy/Ch4/trivial_reshape.toy diff --git a/mlir/example/Toy/Ch5/affine-lowering.mlir b/mlir/example/tests/Toy/Ch5/affine-lowering.mlir similarity index 100% rename from mlir/example/Toy/Ch5/affine-lowering.mlir rename to mlir/example/tests/Toy/Ch5/affine-lowering.mlir diff --git a/mlir/example/Toy/Ch5/ast.toy b/mlir/example/tests/Toy/Ch5/ast.toy similarity index 100% rename from mlir/example/Toy/Ch5/ast.toy rename to mlir/example/tests/Toy/Ch5/ast.toy diff --git a/mlir/example/Toy/Ch5/codegen.toy b/mlir/example/tests/Toy/Ch5/codegen.toy similarity index 100% rename from mlir/example/Toy/Ch5/codegen.toy rename to mlir/example/tests/Toy/Ch5/codegen.toy diff --git a/mlir/example/Toy/Ch5/empty.toy b/mlir/example/tests/Toy/Ch5/empty.toy similarity index 100% rename from mlir/example/Toy/Ch5/empty.toy rename to mlir/example/tests/Toy/Ch5/empty.toy diff --git a/mlir/example/Toy/Ch5/invalid.mlir b/mlir/example/tests/Toy/Ch5/invalid.mlir similarity index 100% rename from mlir/example/Toy/Ch5/invalid.mlir rename to mlir/example/tests/Toy/Ch5/invalid.mlir diff --git a/mlir/example/Toy/Ch5/scalar.toy b/mlir/example/tests/Toy/Ch5/scalar.toy similarity index 100% rename from mlir/example/Toy/Ch5/scalar.toy rename to mlir/example/tests/Toy/Ch5/scalar.toy diff --git a/mlir/example/Toy/Ch5/shape_inference.mlir b/mlir/example/tests/Toy/Ch5/shape_inference.mlir similarity index 100% rename from mlir/example/Toy/Ch5/shape_inference.mlir rename to mlir/example/tests/Toy/Ch5/shape_inference.mlir diff --git a/mlir/example/Toy/Ch5/transpose_transpose.toy b/mlir/example/tests/Toy/Ch5/transpose_transpose.toy similarity index 100% rename from mlir/example/Toy/Ch5/transpose_transpose.toy rename to mlir/example/tests/Toy/Ch5/transpose_transpose.toy diff --git a/mlir/example/Toy/Ch5/trivial_reshape.toy b/mlir/example/tests/Toy/Ch5/trivial_reshape.toy similarity index 100% rename from mlir/example/Toy/Ch5/trivial_reshape.toy rename to mlir/example/tests/Toy/Ch5/trivial_reshape.toy diff --git a/mlir/example/Toy/Ch6/affine-lowering.mlir b/mlir/example/tests/Toy/Ch6/affine-lowering.mlir similarity index 100% rename from mlir/example/Toy/Ch6/affine-lowering.mlir rename to mlir/example/tests/Toy/Ch6/affine-lowering.mlir diff --git a/mlir/example/Toy/Ch6/ast.toy b/mlir/example/tests/Toy/Ch6/ast.toy similarity index 100% rename from mlir/example/Toy/Ch6/ast.toy rename to mlir/example/tests/Toy/Ch6/ast.toy diff --git a/mlir/example/Toy/Ch6/codegen.toy b/mlir/example/tests/Toy/Ch6/codegen.toy similarity index 100% rename from mlir/example/Toy/Ch6/codegen.toy rename to mlir/example/tests/Toy/Ch6/codegen.toy diff --git a/mlir/example/Toy/Ch6/empty.toy b/mlir/example/tests/Toy/Ch6/empty.toy similarity index 100% rename from mlir/example/Toy/Ch6/empty.toy rename to mlir/example/tests/Toy/Ch6/empty.toy diff --git a/mlir/example/Toy/Ch6/invalid.mlir b/mlir/example/tests/Toy/Ch6/invalid.mlir similarity index 100% rename from mlir/example/Toy/Ch6/invalid.mlir rename to mlir/example/tests/Toy/Ch6/invalid.mlir diff --git a/mlir/example/Toy/Ch6/jit.toy b/mlir/example/tests/Toy/Ch6/jit.toy similarity index 100% rename from mlir/example/Toy/Ch6/jit.toy rename to mlir/example/tests/Toy/Ch6/jit.toy diff --git a/mlir/example/Toy/Ch6/scalar.toy b/mlir/example/tests/Toy/Ch6/scalar.toy similarity index 100% rename from mlir/example/Toy/Ch6/scalar.toy rename to mlir/example/tests/Toy/Ch6/scalar.toy diff --git a/mlir/example/Toy/Ch6/shape_inference.mlir b/mlir/example/tests/Toy/Ch6/shape_inference.mlir similarity index 100% rename from mlir/example/Toy/Ch6/shape_inference.mlir rename to mlir/example/tests/Toy/Ch6/shape_inference.mlir diff --git a/mlir/example/Toy/Ch6/transpose_transpose.toy b/mlir/example/tests/Toy/Ch6/transpose_transpose.toy similarity index 100% rename from mlir/example/Toy/Ch6/transpose_transpose.toy rename to mlir/example/tests/Toy/Ch6/transpose_transpose.toy diff --git a/mlir/example/Toy/Ch6/trivial_reshape.toy b/mlir/example/tests/Toy/Ch6/trivial_reshape.toy similarity index 100% rename from mlir/example/Toy/Ch6/trivial_reshape.toy rename to mlir/example/tests/Toy/Ch6/trivial_reshape.toy diff --git a/mlir/example/Toy/Ch7/affine-lowering.mlir b/mlir/example/tests/Toy/Ch7/affine-lowering.mlir similarity index 100% rename from mlir/example/Toy/Ch7/affine-lowering.mlir rename to mlir/example/tests/Toy/Ch7/affine-lowering.mlir diff --git a/mlir/example/Toy/Ch7/ast.toy b/mlir/example/tests/Toy/Ch7/ast.toy similarity index 100% rename from mlir/example/Toy/Ch7/ast.toy rename to mlir/example/tests/Toy/Ch7/ast.toy diff --git a/mlir/example/Toy/Ch7/codegen.toy b/mlir/example/tests/Toy/Ch7/codegen.toy similarity index 100% rename from mlir/example/Toy/Ch7/codegen.toy rename to mlir/example/tests/Toy/Ch7/codegen.toy diff --git a/mlir/example/Toy/Ch7/empty.toy b/mlir/example/tests/Toy/Ch7/empty.toy similarity index 100% rename from mlir/example/Toy/Ch7/empty.toy rename to mlir/example/tests/Toy/Ch7/empty.toy diff --git a/mlir/example/Toy/Ch7/invalid.mlir b/mlir/example/tests/Toy/Ch7/invalid.mlir similarity index 100% rename from mlir/example/Toy/Ch7/invalid.mlir rename to mlir/example/tests/Toy/Ch7/invalid.mlir diff --git a/mlir/example/Toy/Ch7/jit.toy b/mlir/example/tests/Toy/Ch7/jit.toy similarity index 100% rename from mlir/example/Toy/Ch7/jit.toy rename to mlir/example/tests/Toy/Ch7/jit.toy diff --git a/mlir/example/Toy/Ch7/scalar.toy b/mlir/example/tests/Toy/Ch7/scalar.toy similarity index 100% rename from mlir/example/Toy/Ch7/scalar.toy rename to mlir/example/tests/Toy/Ch7/scalar.toy diff --git a/mlir/example/Toy/Ch7/shape_inference.mlir b/mlir/example/tests/Toy/Ch7/shape_inference.mlir similarity index 100% rename from mlir/example/Toy/Ch7/shape_inference.mlir rename to mlir/example/tests/Toy/Ch7/shape_inference.mlir diff --git a/mlir/example/Toy/Ch7/struct-ast.toy b/mlir/example/tests/Toy/Ch7/struct-ast.toy similarity index 100% rename from mlir/example/Toy/Ch7/struct-ast.toy rename to mlir/example/tests/Toy/Ch7/struct-ast.toy diff --git a/mlir/example/Toy/Ch7/struct-codegen.toy b/mlir/example/tests/Toy/Ch7/struct-codegen.toy similarity index 100% rename from mlir/example/Toy/Ch7/struct-codegen.toy rename to mlir/example/tests/Toy/Ch7/struct-codegen.toy diff --git a/mlir/example/Toy/Ch7/struct-opt.mlir b/mlir/example/tests/Toy/Ch7/struct-opt.mlir similarity index 100% rename from mlir/example/Toy/Ch7/struct-opt.mlir rename to mlir/example/tests/Toy/Ch7/struct-opt.mlir diff --git a/mlir/example/Toy/Ch7/transpose_transpose.toy b/mlir/example/tests/Toy/Ch7/transpose_transpose.toy similarity index 100% rename from mlir/example/Toy/Ch7/transpose_transpose.toy rename to mlir/example/tests/Toy/Ch7/transpose_transpose.toy diff --git a/mlir/example/Toy/Ch7/trivial_reshape.toy b/mlir/example/tests/Toy/Ch7/trivial_reshape.toy similarity index 100% rename from mlir/example/Toy/Ch7/trivial_reshape.toy rename to mlir/example/tests/Toy/Ch7/trivial_reshape.toy diff --git a/mlir/example/tests/transform/Ch1/invalidation-1.mlir b/mlir/example/tests/transform/Ch1/invalidation-1.mlir new file mode 100644 index 0000000..2264ade --- /dev/null +++ b/mlir/example/tests/transform/Ch1/invalidation-1.mlir @@ -0,0 +1,101 @@ +// RUN: mlir-opt %s \ +// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \ +// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\ +// RUN: canonicalize,cse,symbol-dce)" \ +// RUN: --split-input-file --verify-diagnostics + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op, + // expected-note @below {{handle to invalidated ops}} + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">) { + // The actual tiling transformation takes tile sizes as attributes. + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + %tiled, %loop = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32] + : (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op) + + // This is trying to use an invalidated handle leading to undefined behavior. + // expected-error @below {{uses a handle invalidated by a previously executed transform op}} + transform.debug.emit_remark_at %arg1, "remark" : !transform.op<"linalg.matmul"> + transform.yield + } +} + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-note @below {{payload op}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// ----- + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">) { + // We can cast one type to another as long as operations are compatible + // with both types. This creates "aliasing" handles. + // expected-note @below {{handle to invalidated ops}} + %casted = transform.cast %arg1 : !transform.op<"linalg.matmul"> to + !transform.any_op + + // The actual tiling transformation takes tile sizes as attributes. + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + %tiled, %loop = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32] + : (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op) + + // Consuming an operand invalidates the consumed handle and any other handle that is + // associated with the same payload operations, or payload operations nested in them. + // expected-error @below {{uses a handle invalidated by a previously executed transform op}} + transform.debug.emit_remark_at %casted, "remark" + : !transform.any_op + transform.yield + } +} + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-note @below {{payload op}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} diff --git a/mlir/example/tests/transform/Ch1/invalidation-2.mlir b/mlir/example/tests/transform/Ch1/invalidation-2.mlir new file mode 100644 index 0000000..b8d74a3 --- /dev/null +++ b/mlir/example/tests/transform/Ch1/invalidation-2.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt %s \ +// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \ +// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\ +// RUN: canonicalize,cse,symbol-dce)" \ +// RUN: --split-input-file --verify-diagnostics +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + + // expected-note @below {{nested payload op}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + + // expected-note @below {{ancestor payload op}} + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// Declaration of the "microkernel" function that we will be targeting. +func.func private @microkernel( + %lhs: tensor<4x512xf32>, + %rhs: tensor<512x4xf32>, + %bias: tensor<4x4xf32>, + %init: tensor<4x4xf32>, + %output: tensor<4x4xf32>) -> tensor<4x4xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">) { + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 + : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2, %loop_second_2 = + transform.structured.fuse_into_containing_op %matmul_fused into %loop_second + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // expected-note @below {{handle to invalidated ops}} + %f, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) + + // expected-error @below {{uses a handle invalidated by a previously executed transform op}} + transform.debug.emit_remark_at %f, "fused" : !transform.any_op + + transform.yield + } +} diff --git a/mlir/example/tests/transform/Ch1/sequence.mlir b/mlir/example/tests/transform/Ch1/sequence.mlir new file mode 100644 index 0000000..df87fc5 --- /dev/null +++ b/mlir/example/tests/transform/Ch1/sequence.mlir @@ -0,0 +1,112 @@ +// RUN: mlir-opt %s \ +// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \ +// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\ +// RUN: canonicalize,cse,symbol-dce)" |\ +// RUN: FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// CHECK: func @outlined +// CHECK: linalg.matmul +// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn} + +// CHECK-LABEL: func @fc_relu +// CHECK: scf.forall +// CHECK: scf.forall +// CHECK: %[[SLICE4:.+]] = tensor.extract_slice +// CHECK: %[[SLICE5:.+]] = tensor.extract_slice +// CHECK: %[[SLICE6:.+]] = tensor.extract_slice +// CHECK: %[[SLICE7:.+]] = tensor.extract_slice +// CHECK: %[[SLICE8:.+]] = tensor.extract_slice +// CHECK: func.call @outlined(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]]) +// CHECK-NOT: linalg.matmul +// CHECK-NOT: linalg.elemwise_binary +// CHECK: scf.forall.in_parallel +// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn} +// CHECK: scf.forall.in_parallel + +// Declaration of the "microkernel" function that we will be targeting. +func.func private @microkernel( + %lhs: tensor<4x512xf32>, + %rhs: tensor<512x4xf32>, + %bias: tensor<4x4xf32>, + %init: tensor<4x4xf32>, + %output: tensor<4x4xf32>) -> tensor<4x4xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">) { + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 + : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2, %loop_second_2 = + transform.structured.fuse_into_containing_op %matmul_fused into %loop_second + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) + + transform.yield + } +} diff --git a/mlir/example/tests/transform/Ch2/invalid.mlir b/mlir/example/tests/transform/Ch2/invalid.mlir new file mode 100644 index 0000000..cb67389 --- /dev/null +++ b/mlir/example/tests/transform/Ch2/invalid.mlir @@ -0,0 +1,11 @@ +// RUN: transform-opt-ch2 %s --transform-interpreter --split-input-file \ +// RUN: --verify-diagnostics + +// expected-note @below {{offending payload}} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{only applies to func.call payloads}} + transform.my.change_call_target %arg0, "updated" : !transform.any_op + transform.yield + } +} diff --git a/mlir/example/tests/transform/Ch2/ops.mlir b/mlir/example/tests/transform/Ch2/ops.mlir new file mode 100644 index 0000000..410a6e3 --- /dev/null +++ b/mlir/example/tests/transform/Ch2/ops.mlir @@ -0,0 +1,27 @@ +// RUN: transform-opt-ch2 %s --transform-interpreter | FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +func.func private @orig() +func.func private @updated() + +// CHECK-LABEL: func @test +func.func @test() { + // CHECK: call @updated + call @orig() : () -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %call = transform.structured.match ops{["func.call"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.my.change_call_target %{{.*}}, "updated" : !transform.any_op + transform.my.change_call_target %call, "updated" : !transform.any_op + transform.yield + } +} diff --git a/mlir/example/tests/transform/Ch2/sequence.mlir b/mlir/example/tests/transform/Ch2/sequence.mlir new file mode 100644 index 0000000..976df1d --- /dev/null +++ b/mlir/example/tests/transform/Ch2/sequence.mlir @@ -0,0 +1,111 @@ +// RUN: transform-opt-ch2 %s \ +// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \ +// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\ +// RUN: canonicalize,cse,symbol-dce)" |\ +// RUN: FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// CHECK-LABEL: func @fc_relu +// CHECK: scf.forall +// CHECK: scf.forall +// CHECK: %[[SLICE4:.+]] = tensor.extract_slice +// CHECK: %[[SLICE5:.+]] = tensor.extract_slice +// CHECK: %[[SLICE6:.+]] = tensor.extract_slice +// CHECK: %[[SLICE7:.+]] = tensor.extract_slice +// CHECK: %[[SLICE8:.+]] = tensor.extract_slice +// CHECK: func.call @microkernel(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]]) +// CHECK-NOT: linalg.matmul +// CHECK-NOT: linalg.elemwise_binary +// CHECK: scf.forall.in_parallel +// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn} +// CHECK: scf.forall.in_parallel + +// Declaration of the "microkernel" function that we will be targeting. +func.func private @microkernel( + %lhs: tensor<4x512xf32>, + %rhs: tensor<512x4xf32>, + %bias: tensor<4x4xf32>, + %init: tensor<4x4xf32>, + %output: tensor<4x4xf32>) -> tensor<4x4xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">) { + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 + : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2, %loop_second_2 = + transform.structured.fuse_into_containing_op %matmul_fused into %loop_second + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Rewrite the call target. + transform.my.change_call_target %call, "microkernel" : !transform.any_op + + transform.yield + } +} diff --git a/mlir/example/tests/transform/Ch3/invalid.mlir b/mlir/example/tests/transform/Ch3/invalid.mlir new file mode 100644 index 0000000..acaabd5 --- /dev/null +++ b/mlir/example/tests/transform/Ch3/invalid.mlir @@ -0,0 +1,10 @@ +// RUN: transform-opt-ch3 %s --transform-interpreter --split-input-file --verify-diagnostics + +// expected-note @below {{offending operation}} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + // expected-error @below {{expected the payload operation to implement CallOpInterface}} + %arg0: !transform.my.call_op_interface) { + transform.yield + } +} diff --git a/mlir/example/tests/transform/Ch3/ops.mlir b/mlir/example/tests/transform/Ch3/ops.mlir new file mode 100644 index 0000000..b2d47cc --- /dev/null +++ b/mlir/example/tests/transform/Ch3/ops.mlir @@ -0,0 +1,48 @@ +// RUN: transform-opt-ch3 %s --transform-interpreter \ +// RUN: --allow-unregistered-dialect --split-input-file | FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +func.func private @orig() +func.func private @updated() + +// CHECK-LABEL: func @test1 +func.func @test1() { + // CHECK: call @updated + call @orig() : () -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %call = transform.structured.match ops{["func.call"]} in %arg0 : (!transform.any_op) -> !transform.op<"func.call"> + // CHECK: transform.my.change_call_target %{{.*}}, "updated" : !transform.op<"func.call"> + transform.my.change_call_target %call, "updated" : !transform.op<"func.call"> + transform.yield + } +} + +// ----- + +func.func private @orig() + +// CHECK-LABEL: func @test2 +func.func @test2() { + // CHECK: "my.mm4" + call @orig() : () -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %call = transform.structured.match ops{["func.call"]} in %arg0 : (!transform.any_op) -> !transform.my.call_op_interface + // CHECK: transform.my.call_to_op %{{.*}} : (!transform.my.call_op_interface) -> !transform.any_op + transform.my.call_to_op %call : (!transform.my.call_op_interface) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/example/tests/transform/Ch3/sequence.mlir b/mlir/example/tests/transform/Ch3/sequence.mlir new file mode 100644 index 0000000..b52fe10 --- /dev/null +++ b/mlir/example/tests/transform/Ch3/sequence.mlir @@ -0,0 +1,111 @@ +// RUN: transform-opt-ch3 %s \ +// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \ +// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\ +// RUN: canonicalize,cse,symbol-dce)" |\ +// RUN: FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// CHECK-LABEL: func @fc_relu +// CHECK: scf.forall +// CHECK: scf.forall +// CHECK: %[[SLICE4:.+]] = tensor.extract_slice +// CHECK: %[[SLICE5:.+]] = tensor.extract_slice +// CHECK: %[[SLICE6:.+]] = tensor.extract_slice +// CHECK: %[[SLICE7:.+]] = tensor.extract_slice +// CHECK: %[[SLICE8:.+]] = tensor.extract_slice +// CHECK: func.call @microkernel(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]]) +// CHECK-NOT: linalg.matmul +// CHECK-NOT: linalg.elemwise_binary +// CHECK: scf.forall.in_parallel +// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn} +// CHECK: scf.forall.in_parallel + +// Declaration of the "microkernel" function that we will be targeting. +func.func private @microkernel( + %lhs: tensor<4x512xf32>, + %rhs: tensor<512x4xf32>, + %bias: tensor<4x4xf32>, + %init: tensor<4x4xf32>, + %output: tensor<4x4xf32>) -> tensor<4x4xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">) { + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 + : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2, %loop_second_2 = + transform.structured.fuse_into_containing_op %matmul_fused into %loop_second + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) + + // Rewrite the call target. + transform.my.change_call_target %call, "microkernel" : !transform.op<"func.call"> + + transform.yield + } +} diff --git a/mlir/example/tests/transform/Ch4/features.mlir b/mlir/example/tests/transform/Ch4/features.mlir new file mode 100644 index 0000000..7f43cb3 --- /dev/null +++ b/mlir/example/tests/transform/Ch4/features.mlir @@ -0,0 +1,123 @@ +// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics + +// Matmul as a named operation. +func.func @named( + %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // expected-remark @below {{matmul}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %matmul : tensor<512x512xf32> +} + +// Matmul as a generic operation. +func.func @generic( + %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // expected-remark @below {{matmul}} + %matmul = linalg.generic { + iterator_types = ["parallel", "parallel", "reduction"], + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>] + } ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %0 = arith.mulf %arg0, %arg1 : f32 + %1 = arith.addf %0, %arg2 : f32 + linalg.yield %1 : f32 + } -> tensor<512x512xf32> + return %matmul : tensor<512x512xf32> +} + +// The module containing named sequences must have an attribute allowing them +// to enable verification. +module @transforms attributes { transform.with_named_sequence } { + // Entry point. This takes as the only argument the root operation (typically + // pass root) given to the transform interpreter. + transform.named_sequence @__transform_main( + %root: !transform.any_op {transform.consumed}) { + + // Traverses the payload IR associated with the operand handle, invoking + // @match_matmul_elemwise on each of the operations. If the named sequence + // succeeds, i.e., if none of the nested match (transform) operations + // produced a silenceable failure, invokes @print_matmul_elemwise and + // forwards the values yielded as arguments of the new invocation. If the + // named sequence fails with a silenceable failure, silences it (the message + // is forwarded to the debug stream). Definite failures are propagated + // immediately and unconditionally, as usual. + transform.foreach_match in %root + @match_generic_matmul -> @print_generic_matmul + : (!transform.any_op) -> !transform.any_op + + transform.yield + } + + // This is an action sequence. + transform.named_sequence @print_generic_matmul( + %matmul: !transform.any_op {transform.readonly}) { + transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op + transform.yield + } + + transform.named_sequence @match_generic_matmul( + %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op { + // Match a structured linear algebra operation. + transform.match.structured %candidate : !transform.any_op { + ^bb0(%c: !transform.any_op): + // With a rank equal to 3. + %rank = transform.match.structured.rank %c + : (!transform.any_op) -> !transform.param + %c3 = transform.param.constant 3 : i64 -> !transform.param + transform.match.param.cmpi eq %rank, %c3 : !transform.param + + // With 2 inputs. + %n_ins = transform.match.structured.num_inputs %c + : (!transform.any_op) -> !transform.param + %c2 = transform.param.constant 2 : i64 -> !transform.param + transform.match.param.cmpi eq %n_ins, %c2 : !transform.param + + // With 1 output (note that structured ops in destination passing style + // has as many inits as outputs). + %n_inits = transform.match.structured.num_inits %c + : (!transform.any_op) -> !transform.param + %c1 = transform.param.constant 1 : i64 -> !transform.param + transform.match.param.cmpi eq %n_inits, %c1 : !transform.param + + // All inputs and inits are accessed with a projected permutation. + transform.match.structured.input %c[all] {projected_permutation} + : !transform.any_op + transform.match.structured.init %c[0] {projected_permutation} + : !transform.any_op + + // The body is a mulf/addf contraction with appropriate dimensions. + transform.match.structured.body %c + { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op + %batch, %lhs, %rhs, %reduction = + transform.match.structured.classify_contraction_dims %c + : (!transform.any_op) + -> (!transform.param, !transform.param, !transform.param, + !transform.param) + + // There is one of lhs, rhs and reduction dimensions and zero batch + // dimensions. + %n_batch = transform.num_associations %batch + : (!transform.param) -> !transform.param + %n_lhs = transform.num_associations %lhs + : (!transform.param) -> !transform.param + %n_rhs = transform.num_associations %rhs + : (!transform.param) -> !transform.param + %n_reduction = transform.num_associations %reduction + : (!transform.param) -> !transform.param + %c0 = transform.param.constant 0 : i64 -> !transform.param + transform.match.param.cmpi eq %n_batch, %c0 : !transform.param + transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param + transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param + transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param + } + transform.yield %candidate : !transform.any_op + } +} diff --git a/mlir/example/tests/transform/Ch4/multiple.mlir b/mlir/example/tests/transform/Ch4/multiple.mlir new file mode 100644 index 0000000..a9f7155 --- /dev/null +++ b/mlir/example/tests/transform/Ch4/multiple.mlir @@ -0,0 +1,131 @@ +// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics + +// Matmul+ReLU. +func.func @fc_relu_operands_00( + %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-remark @below {{matmul # 0}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + // expected-remark @below {{add # 0}} + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + // expected-remark @below {{max # 0}} + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// Matmul+ReLU with swapped operands. +func.func @fc_relu_operands_01( + %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-remark @below {{matmul # 1}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + // expected-remark @below {{add # 1}} + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + // expected-remark @below {{max # 1}} + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%c0f, %biased : f32, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// The module containing named sequences must have an attribute allowing them +// to enable verification. +module @transforms attributes { transform.with_named_sequence } { + // Entry point. This takes as the only argument the root operation (typically + // pass root) given to the transform interpreter. + transform.named_sequence @__transform_main( + %root: !transform.any_op {transform.consumed}) { + + // Traverses the payload IR associated with the operand handle, invoking + // @match_matmul_elemwise on each of the operations. If the named sequence + // succeeds, i.e., if none of the nested match (transform) operations + // produced a silenceable failure, invokes @print_matmul_elemwise and + // forwards the values yielded as arguments of the new invocation. If the + // named sequence fails with a silenceable failure, silences it (the message + // is forwarded to the debug stream). Definite failures are propagated + // immediately and unconditionally, as usual. + transform.foreach_match in %root + @match_matmul_elemwise -> @print_matmul_elemwise + : (!transform.any_op) -> !transform.any_op + + transform.yield + } + + // This is an action sequence. + transform.named_sequence @print_matmul_elemwise( + %matmul: !transform.any_op {transform.readonly}, + %add: !transform.any_op {transform.readonly}, + %max: !transform.any_op {transform.readonly}, + %pos: !transform.param {transform.readonly}) { + transform.debug.emit_param_as_remark %pos, "matmul #" at %matmul + : !transform.param, !transform.any_op + transform.debug.emit_param_as_remark %pos, "add #" at %add + : !transform.param, !transform.any_op + transform.debug.emit_param_as_remark %pos, "max #" at %max + : !transform.param, !transform.any_op + transform.yield + } + + // This is also a matcher sequence. It is similarly given an operation to + // match and nested operations must succeed in order for a match to be deemed + // successful. It starts matching from the last operation in the use-def chain + // and goes back because each operand (use) has exactly one definition. + transform.named_sequence @match_matmul_elemwise( + %last: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.any_op, + !transform.param) { + // The last operation must be an elementwise binary. + transform.match.operation_name %last ["linalg.elemwise_binary"] + : !transform.any_op + + // One of its operands must be defined by another operation, to which we + // will get a handle here. This is achieved thanks to a newly defined + // operation that tries to match operands one by one using the match + // operations nested in its region. + %pos, %middle = transform.match.my.has_operand_satisfying %last + : (!transform.any_op) -> (!transform.param, !transform.any_op) { + ^bb0(%operand: !transform.any_value): + // The operand must be defined by an operation. + %def = transform.get_defining_op %operand + : (!transform.any_value) -> !transform.any_op + // The defining operation must itself be an elementwise binary. + transform.match.operation_name %def ["linalg.elemwise_binary"] + : !transform.any_op + transform.yield %def : !transform.any_op + } + + // And the first operand of that operation must be defined by yet another + // operation. + %matmul = transform.get_producer_of_operand %middle[0] + : (!transform.any_op) -> !transform.any_op + // And that operation is a matmul. + transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op + // We will yield the handles to the matmul and the two elementwise + // operations separately. + transform.yield %matmul, %middle, %last, %pos + : !transform.any_op, !transform.any_op, !transform.any_op, + !transform.param + } +} diff --git a/mlir/example/tests/transform/Ch4/sequence.mlir b/mlir/example/tests/transform/Ch4/sequence.mlir new file mode 100644 index 0000000..0dd8a9b --- /dev/null +++ b/mlir/example/tests/transform/Ch4/sequence.mlir @@ -0,0 +1,139 @@ +// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics +// +// RUN: transform-opt-ch4 %s \ +// RUN: --transform-interpreter='entry-point=__transform_main_v2' \ +// RUN: --verify-diagnostics + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-remark @below {{matmul}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + // expected-remark @below {{elementwise binary}} + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + // expected-remark @below {{elementwise binary}} + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// The module containing named sequences must have an attribute allowing them +// to enable verification. +module @transforms attributes { transform.with_named_sequence } { + // Entry point. This takes as the only argument the root operation (typically + // pass root) given to the transform interpreter. + transform.named_sequence @__transform_main( + %root: !transform.any_op {transform.readonly}) { + // Collect operations that match the criteria specified in the named + // sequence. If the named sequence fails with a silenceable failure, + // silences it (the message is forwarded to the debug stream). If the named + // sequence succeeds, appends its results to the results of this operation. + %elemwise = transform.collect_matching @match_elemwise in %root + : (!transform.any_op) -> !transform.any_op + %matmul = transform.collect_matching @match_matmul in %root + : (!transform.any_op) -> !transform.any_op + + transform.include @print_elemwise failures(propagate) (%elemwise) + : (!transform.any_op) -> () + transform.include @print_matmul failures(propagate) (%matmul) + : (!transform.any_op) -> () + + transform.yield + } + + // Alternative entry point. + transform.named_sequence @__transform_main_v2( + %root: !transform.any_op {transform.readonly}) { + // Collect groups of operations that match the criteria specified in the + // named sequence. + %matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %elemwise = transform.merge_handles %el1, %el2 : !transform.any_op + + transform.include @print_elemwise failures(propagate) (%elemwise) + : (!transform.any_op) -> () + transform.include @print_matmul failures(propagate) (%matmul) + : (!transform.any_op) -> () + + transform.yield + } + + // This is a matcher sequence. It is given an operation to match and the + // match is considered successful unless any nested operation produces a + // failure. The values yielded by this operation will be forwarded to the + // rewriter sequence on success. + transform.named_sequence @match_elemwise( + %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.operation_name %entry ["linalg.elemwise_binary"] + : !transform.any_op + transform.yield %entry : !transform.any_op + } + transform.named_sequence @match_matmul( + %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op + transform.yield %entry : !transform.any_op + } + + // This is an action sequence. + transform.named_sequence @print_elemwise( + %elemwise_binary: !transform.any_op {transform.readonly}) { + transform.debug.emit_remark_at + %elemwise_binary, "elementwise binary" : !transform.any_op + transform.yield + } + transform.named_sequence @print_matmul( + %matmul: !transform.any_op {transform.readonly}) { + transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op + transform.yield + } + + // This is also a matcher sequence. It is similarly given an operation to + // match and nested operations must succeed in order for a match to be deemed + // successful. It starts matching from the last operation in the use-def chain + // and goes back because each operand (use) has exactly one definition. + transform.named_sequence @match_matmul_elemwise( + %last: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.any_op) { + // The last operation must be an elementwise binary. + transform.match.operation_name %last ["linalg.elemwise_binary"] + : !transform.any_op + // Its first operand must be defined by another operation, to which we + // will get a handle here. We are guaranteed that the first operand exists + // because we know the operation is binary, but even in absence of such a + // guarantee, this operation would have produced a silenceable failure when + // `%last` does not have enough operands. + %middle = transform.get_producer_of_operand %last[0] + : (!transform.any_op) -> !transform.any_op + // The defining operation must itself be an elementwise binary. + transform.match.operation_name %middle ["linalg.elemwise_binary"] + : !transform.any_op + // And the first operand of that operation must be defined by yet another + // operation. + %matmul = transform.get_producer_of_operand %middle[0] + : (!transform.any_op) -> !transform.any_op + // And that operation is a matmul. + transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op + // We will yield the handles to the matmul and the two elementwise + // operations separately. + transform.yield %matmul, %middle, %last + : !transform.any_op, !transform.any_op, !transform.any_op + } +} diff --git a/mlir/example/tests/transform/ChH/full.mlir b/mlir/example/tests/transform/ChH/full.mlir new file mode 100644 index 0000000..259475e --- /dev/null +++ b/mlir/example/tests/transform/ChH/full.mlir @@ -0,0 +1,408 @@ +// RUN: mlir-opt %s --transform-interpreter \ +// RUN: --test-transform-dialect-erase-schedule \ +// RUN: --math-uplift-to-fma \ +// RUN: --convert-bufferization-to-memref \ +// RUN: --test-lower-to-llvm |\ +// RUN: FileCheck %s + +// Fixed-size tensor types to be used in convolution. +// Named sizes are: N=5 OH=80 OW=100 F=C=128 KH=KW=3. +// Input is NHWC. +// Filter is CHWF. +// Ouptut is NHWF. +!tinput = tensor<5x82x102x128xf32> +!tfilter = tensor<128x3x3x128xf32> +!tbias = tensor<128xf32> +!toutput = tensor<5x80x100x128xf32> + +// Function containing the convolution. Note that its arguments and results are +// tensors annotated with attributes from the `bufferization` dialect. These +// attributes hint the bufferization pass to assume buffers can be directly +// used for these tensors without reshaping. +func.func @conv( + %input: !tinput {bufferization.writable = false, + bufferization.access = "read", + bufferization.buffer_layout = + affine_map<(d0,d1,d2,d3)->(d0,d1,d2,d3)>}, + %filter: !tfilter {bufferization.writable = false, + bufferization.access = "read", + bufferization.buffer_layout = + affine_map<(d0,d1,d2,d3)->(d0,d1,d2,d3)>}, + %bias: !tbias {bufferization.writable = false, + bufferization.access = "read", + bufferization.buffer_layout = affine_map<(d0)->(d0)>}, + %output: !toutput {bufferization.writable = true, + bufferization.buffer_layout = + affine_map<(d0,d1,d2,d3)->(d0,d1,d2,d3)>, + bufferization.access = "write"}) -> !toutput + // This requests a C-compatible interface to be emitted for the function + // when translating to LLVM IR. + attributes { llvm.emit_c_interface } +{ + // Bias. Using a named Linalg operation for brevity. + %bias_init = tensor.empty() : !toutput + %biased = linalg.broadcast ins(%bias : !tbias) + outs(%bias_init : !toutput) dimensions = [0, 1, 2] + + // Convolution proper. While Linalg has named operations for 2D convolutions, + // the one in the Halide example has an uncommon order of filter dimensions + // and is not supported. It also takes the fitler as first argument. This + // code recreates it faithfully using the generic form. + %convolved = linalg.generic { + iterator_types = ["parallel", "parallel", "parallel", "parallel", + "reduction", "reduction", "reduction"], + indexing_maps = [ + affine_map<(n, y, x, c, rz, ry, rx) -> (rx, rz, ry, c)>, + affine_map<(n, y, x, c, rz, ry, rx) -> (n, y+rz, x+ry, rx)>, + affine_map<(n, y, x, c, rz, ry, rx) -> (n, y, x, c)> + ] + } ins(%filter, %input: !tfilter, !tinput) outs(%biased : !toutput) { + ^bb0(%in: f32, %f: f32, %b: f32): + // Note the fastmath attributes that allow operations to be recombined into + // %0 = math.fma %in, %f, %b : f32 + // later on and to reorder reductions. + %m1 = arith.mulf %in, %f {fastmath = #arith.fastmath} : f32 + %0 = arith.addf %b, %m1 {fastmath = #arith.fastmath} : f32 + linalg.yield %0 : f32 + } -> !toutput + + // ReLU is just a max(0, x). + %c0 = arith.constant 0.0 : f32 + %relued = linalg.generic { + iterator_types = ["parallel", "parallel", "parallel", "parallel"], + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> ()>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ] + } ins(%c0, %convolved : f32, !toutput) + outs(%output : !toutput) { + ^bb0(%cst: f32, %in: f32, %out: f32): + %0 = llvm.intr.maxnum(%cst, %in) : (f32, f32) -> f32 + linalg.yield %0 : f32 + } -> !toutput + + return %relued : !toutput +} + +// Module containing the transformation script to be applied. The attribute +// is required to correctly verify the use of named (macro-like) sequences. +module attributes { transform.with_named_sequence } { + // Apply transformations in a sequence to recreate the following Halide + // schedule: + // + // Var co, ci, xo, xi; + // relu.split(c, co, ci, vec * tile_w) + // .split(x, xo, xi, tile_h) + // .reorder(ci, xi, xo, y, n, co) + // .vectorize(ci, vec) + // .unroll(ci) + // .unroll(xi); + // conv.compute_at(relu, xo) + // .vectorize(c, vec) + // .unroll(c) + // .unroll(x) + // .unroll(y) + // .update() + // .reorder(c, x, y, r.x, r.y, r.z, n) + // .vectorize(c, vec) + // .unroll(c) + // .unroll(x) + // .unroll(y) + // .unroll(r.x, 2); + // + // where tile_w = 4, tile_h = 5, vec = 16. Note that unroll(y) and unroll(r.x) + // have no effect on the Halide IR as of 294f80c49bf3bb8582446613c25fcce03b82. + // Also note that the order of dimensions in Halide is inverted, e.g., co and + // n are the outermost loops in the respective reorder directives. + transform.named_sequence @__transform_main( + // This argument will point to the top-level module. + %arg0: !transform.any_op) { + + // 1. Find the operations we are going to transform usnig their names. This + // is a simplistic approach that works when there are few operations in the + // IR to be transformed. More complex scenarios should rely on operations + // with `transform.match` prefix that are out of scope for this chapter. + %bias = transform.structured.match ops{["linalg.broadcast"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %generics = transform.structured.match ops{["linalg.generic"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %conv, %relu = transform.split_handle %generics + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // 2. Initial tiling to start producing the loop structure. Note that the + // linalg.generic operation has the implicit loop order (n, y, x, c). Since + // the desired order of dimensions is (co, n, y, xo, xi, ci), we first tile + // only the c dimension to materialize the outermost co loop, and then tile + // the other dimensions since they are already in the expected order. Tiling + // by 1 produces the loop that iterates along the entire dimension. Tiling + // by 0 does not produce a loop. The size 64 is chosen as tiling by 4*16 + // where 16 is the AVX512 vector length. Note that structured tiling doesn't + // remove the dimensions that became trivial (unit size) so the resulting + // sturucture is technically (co, no=n, yo=y, xo, [ni=1, yi=1, xi, ci]) + // where brackets indicate implicit loops of the `linalg.generic` operation + // inside the loops produced by tiling. + // + // [n y x c] + %relu2, %co = transform.structured.tile_using_forall %relu + tile_sizes [0, 0, 0, 64] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %relu3, %n_y_xo = transform.structured.tile_using_forall %relu2 + tile_sizes [1, 1, 5, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Compute_at is actually fusion into the given loop (given that we start + // with totally fissioned form, Halide starts with a fused form by reusing + // the loop iterators). + %conv2, %co2 = transform.structured.fuse_into_containing_op %conv into %co + : (!transform.any_op, !transform.any_op) + -> (!transform.any_op, !transform.any_op) + %conv3, %n_y_xo2 = transform.structured.fuse_into_containing_op %conv2 + into %n_y_xo + : (!transform.any_op, !transform.any_op) + -> (!transform.any_op, !transform.any_op) + + // Also fuse the bias that we represent as a separate operation and Halide + // represents as the "pure" (as opposed to "update") part of the conv + // expression. Note that fusion consumes both handles and produces new + // handles for chaining purposes. + %bias2, %co3 = transform.structured.fuse_into_containing_op %bias into %co2 + : (!transform.any_op, !transform.any_op) + -> (!transform.any_op, !transform.any_op) + %bias3, %n_y_xo3 = transform.structured.fuse_into_containing_op %bias2 + into %n_y_xo2 + : (!transform.any_op, !transform.any_op) + -> (!transform.any_op, !transform.any_op) + + // Clean up the result of fusion, which mechanically duplicates the producer + // operation in the consumer loop without removing the original operation. + // The original operation is now "dead": it has no uses and no side effects + // so it can be removed by dead-code elimination (DCE) that runs as part of + // pattern rewriting. The transform dialect allows to apply a combination + // of named pattern sets, exposed as operations, in one sweep to an + // isolated-from-above container payload operation. Note that we don't + // actually need any patterns for DCE to run, just trigger the rewriting. + // + // This step is optional. The transformation can continue without it and + // produce the same final IR, but makes it easier to manually examine the + // intermediate stages. + %f00 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f00 { + } : !transform.any_op + + // The loop reordering requested for the convolution operation requires + // putting reduction loops (r.z, r.y. r.x) before the "inner" loops xi, ci. + // The "inner" loops are still implicit as part of the linalg.generic + // operation, and we need to materialize reduction loops around it by tiling + // with size 1. Since we are producing reduction loops, we indicate that we + // are tiling a reduction and request a sequential `scf.for` loops (parallel + // reductions are supported by `scf.forall`, but we don't need those here). + // + // This transform operation is more capable than merely producing + // (reduction) loops: the transformed code performs `tile_size` partial + // reductions of `N / tile_size` elements, potentially in parallel by + // changing the dimension kind of the structured operation inside the loop, + // and then performs a final reduction of these partial results by producing + // a new “combiner” structured operation after the loops. In our case, + // tile_size = 1 along all dimensions, so the reduction is entirely + // performed by the generated loops. The combiner structured operation is + // still produced and adds up the reduction result with the initial value. + %red_fill, %conv4, %combining, %rz_ry_rx + = transform.structured.tile_reduction_using_for %conv3 by + // n y x c rz ry rx + tile_sizes=[0, 0, 0, 0, 1, 1, 1] + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op, !transform.any_op, + !transform.any_op) + + // At this point, the inner Linalg operations have implicit iteration spaces + // of 5x64 size, with some additional unit-size dimensions. Completely + // replicating Halide schedule would require materializing the loops with + // 5 and 4 iterations, respectively, unrolling those loops and marking the + // remaining 16-point iteration space for vectorization. + // + // This is unnecessary in MLIR that supports multi-dimensional vectors, + // which will be decomposed into target-specific sizes during the lowering. + // Therefore, this schedule stops here. + + // Transform the named broadcast operation used for bias into the generic + // form before vectorization to prevent special cases from kicking in. + transform.structured.generalize %bias3 + : (!transform.any_op) -> !transform.any_op + + // Use the named macro to perform most of the lowering. + transform.include @lower failures(propagate) (%arg0) + : (!transform.any_op) -> () + transform.yield + } + + // Named sequence of transformations is a macro-like object that can be + // included from another place in the transform dialect, but doesn't allow for + // recursion. This can be reused in other scenarios. + transform.named_sequence @lower( + %arg0: !transform.any_op {transform.consumed}) { + %f00 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + + // Simplify the code as tiling and fusion may have produced a lot of + // operations computing tensor subsets and loop ranges, some of which may be + // duplicated or excessively complex. Simplification involving + // canonicalization, common subexpression elimination, loop invariant code + // motion and various rewrite patterns can be applied directly from the + // transform dialect. Furthermore, an arbitrary combination of rewrite + // patterns can be applied in one sweep to a given scope, a functionality + // that cannot be achieved with conventional compiler passes that apply each + // group of patterns separately (at least without creating a new pass for + // each combination of pattern groups). + transform.apply_patterns to %f00 { + transform.apply_patterns.canonicalization + transform.apply_patterns.linalg.tiling_canonicalization + } : !transform.any_op + transform.apply_cse to %f00 : !transform.any_op + %all_loops = transform.structured.match interface{LoopLikeInterface} + in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops : !transform.any_op + + // Tiling-by-one as a way of materializing loops produced operations + // processing 4+D types where only a handful of dimension isn’t unit-sized, + // e.g., tensor<1x1x1x5x64xf32> where 5 and 64 are tile sizes. Remove such + // unit dimensions before vectorization, for clarity. + transform.apply_patterns to %f00 { + transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes + } : !transform.any_op + + // Vectorize the remaining non-unit dimensions in structured operations. + // This essentially rewrites operations on `tensor<5x64xf32>` into + // opreations on `vector<5x64xf32>`. Further lowering in MLIR and LLVM will + // decompose this into a sequence of operations on single-dimensional + // vectors of the platform-relevant size, e.g., `vector<16xf32>` for AVX512. + // High-level vector primitives, such as `vector.transpose` and + // `vector.broadcast` can be introduced at this stage. They will be later + // lowered to sequences of lower-level primitives such as `vector.shuffle` + // depending on the selected lowering strategy. + %fv = transform.structured.vectorize_children_and_apply_patterns %f00 + : (!transform.any_op) -> !transform.any_op + + // Vectorization may have created new opportunities for cleanups. In + // particular, tensor subsetting operations can be composed with vector + // operations, and vector transfer (multi-dimensional load/store) operations + // can be recombined and hoisted out of loops. + transform.apply_patterns to %fv { + transform.apply_patterns.canonicalization + transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers + } : !transform.any_op + transform.apply_cse to %fv : !transform.any_op + transform.structured.hoist_redundant_vector_transfers %fv + : (!transform.any_op) -> !transform.any_op + + // Apply bufferization that rewrites the remaining operations on tensors + // as operations on structured buffer (memref) types, including the function + // API. MLIR bufferization uses destination-passing style meaning that a + // buffer is shared between one of the operation's operands and its result. + // + // Since bufferization rewrites function signatures, it is applied as a + // module-wise transformation. Therefore, it invalidates all previously + // defined handles. Bufferization is usually a late step in the + // transformation process, so invalidation is not an issue. However, if + // other transformations, such as loop unrolling, are required after + // bufferization, new handles should be produced using the match operations. + // + // One-shot bufferization itself does not produce buffer deallocations, + // which may lead to leaks. So we have to run the buffer deallocation pass + // pipeline to avoid them. Note that the transform dialect seamlessly runs + // named passes and pass pipelines: if desired, one could replace complex + // --pass-pipeline expressions with operations. Note that we apply the + // pipeline to functions rather than entire module to avoid running it + // on the transform IR that is contained in the module. + %arg1 = transform.bufferization.one_shot_bufferize %arg0 { + bufferize_function_boundaries = true, + function_boundary_type_conversion = 1 : i32 } + : (!transform.any_op) -> !transform.any_op + %f = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "buffer-deallocation-pipeline" to %f + : (!transform.any_op) -> !transform.any_op + + // Apply general canonicalization and CSE to each function after + // bufferization as new simplification opportunities may have appeared. + %fb = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %fb { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %fb : !transform.any_op + + // Lower complex, multidimensional vector operations into simpler + // primitives. This particular selection of the pattern groups corresponds + // to vector dialect operations present in the payload IR at this stage. + // Many of these groups can be parameterized to use different strategies or + // lower-level primitives offering performance trade-offs. In this case, we + // are selecting the simplest strategies. + transform.apply_patterns to %fb { + transform.apply_patterns.vector.lower_contraction + lowering_strategy = parallelarith + transform.apply_patterns.vector.lower_transfer + max_transfer_rank = 1 + transform.apply_patterns.vector.lower_transpose + lowering_strategy = eltwise + transform.apply_patterns.vector.lower_shape_cast + } : !transform.any_op + + // These patterns apply in a separate sweep to avoid transfer-to-scf + // patterns overlap with lower-transfer patterns as they apply to the same + // kind of operations. These patterns may produce local allocations to act + // as temporary caches deep inside loops, which could lead to catastrophic + // performance. Such allocations are moved onto the stack and hoisted from + // all the surrounding loops. + transform.apply_patterns to %fb { + transform.apply_patterns.vector.transfer_to_scf + transform.apply_patterns.memref.alloc_to_alloca + } : !transform.any_op + transform.bufferization.buffer_loop_hoisting %fb : !transform.any_op + + // A final round of cleanups additionally includes patterns to simplify + // buffer aliasing operations that may have been introduced during + // bufferization and could result in excessively complex address + // computation. + transform.apply_patterns to %fb { + transform.apply_patterns.memref.fold_memref_alias_ops + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %fb : !transform.any_op + + transform.yield + } +} + +// The core computation, at the LLVM dialect level, must correspond to five +// immediately adjacent fma on vector<64xf32>. + +// CHECK: %[[R0:.+]] = llvm.mlir.undef : !llvm.array<5 x vector<64xf32>> + +// CHECK: %[[V:.+]] = llvm.load %{{.*}} : !llvm.ptr -> !llvm.array<5 x vector<64xf32>> +// CHECK-NEXT: %[[LINE0:.+]] = llvm.extractvalue %[[V]][0] : !llvm.array<5 x vector<64xf32>> +// CHECK-NEXT: %[[FMA0:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE0]]) +// CHECK-SAME: -> vector<64xf32> +// CHECK-NEXT: %[[R1:.+]] = llvm.insertvalue %[[FMA0]], %[[R0]][0] + +// CHECK-NEXT: %[[LINE1:.+]] = llvm.extractvalue %[[V]][1] : !llvm.array<5 x vector<64xf32>> +// CHECK-NEXT: %[[FMA1:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE1]]) +// CHECK-SAME: -> vector<64xf32> +// CHECK-NEXT: %[[R2:.+]] = llvm.insertvalue %[[FMA1]], %[[R1]][1] + +// CHECK-NEXT: %[[LINE2:.+]] = llvm.extractvalue %[[V]][2] : !llvm.array<5 x vector<64xf32>> +// CHECK-NEXT: %[[FMA2:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE2]]) +// CHECK-SAME: -> vector<64xf32> +// CHECK-NEXT: %[[R3:.+]] = llvm.insertvalue %[[FMA2]], %[[R2]][2] + +// CHECK-NEXT: %[[LINE3:.+]] = llvm.extractvalue %[[V]][3] : !llvm.array<5 x vector<64xf32>> +// CHECK-NEXT: %[[FMA3:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE3]]) +// CHECK-SAME: -> vector<64xf32> +// CHECK-NEXT: %[[R4:.+]] = llvm.insertvalue %[[FMA3]], %[[R3]][3] + +// CHECK-NEXT: %[[LINE4:.+]] = llvm.extractvalue %[[V]][4] : !llvm.array<5 x vector<64xf32>> +// CHECK-NEXT: %[[FMA4:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE4]]) +// CHECK-SAME: -> vector<64xf32> +// CHECK-NEXT: %[[R5:.+]] = llvm.insertvalue %[[FMA4]], %[[R4]][4] diff --git a/mlir/example/transform_Ch2/CMakeLists.txt b/mlir/example/transform_Ch2/CMakeLists.txt new file mode 100644 index 0000000..4f988c7 --- /dev/null +++ b/mlir/example/transform_Ch2/CMakeLists.txt @@ -0,0 +1,16 @@ +# For a better top-level template to copy, see examples/standalone. + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +add_subdirectory(include) +add_subdirectory(lib) + +add_executable(transform-opt-ch2 transform-opt/transform-opt.cpp) + +add_dependencies(transform-opt-ch2 MyExtensionCh2IncGen) + +target_link_libraries( + transform-opt-ch2 PRIVATE MLIRIR MLIRMlirOptMain MLIRSideEffectInterfaces + MyExtensionCh2) diff --git a/mlir/example/transform_Ch2/include/CMakeLists.txt b/mlir/example/transform_Ch2/include/CMakeLists.txt new file mode 100644 index 0000000..3171365 --- /dev/null +++ b/mlir/example/transform_Ch2/include/CMakeLists.txt @@ -0,0 +1,14 @@ +# Tell Tablegen to use MyExtension.td as input. +set(LLVM_TARGET_DEFINITIONS MyExtension.td) + +# Ask Tablegen to generate op declarations and definitions from ODS. +mlir_tablegen(MyExtension.h.inc -gen-op-decls) +mlir_tablegen(MyExtension.cpp.inc -gen-op-defs) + +# Add a CMakeTarget we can depend on to ensure the generation happens before the +# compilation. +add_public_tablegen_target(MyExtensionCh2IncGen) + +# Don't forget to generate the documentation, this will produce a +# MyExtensionCh2.md under Tutorials/transform +add_mlir_doc(MyExtension MyExtensionCh2 Tutorials/transform/ -gen-op-doc) diff --git a/mlir/example/transform_Ch2/include/MyExtension.h b/mlir/example/transform_Ch2/include/MyExtension.h new file mode 100644 index 0000000..5ab70a5 --- /dev/null +++ b/mlir/example/transform_Ch2/include/MyExtension.h @@ -0,0 +1,22 @@ +//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 2 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" + +#define GET_OP_CLASSES +#include "MyExtension.h.inc" + +// Registers our Transform dialect extension. +void registerMyExtension(::mlir::DialectRegistry ®istry); diff --git a/mlir/example/transform_Ch2/include/MyExtension.td b/mlir/example/transform_Ch2/include/MyExtension.td new file mode 100644 index 0000000..15cd1e6 --- /dev/null +++ b/mlir/example/transform_Ch2/include/MyExtension.td @@ -0,0 +1,56 @@ +//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 2 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#ifndef MY_EXTENSION +#define MY_EXTENSION + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Define the new operation. By convention, prefix its name with the name of the dialect +// extension, "my.". The full operation name will be further prefixed with "transform.". +def ChangeCallTargetOp : Op, + DeclareOpInterfaceMethods]> { + // Provide a brief and a full description. It is recommended that the latter describes + // the effects on the operands and how the operation processes various failure modes. + let summary = "Changes the callee of a call operation to the specified one"; + let description = [{ + For each `func.call` payload operation associated with the handle, changes its + callee to be the symbol whose name is provided as an attribute to this operation. + + Generates a silenceable failure if the operand is associated with payload operations + that are not `func.call`. + Only reads the operand. + }]; + + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. + let arguments = (ins + TransformHandleTypeInterface:$call, + StrAttr:$new_target); + + // The results are empty as the transformation does not produce any new payload. + let results = (outs); + + // Provide nice syntax. + let assemblyFormat = "$call `,` $new_target attr-dict `:` type($call)"; +} + +#endif // MY_EXTENSION diff --git a/mlir/example/transform_Ch2/lib/CMakeLists.txt b/mlir/example/transform_Ch2/lib/CMakeLists.txt new file mode 100644 index 0000000..5609e4d --- /dev/null +++ b/mlir/example/transform_Ch2/lib/CMakeLists.txt @@ -0,0 +1,19 @@ +# Outside examples, this should be `add_mlir_library`. +add_mlir_example_library( + # Library called MyExtension. + MyExtensionCh2 + # Built from the following source files. + MyExtension.cpp + # Make includes visible without top-level path. + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/examples/transform/Ch2/include + # Make sure ODS declaration and definitions are generated before compiling + # this. + DEPENDS + MyExtensionCh2IncGen + # Link in the transform dialect, an all generated dialects. + LINK_LIBS + PRIVATE + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect) diff --git a/mlir/example/transform_Ch2/lib/MyExtension.cpp b/mlir/example/transform_Ch2/lib/MyExtension.cpp new file mode 100644 index 0000000..68d538a --- /dev/null +++ b/mlir/example/transform_Ch2/lib/MyExtension.cpp @@ -0,0 +1,140 @@ +//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 2 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +// Define a new transform dialect extension. This uses the CRTP idiom to +// identify extensions. +class MyExtension + : public ::mlir::transform::TransformDialectExtension { +public: + // The extension must derive the base constructor. + using Base::Base; + + // This function initializes the extension, similarly to `initialize` in + // dialect definitions. List individual operations and dependent dialects + // here. + void init(); +}; + +void MyExtension::init() { + // Similarly to dialects, an extension can declare a dependent dialect. This + // dialect will be loaded along with the extension and, therefore, along with + // the Transform dialect. Only declare as dependent the dialects that contain + // the attributes or types used by transform operations. Do NOT declare as + // dependent the dialects produced during the transformation. + // declareDependentDialect(); + + // When transformations are applied, they may produce new operations from + // previously unloaded dialects. Typically, a pass would need to declare + // itself dependent on the dialects containing such new operations. To avoid + // confusion with the dialects the extension itself depends on, the Transform + // dialects differentiates between: + // - dependent dialects, which are used by the transform operations, and + // - generated dialects, which contain the entities (attributes, operations, + // types) that may be produced by applying the transformation even when + // not present in the original payload IR. + // In the following chapter, we will be add operations that generate function + // calls and structured control flow operations, so let's declare the + // corresponding dialects as generated. + declareGeneratedDialect<::mlir::scf::SCFDialect>(); + declareGeneratedDialect<::mlir::func::FuncDialect>(); + + // Finally, we register the additional transform operations with the dialect. + // List all operations generated from ODS. This call will perform additional + // checks that the operations implement the transform and memory effect + // interfaces required by the dialect interpreter and assert if they do not. + registerTransformOps< +#define GET_OP_LIST +#include "MyExtension.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "MyExtension.cpp.inc" + +static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { + call.setCallee(newTarget); +} + +// Implementation of our transform dialect operation. +// This operation returns a tri-state result that can be one of: +// - success when the transformation succeeded; +// - definite failure when the transformation failed in such a way that +// following transformations are impossible or undesirable, typically it could +// have left payload IR in an invalid state; it is expected that a diagnostic +// is emitted immediately before returning the definite error; +// - silenceable failure when the transformation failed but following +// transformations are still applicable, typically this means a precondition +// for the transformation is not satisfied and the payload IR has not been +// modified. The silenceable failure additionally carries a Diagnostic that +// can be emitted to the user. +::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, + // The list of payload IR entities that will be associated with the + // transform IR values defined by this transform operation. In this case, it + // can remain empty as there are no results. + ::mlir::transform::TransformResults &results, + // The transform application state. This object can be used to query the + // current associations between transform IR values and payload IR entities. + // It can also carry additional user-defined state. + ::mlir::transform::TransformState &state) { + + // First, we need to obtain the list of payload operations that are associated + // with the operand handle. + auto payload = state.getPayloadOps(getCall()); + + // Then, we iterate over the list of operands and call the actual IR-mutating + // function. We also check the preconditions here. + for (Operation *payloadOp : payload) { + auto call = dyn_cast<::mlir::func::CallOp>(payloadOp); + if (!call) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "only applies to func.call payloads"; + diag.attachNote(payloadOp->getLoc()) << "offending payload"; + return diag; + } + + updateCallee(call, getNewTarget()); + } + + // If everything went well, return success. + return DiagnosedSilenceableFailure::success(); +} + +void mlir::transform::ChangeCallTargetOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + // Indicate that the `call` handle is only read by this operation because the + // associated operation is not erased but rather modified in-place, so the + // reference to it remains valid. + onlyReadsHandle(getCallMutable(), effects); + + // Indicate that the payload is modified by this operation. + modifiesPayload(effects); +} + +void registerMyExtension(::mlir::DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/example/transform_Ch2/transform-opt/transform-opt.cpp b/mlir/example/transform_Ch2/transform-opt/transform-opt.cpp new file mode 100644 index 0000000..874ad78 --- /dev/null +++ b/mlir/example/transform_Ch2/transform-opt/transform-opt.cpp @@ -0,0 +1,49 @@ +//===-- transform-opt.cpp - Transform dialect tutorial entry point --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top-level file for the Transform dialect tutorial chapter 2. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" + +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include + +namespace test { +void registerTestTransformDialectExtension(mlir::DialectRegistry &); +} // namespace test + +int main(int argc, char **argv) { + // Register all "core" dialects and our transform dialect extension. + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::registerAllExtensions(registry); + registerMyExtension(registry); + + // Register transform interpreter pass. + mlir::transform::registerInterpreterPass(); + + // Register a handful of cleanup passes that we can run to make the output IR + // look nicer. + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + mlir::registerSymbolDCEPass(); + + // Delegate to the MLIR utility for parsing and pass management. + return mlir::MlirOptMain(argc, argv, "transform-opt-ch2", registry) + .succeeded() + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/mlir/example/transform_Ch3/CMakeLists.txt b/mlir/example/transform_Ch3/CMakeLists.txt new file mode 100644 index 0000000..3714e5f --- /dev/null +++ b/mlir/example/transform_Ch3/CMakeLists.txt @@ -0,0 +1,16 @@ +# For a better top-level template to copy, see examples/standalone. + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +add_subdirectory(include) +add_subdirectory(lib) + +add_executable(transform-opt-ch3 transform-opt/transform-opt.cpp) + +add_dependencies(transform-opt-ch3 MyExtensionCh3IncGen) + +target_link_libraries( + transform-opt-ch3 PRIVATE MLIRIR MLIRMlirOptMain MLIRSideEffectInterfaces + MyExtensionCh3) diff --git a/mlir/example/transform_Ch3/include/CMakeLists.txt b/mlir/example/transform_Ch3/include/CMakeLists.txt new file mode 100644 index 0000000..32b94dc --- /dev/null +++ b/mlir/example/transform_Ch3/include/CMakeLists.txt @@ -0,0 +1,21 @@ +# Tell Tablegen to use MyExtension.td as input. +set(LLVM_TARGET_DEFINITIONS MyExtension.td) + +# Ask Tablegen to generate op declarations and definitions from ODS. +mlir_tablegen(MyExtension.h.inc -gen-op-decls) +mlir_tablegen(MyExtension.cpp.inc -gen-op-defs) + +# Tell Tablegen to use MyExtensionTypes.td as input. +set(LLVM_TARGET_DEFINITIONS MyExtensionTypes.td) + +# Ask Tablegen to generate type declarations and definitions from ODS. +mlir_tablegen(MyExtensionTypes.h.inc -gen-typedef-decls) +mlir_tablegen(MyExtensionTypes.cpp.inc -gen-typedef-defs) + +# Add a CMakeTarget we can depend on to ensure the generation happens before the +# compilation. +add_public_tablegen_target(MyExtensionCh3IncGen) + +# Don't forget to generate the documentation, this will produce a +# MyExtensionCh3.md under Tutorials/transform +add_mlir_doc(MyExtension MyExtensionCh3 Tutorials/transform/ -gen-op-doc) diff --git a/mlir/example/transform_Ch3/include/MyExtension.h b/mlir/example/transform_Ch3/include/MyExtension.h new file mode 100644 index 0000000..0868504 --- /dev/null +++ b/mlir/example/transform_Ch3/include/MyExtension.h @@ -0,0 +1,35 @@ +//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 3 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" + +namespace mlir { +class CallOpInterface; +namespace func { +class CallOp; +} // namespace func +namespace transform { +class OperationType; +} // namespace transform +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "MyExtensionTypes.h.inc" + +#define GET_OP_CLASSES +#include "MyExtension.h.inc" + +// Registers our Transform dialect extension. +void registerMyExtension(::mlir::DialectRegistry ®istry); diff --git a/mlir/example/transform_Ch3/include/MyExtension.td b/mlir/example/transform_Ch3/include/MyExtension.td new file mode 100644 index 0000000..7944f91 --- /dev/null +++ b/mlir/example/transform_Ch3/include/MyExtension.td @@ -0,0 +1,100 @@ +//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 3 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#ifndef MY_EXTENSION +#define MY_EXTENSION + +include "MyExtensionTypes.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Define the new operation. By convention, prefix its name with the name of the dialect +// extension, "my.". The full operation name will be further prefixed with "transform.". +def ChangeCallTargetOp : Op]> { + // Provide a brief and a full description. It is recommended that the latter describes + // the effects on the operands and how the operation processes various failure modes. + let summary = "Changes the callee of a call operation to the specified one"; + let description = [{ + For each `func.call` payload operation associated with the handle, changes its + callee to be the symbol whose name is provided as an attribute to this operation. + + Generates a silenceable failure if the operand is associated with payload operations + that are not `func.call`. + Only reads the operand. + }]; + + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. + let arguments = (ins + // Specify the type constraint on the input accepting only `func.call` payload + // operations. + Transform_ConcreteOpType<"func.call">:$call, + StrAttr:$new_target); + + // The results are empty as the transformation does not produce any new payload. + let results = (outs); + + // Provide nice syntax. + let assemblyFormat = "$call `,` $new_target attr-dict `:` qualified(type($call))"; + + // Declare the function implementing the interface for a single payload operation. + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::func::CallOp call, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// Define another transform operation. +def CallToOp : Op { + // Summary and description omitted for brevity. + + // The argument is the handle to the payload operations. + let arguments = (ins CallOpInterfaceHandle:$call); + + // The result is the handle to the payload operations produced during the + // transformation. + let results = (outs TransformHandleTypeInterface:$transformed); + + // Provide nice syntax. + let assemblyFormat = "$call attr-dict `:` functional-type(operands, results)"; + + // Declare the function implementing the interface for a single payload operation. + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::CallOpInterface call, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +#endif // MY_EXTENSION diff --git a/mlir/example/transform_Ch3/include/MyExtensionTypes.td b/mlir/example/transform_Ch3/include/MyExtensionTypes.td new file mode 100644 index 0000000..d0df02a --- /dev/null +++ b/mlir/example/transform_Ch3/include/MyExtensionTypes.td @@ -0,0 +1,34 @@ +//===-- MyExtensionTypes.td - Transform dialect tutorial ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension types used in the +// Chapter 3 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#ifndef MY_EXTENSIONTYPES +#define MY_EXTENSIONTYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +// Transform dialect allows additional types to be defined and injected. +def CallOpInterfaceHandle + : TypeDef]> { + + // The usual components of a type such as description, mnemonic and assembly format + // should be provided. + let summary = "handle to payload operations implementing CallOpInterface"; + let mnemonic = "my.call_op_interface"; + let assemblyFormat = ""; +} + +#endif // MY_EXTENSIONTYPES diff --git a/mlir/example/transform_Ch3/lib/CMakeLists.txt b/mlir/example/transform_Ch3/lib/CMakeLists.txt new file mode 100644 index 0000000..b75463e --- /dev/null +++ b/mlir/example/transform_Ch3/lib/CMakeLists.txt @@ -0,0 +1,19 @@ +# Outside examples, this should be `add_mlir_library`. +add_mlir_example_library( + # Library called MyExtension. + MyExtensionCh3 + # Built from the following source files. + MyExtension.cpp + # Make includes visible without top-level path. + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/examples/transform/Ch3/include + # Make sure ODS declaration and definitions are generated before compiling + # this. + DEPENDS + MyExtensionCh3IncGen + # Link in the transform dialect, an all generated dialects. + LINK_LIBS + PRIVATE + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect) diff --git a/mlir/example/transform_Ch3/lib/MyExtension.cpp b/mlir/example/transform_Ch3/lib/MyExtension.cpp new file mode 100644 index 0000000..f7a9942 --- /dev/null +++ b/mlir/example/transform_Ch3/lib/MyExtension.cpp @@ -0,0 +1,221 @@ +//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 3 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "llvm/ADT/TypeSwitch.h" + +#define GET_TYPEDEF_CLASSES +#include "MyExtensionTypes.cpp.inc" + +#define GET_OP_CLASSES +#include "MyExtension.cpp.inc" + +//===---------------------------------------------------------------------===// +// MyExtension +//===---------------------------------------------------------------------===// + +// Define a new transform dialect extension. This uses the CRTP idiom to +// identify extensions. +class MyExtension + : public ::mlir::transform::TransformDialectExtension { +public: + // The extension must derive the base constructor. + using Base::Base; + + // This function initializes the extension, similarly to `initialize` in + // dialect definitions. List individual operations and dependent dialects + // here. + void init(); +}; + +void MyExtension::init() { + // Similarly to dialects, an extension can declare a dependent dialect. This + // dialect will be loaded along with the extension and, therefore, along with + // the Transform dialect. Only declare as dependent the dialects that contain + // the attributes or types used by transform operations. Do NOT declare as + // dependent the dialects produced during the transformation. + // declareDependentDialect(); + + // When transformations are applied, they may produce new operations from + // previously unloaded dialects. Typically, a pass would need to declare + // itself dependent on the dialects containing such new operations. To avoid + // confusion with the dialects the extension itself depends on, the Transform + // dialects differentiates between: + // - dependent dialects, which are used by the transform operations, and + // - generated dialects, which contain the entities (attributes, operations, + // types) that may be produced by applying the transformation even when + // not present in the original payload IR. + // In the following chapter, we will be add operations that generate function + // calls and structured control flow operations, so let's declare the + // corresponding dialects as generated. + declareGeneratedDialect<::mlir::scf::SCFDialect>(); + declareGeneratedDialect<::mlir::func::FuncDialect>(); + + // Register the additional transform dialect types with the dialect. List all + // types generated from ODS. + registerTypes< +#define GET_TYPEDEF_LIST +#include "MyExtensionTypes.cpp.inc" + >(); + + // ODS generates these helpers for type printing and parsing, but the + // Transform dialect provides its own support for types supplied by the + // extension. Reference these functions to avoid a compiler warning. + (void)&generatedTypeParser; + (void)&generatedTypePrinter; + + // Finally, we register the additional transform operations with the dialect. + // List all operations generated from ODS. This call will perform additional + // checks that the operations implement the transform and memory effect + // interfaces required by the dialect interpreter and assert if they do not. + registerTransformOps< +#define GET_OP_LIST +#include "MyExtension.cpp.inc" + >(); +} + +//===---------------------------------------------------------------------===// +// ChangeCallTargetOp +//===---------------------------------------------------------------------===// + +static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { + call.setCallee(newTarget); +} + +// Implementation of our transform dialect operation. +// This operation returns a tri-state result that can be one of: +// - success when the transformation succeeded; +// - definite failure when the transformation failed in such a way that +// following +// transformations are impossible or undesirable, typically it could have left +// payload IR in an invalid state; it is expected that a diagnostic is emitted +// immediately before returning the definite error; +// - silenceable failure when the transformation failed but following +// transformations +// are still applicable, typically this means a precondition for the +// transformation is not satisfied and the payload IR has not been modified. +// The silenceable failure additionally carries a Diagnostic that can be emitted +// to the user. +::mlir::DiagnosedSilenceableFailure +mlir::transform::ChangeCallTargetOp::applyToOne( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, + // The single payload operation to which the transformation is applied. + ::mlir::func::CallOp call, + // The payload IR entities that will be appended to lists associated with + // the results of this transform operation. This list contains one entry per + // result. + ::mlir::transform::ApplyToEachResultList &results, + // The transform application state. This object can be used to query the + // current associations between transform IR values and payload IR entities. + // It can also carry additional user-defined state. + ::mlir::transform::TransformState &state) { + + // Dispatch to the actual transformation. + updateCallee(call, getNewTarget()); + + // If everything went well, return success. + return DiagnosedSilenceableFailure::success(); +} + +void mlir::transform::ChangeCallTargetOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + // Indicate that the `call` handle is only read by this operation because the + // associated operation is not erased but rather modified in-place, so the + // reference to it remains valid. + onlyReadsHandle(getCallMutable(), effects); + + // Indicate that the payload is modified by this operation. + modifiesPayload(effects); +} + +//===---------------------------------------------------------------------===// +// CallToOp +//===---------------------------------------------------------------------===// + +static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter, + mlir::CallOpInterface call) { + // Construct an operation from an unregistered dialect. This is discouraged + // and is only used here for brevity of the overall example. + mlir::OperationState state(call.getLoc(), "my.mm4"); + state.types.assign(call->result_type_begin(), call->result_type_end()); + state.operands.assign(call->operand_begin(), call->operand_end()); + + mlir::Operation *replacement = rewriter.create(state); + rewriter.replaceOp(call, replacement->getResults()); + return replacement; +} + +// See above for the signature description. +mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( + mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call, + mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformState &state) { + + // Dispatch to the actual transformation. + Operation *replacement = replaceCallWithOp(rewriter, call); + + // Associate the payload operation produced by the rewrite with the result + // handle of this transform operation. + results.push_back(replacement); + + // If everything went well, return success. + return DiagnosedSilenceableFailure::success(); +} + +//===---------------------------------------------------------------------===// +// CallOpInterfaceHandleType +//===---------------------------------------------------------------------===// + +// The interface declares this method to verify constraints this type has on +// payload operations. It returns the now familiar tri-state result. +mlir::DiagnosedSilenceableFailure +mlir::transform::CallOpInterfaceHandleType::checkPayload( + // Location at which diagnostics should be emitted. + mlir::Location loc, + // List of payload operations that are about to be associated with the + // handle that has this type. + llvm::ArrayRef payload) const { + + // All payload operations are expected to implement CallOpInterface, check + // this. + for (Operation *op : payload) { + if (llvm::isa(op)) + continue; + + // By convention, these verifiers always emit a silenceable failure since + // they are checking a precondition. + DiagnosedSilenceableFailure diag = + emitSilenceableError(loc) + << "expected the payload operation to implement CallOpInterface"; + diag.attachNote(op->getLoc()) << "offending operation"; + return diag; + } + + // If everything is okay, return success. + return DiagnosedSilenceableFailure::success(); +} + +//===---------------------------------------------------------------------===// +// Extension registration +//===---------------------------------------------------------------------===// + +void registerMyExtension(::mlir::DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/example/transform_Ch3/transform-opt/transform-opt.cpp b/mlir/example/transform_Ch3/transform-opt/transform-opt.cpp new file mode 100644 index 0000000..c9150c6 --- /dev/null +++ b/mlir/example/transform_Ch3/transform-opt/transform-opt.cpp @@ -0,0 +1,45 @@ +//===-- transform-opt.cpp - Transform dialect tutorial entry point --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top-level file for the Transform dialect tutorial chapter 3. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" + +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include + +int main(int argc, char **argv) { + // Register all "core" dialects and our transform dialect extension. + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::registerAllExtensions(registry); + registerMyExtension(registry); + + // Register the interpreter pass. + mlir::transform::registerInterpreterPass(); + + // Register a handful of cleanup passes that we can run to make the output IR + // look nicer. + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + mlir::registerSymbolDCEPass(); + + // Delegate to the MLIR utility for parsing and pass management. + return mlir::MlirOptMain(argc, argv, "transform-opt-ch3", registry) + .succeeded() + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/mlir/example/transform_Ch4/CMakeLists.txt b/mlir/example/transform_Ch4/CMakeLists.txt new file mode 100644 index 0000000..04cd265 --- /dev/null +++ b/mlir/example/transform_Ch4/CMakeLists.txt @@ -0,0 +1,16 @@ +# For a better top-level template to copy, see examples/standalone. + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +add_subdirectory(include) +add_subdirectory(lib) + +add_executable(transform-opt-ch4 transform-opt/transform-opt.cpp) + +add_dependencies(transform-opt-ch4 MyExtensionCh4IncGen) + +target_link_libraries( + transform-opt-ch4 PRIVATE MLIRIR MLIRMlirOptMain MLIRSideEffectInterfaces + MLIRTransformDialectTransforms MyExtensionCh4) diff --git a/mlir/example/transform_Ch4/include/CMakeLists.txt b/mlir/example/transform_Ch4/include/CMakeLists.txt new file mode 100644 index 0000000..1f960e5 --- /dev/null +++ b/mlir/example/transform_Ch4/include/CMakeLists.txt @@ -0,0 +1,14 @@ +# Tell Tablegen to use MyExtension.td as input. +set(LLVM_TARGET_DEFINITIONS MyExtension.td) + +# Ask Tablegen to generate op declarations and definitions from ODS. +mlir_tablegen(MyExtension.h.inc -gen-op-decls) +mlir_tablegen(MyExtension.cpp.inc -gen-op-defs) + +# Add a CMakeTarget we can depend on to ensure the generation happens before the +# compilation. +add_public_tablegen_target(MyExtensionCh4IncGen) + +# Don't forget to generate the documentation, this will produce a +# MyExtensionCh4.md under Tutorials/transform +add_mlir_doc(MyExtension MyExtensionCh4 Tutorials/transform/ -gen-op-doc) diff --git a/mlir/example/transform_Ch4/include/MyExtension.h b/mlir/example/transform_Ch4/include/MyExtension.h new file mode 100644 index 0000000..620ec8f --- /dev/null +++ b/mlir/example/transform_Ch4/include/MyExtension.h @@ -0,0 +1,30 @@ +//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 4 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" + +namespace mlir { +class CallOpInterface; +namespace func { +class CallOp; +} // namespace func +} // namespace mlir + +#define GET_OP_CLASSES +#include "MyExtension.h.inc" + +// Registers our Transform dialect extension. +void registerMyExtension(::mlir::DialectRegistry ®istry); diff --git a/mlir/example/transform_Ch4/include/MyExtension.td b/mlir/example/transform_Ch4/include/MyExtension.td new file mode 100644 index 0000000..6606803 --- /dev/null +++ b/mlir/example/transform_Ch4/include/MyExtension.td @@ -0,0 +1,46 @@ +//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 4 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#ifndef MY_EXTENSION +#define MY_EXTENSION + +include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Define the new operation. By convention, prefix its name with `match` +// followed by the name of the dialect extension. +def HasOperandSatisfyingOp : TransformDialectOp<"match.my.has_operand_satisfying", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + // Indicate that the operation implements MatchOpInterface in addition to + // the TransformOpInterface. This interface is only used as a tag at this + // point and has no methods that are mandatory to implement. + MatchOpInterface, + SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { + let summary = "Succeed if any of the operands matches all nested criteria"; + let arguments = (ins TransformHandleTypeInterface:$op); + let results = (outs TransformParamTypeInterface:$position, + Variadic:$results); + + // Match operations can be arbitrarily complex, e.g., containing regions. + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let assemblyFormat = [{ + $op `:` functional-type($op, results) attr-dict-with-keyword $body + }]; +} + +#endif // MY_EXTENSION diff --git a/mlir/example/transform_Ch4/lib/CMakeLists.txt b/mlir/example/transform_Ch4/lib/CMakeLists.txt new file mode 100644 index 0000000..b6a33d4 --- /dev/null +++ b/mlir/example/transform_Ch4/lib/CMakeLists.txt @@ -0,0 +1,17 @@ +# Outside examples, this should be `add_mlir_library`. +add_mlir_example_library( + # Library called MyExtension. + MyExtensionCh4 + # Built from the following source files. + MyExtension.cpp + # Make includes visible without top-level path. + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/examples/transform/Ch4/include + # Make sure ODS declaration and definitions are generated before compiling + # this. + DEPENDS + MyExtensionCh4IncGen + # Link in the transform dialect, an all generated dialects. + LINK_LIBS + PRIVATE + MLIRTransformDialect) diff --git a/mlir/example/transform_Ch4/lib/MyExtension.cpp b/mlir/example/transform_Ch4/lib/MyExtension.cpp new file mode 100644 index 0000000..38c8ca1 --- /dev/null +++ b/mlir/example/transform_Ch4/lib/MyExtension.cpp @@ -0,0 +1,206 @@ +//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 4 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE_MATCHER "transform-matcher" +#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") +#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) + +#define GET_OP_CLASSES +#include "MyExtension.cpp.inc" + +//===---------------------------------------------------------------------===// +// MyExtension +//===---------------------------------------------------------------------===// + +// Define a new transform dialect extension. This uses the CRTP idiom to +// identify extensions. +class MyExtension + : public ::mlir::transform::TransformDialectExtension { +public: + // The extension must derive the base constructor. + using Base::Base; + + // This function initializes the extension, similarly to `initialize` in + // dialect definitions. List individual operations and dependent dialects + // here. + void init(); +}; + +void MyExtension::init() { + // Register the additional match operations with the dialect similarly to + // other transform operations. List all operations generated from ODS. This + // call will perform additional checks that the operations implement the + // transform and memory effect interfaces required by the dialect interpreter + // and assert if they do not. + registerTransformOps< +#define GET_OP_LIST +#include "MyExtension.cpp.inc" + >(); +} + +//===---------------------------------------------------------------------===// +// HasOperandSatisfyingOp +//===---------------------------------------------------------------------===// + +/// Returns `true` if both types implement one of the interfaces provided as +/// template parameters. +template +static bool implementSameInterface(mlir::Type t1, mlir::Type t2) { + return ((llvm::isa(t1) && llvm::isa(t2)) || ... || false); +} + +/// Returns `true` if both types implement one of the transform dialect +/// interfaces. +static bool implementSameTransformInterface(mlir::Type t1, mlir::Type t2) { + return implementSameInterface< + mlir::transform::TransformHandleTypeInterface, + mlir::transform::TransformParamTypeInterface, + mlir::transform::TransformValueHandleTypeInterface>(t1, t2); +} + +// Matcher ops implement `apply` similarly to other transform ops. They are not +// expected to modify payload, but use the tri-state result to signal failure or +// success to match, as well as potential irrecoverable errors. +mlir::DiagnosedSilenceableFailure +mlir::transform::HasOperandSatisfyingOp::apply( + mlir::transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &results, + mlir::transform::TransformState &state) { + // For simplicity, only handle a single payload op. Actual implementations + // can use `SingleOpMatcher` trait to simplify implementation and document + // this expectation. + auto payloadOps = state.getPayloadOps(getOp()); + if (!llvm::hasSingleElement(payloadOps)) + return emitSilenceableError() << "expected single payload"; + + // Iterate over all operands of the payload op to see if they can be matched + // using the body of this op. + Operation *payload = *payloadOps.begin(); + for (OpOperand &operand : payload->getOpOperands()) { + // Create a scope for transform values defined in the body. This corresponds + // to the syntactic scope of the region attached to this op. Any values + // associated with payloads from now on will be automatically dissociated + // when this object is destroyed, i.e. at the end of the iteration. + // Associate the block argument handle with the operand. + auto matchScope = state.make_region_scope(getBody()); + if (failed(state.mapBlockArgument(getBody().getArgument(0), + {operand.get()}))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + + // Iterate over all nested matchers with the current mapping and see if they + // succeed. + bool matchSucceeded = true; + for (Operation &matcher : getBody().front().without_terminator()) { + // Matcher ops are applied similarly to any other transform op. + DiagnosedSilenceableFailure diag = + state.applyTransform(cast(matcher)); + + // Definite failures are immediately propagated as they are irrecoverable. + if (diag.isDefiniteFailure()) + return diag; + + // On success, keep checking the remaining conditions. + if (diag.succeeded()) + continue; + + // Report failure-to-match for debugging purposes and stop matching this + // operand. + assert(diag.isSilenceableFailure()); + DEBUG_MATCHER(DBGS_MATCHER() + << "failed to match operand #" << operand.getOperandNumber() + << ": " << diag.getMessage()); + (void)diag.silence(); + matchSucceeded = false; + break; + } + // If failed to match this operand, try other operands. + if (!matchSucceeded) + continue; + + // If we reached this point, the matching succeeded for the current operand. + // Remap the values associated with terminator operands to be associated + // with op results, and also map the parameter result to the operand's + // position. Note that it is safe to do here despite the end of the scope + // as `results` are integrated into `state` by the interpreter after `apply` + // returns rather than immediately. + SmallVector> yieldedMappings; + transform::detail::prepareValueMappings( + yieldedMappings, getBody().front().getTerminator()->getOperands(), + state); + results.setParams(cast(getPosition()), + {rewriter.getI32IntegerAttr(operand.getOperandNumber())}); + for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings)) + results.setMappedValues(result, mapping); + return DiagnosedSilenceableFailure::success(); + } + + // If we reached this point, none of the operands succeeded the match. + return emitSilenceableError() + << "none of the operands satisfied the conditions"; +} + +// By convention, operations implementing MatchOpInterface must not modify +// payload IR and must therefore specify that they only read operand handles and +// payload as their effects. +void mlir::transform::HasOperandSatisfyingOp::getEffects( + llvm::SmallVectorImpl &effects) { + onlyReadsPayload(effects); + onlyReadsHandle(getOpMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); +} + +// Verify well-formedness of the operation and emit diagnostics if it is +// ill-formed. +llvm::LogicalResult mlir::transform::HasOperandSatisfyingOp::verify() { + mlir::Block &bodyBlock = getBody().front(); + if (bodyBlock.getNumArguments() != 1 || + !isa( + bodyBlock.getArgument(0).getType())) { + return emitOpError() + << "expects the body to have one value handle argument"; + } + if (bodyBlock.getTerminator()->getNumOperands() != getNumResults() - 1) { + return emitOpError() << "expects the body to yield " + << (getNumResults() - 1) << " values, got " + << bodyBlock.getTerminator()->getNumOperands(); + } + for (auto &&[i, operand, result] : + llvm::enumerate(bodyBlock.getTerminator()->getOperands().getTypes(), + getResults().getTypes())) { + if (implementSameTransformInterface(operand, result)) + continue; + return emitOpError() << "expects terminator operand #" << i + << " and result #" << (i + 1) + << " to implement the same transform interface"; + } + + for (Operation &op : bodyBlock.without_terminator()) { + if (!isa(op) || !isa(op)) { + InFlightDiagnostic diag = emitOpError() + << "expects body to contain match ops"; + diag.attachNote(op.getLoc()) << "non-match operation"; + return diag; + } + } + + return success(); +} + +void registerMyExtension(::mlir::DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/example/transform_Ch4/transform-opt/transform-opt.cpp b/mlir/example/transform_Ch4/transform-opt/transform-opt.cpp new file mode 100644 index 0000000..03c84bd --- /dev/null +++ b/mlir/example/transform_Ch4/transform-opt/transform-opt.cpp @@ -0,0 +1,43 @@ +//===-- transform-opt.cpp - Transform dialect tutorial entry point --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top-level file for the Transform dialect tutorial chapter 4. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" + +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include + +int main(int argc, char **argv) { + // Register all "core" dialects and our transform dialect extension. + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::registerAllExtensions(registry); + registerMyExtension(registry); + + // Register a handful of cleanup passes that we can run to make the output IR + // look nicer. + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + mlir::registerSymbolDCEPass(); + mlir::transform::registerInterpreterPass(); + + // Delegate to the MLIR utility for parsing and pass management. + return mlir::MlirOptMain(argc, argv, "transform-opt-ch4", registry) + .succeeded() + ? EXIT_SUCCESS + : EXIT_FAILURE; +}