-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][affine] Add pass --affine-raise-from-memref (2) #138004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Restrict isValidDim to induction vars, and not iter_args
@llvm/pr-subscribers-mlir-affine Author: Clément Fournier (oowekyala) ChangesNote: this is a reopening of #114032, which at the time had been approved but not merged. I inadvertently closed it last month and am just realizing. CC @ftynse -- This adds a pass that converts memref.load/store into affine.load/store. This is useful as those memref operators are ignored by passes like --affine-scalrep as they don't implement the Affine[Read/Write]OpInterface. Doing this allows you to put as much of your program in affine form before you apply affine optimization passes. This also slightly changes the implementation of affine::isValidDim. The previous implementation allowed values from the iter_args of affine loops to be used as valid dims. I think this doesn't make sense and what was meant is just the induction vars. In the real world, there is little reason to find an index in the iter_args, but I wrote that in my tests and found out it was treated as an affine dim, so corrected that. Full diff: https://github.com/llvm/llvm-project/pull/138004.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index e152101236dc7..c1b9c30d302dd 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -22,6 +22,9 @@ namespace mlir {
namespace func {
class FuncOp;
} // namespace func
+namespace memref {
+class MemRefDialect;
+} // namespace memref
namespace affine {
class AffineForOp;
@@ -48,6 +51,9 @@ createAffineLoopInvariantCodeMotionPass();
/// ops.
std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();
+/// Creates a pass that converts some memref operators to affine operators.
+std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();
+
/// Apply normalization transformations to affine loop-like ops. If
/// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
/// loop is replaced by its loop body).
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 77073aa29da73..a77bcac5ed407 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -397,6 +397,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
}
+def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
+ let summary = "Turn some memref operators to affine operators where supported";
+ let description = [{
+ Raise memref.load and memref.store to affine.store and affine.load, inferring
+ the affine map of those operators if needed. This allows passes like --affine-scalrep
+ to optimize those loads and stores (forwarding them or eliminating them).
+ They can be turned back to memref dialect ops with --lower-affine.
+ }];
+ let constructor = "mlir::affine::createRaiseMemrefToAffine()";
+ let dependentDialects = ["affine::AffineDialect"];
+}
+
def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
let summary = "Simplify affine expressions in maps/sets and normalize "
"memrefs";
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index dceebbfec586c..06204188e14e2 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -284,10 +284,12 @@ bool mlir::affine::isValidDim(Value value) {
return isValidDim(value, getAffineScope(defOp));
// This value has to be a block argument for an op that has the
- // `AffineScope` trait or for an affine.for or affine.parallel.
+ // `AffineScope` trait or an induction var of an affine.for or
+ // affine.parallel.
+ if (isAffineInductionVar(value))
+ return true;
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
- return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
- isa<AffineForOp, AffineParallelOp>(parentOp));
+ return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
}
// Value can be used as a dimension id iff it meets one of the following
@@ -306,10 +308,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
auto *op = value.getDefiningOp();
if (!op) {
- // This value has to be a block argument for an affine.for or an
+ // This value has to be an induction var for an affine.for or an
// affine.parallel.
- auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
- return isa<AffineForOp, AffineParallelOp>(parentOp);
+ return isAffineInductionVar(value);
}
// Affine apply operation is ok if all of its operands are ok.
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index c42789b01bc9f..1c82822b2bd7f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
LoopUnroll.cpp
LoopUnrollAndJam.cpp
PipelineDataTransfer.cpp
+ RaiseMemrefDialect.cpp
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
new file mode 100644
index 0000000000000..491d2e03c36bc
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
@@ -0,0 +1,187 @@
+//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements functionality to convert memref load and store ops to
+// the corresponding affine ops, inferring the affine map as needed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Analysis/Utils.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_RAISEMEMREFDIALECT
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+#define DEBUG_TYPE "raise-memref-to-affine"
+
+using namespace mlir;
+using namespace mlir::affine;
+
+namespace {
+
+/// Find the index of the given value in the `dims` list,
+/// and append it if it was not already in the list. The
+/// dims list is a list of symbols or dimensions of the
+/// affine map. Within the results of an affine map, they
+/// are identified by their index, which is why we need
+/// this function.
+static std::optional<size_t>
+findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
+ function_ref<bool(Value)> isValidElement) {
+
+ Value *loopIV = std::find(dims.begin(), dims.end(), value);
+ if (loopIV != dims.end()) {
+ // We found an IV that already has an index, return that index.
+ return {std::distance(dims.begin(), loopIV)};
+ }
+ if (isValidElement(value)) {
+ // This is a valid element for the dim/symbol list, push this as a
+ // parameter.
+ size_t idx = dims.size();
+ dims.push_back(value);
+ return idx;
+ }
+ return std::nullopt;
+}
+
+/// Convert a value to an affine expr if possible. Adds dims and symbols
+/// if needed.
+static AffineExpr toAffineExpr(Value value,
+ llvm::SmallVectorImpl<Value> &affineDims,
+ llvm::SmallVectorImpl<Value> &affineSymbols) {
+ using namespace matchers;
+ IntegerAttr::ValueType cst;
+ if (matchPattern(value, m_ConstantInt(&cst))) {
+ return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
+ }
+
+ Operation *definingOp = value.getDefiningOp();
+ if (llvm::isa_and_nonnull<arith::AddIOp>(definingOp) ||
+ llvm::isa_and_nonnull<arith::MulIOp>(definingOp)) {
+ // TODO: replace recursion with explicit stack.
+ // For the moment this can be tolerated as we only recurse on
+ // arith.addi and arith.muli, so there cannot be any infinite
+ // recursion. The depth of these expressions should be in most
+ // cases very manageable, as affine expressions should be as
+ // simple as `a + b * c`.
+ AffineExpr lhsE =
+ toAffineExpr(definingOp->getOperand(0), affineDims, affineSymbols);
+ AffineExpr rhsE =
+ toAffineExpr(definingOp->getOperand(1), affineDims, affineSymbols);
+
+ if (lhsE && rhsE) {
+ AffineExprKind kind;
+ if (isa<arith::AddIOp>(definingOp)) {
+ kind = mlir::AffineExprKind::Add;
+ } else {
+ kind = mlir::AffineExprKind::Mul;
+
+ if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) {
+ // This is not an affine expression, give up.
+ return {};
+ }
+ }
+ return getAffineBinaryOpExpr(kind, lhsE, rhsE);
+ }
+ return {};
+ }
+
+ if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
+ return affine::isValidSymbol(v);
+ })) {
+ return getAffineSymbolExpr(*dimIx, value.getContext());
+ }
+
+ if (auto dimIx = findInListOrAdd(
+ value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
+
+ return getAffineDimExpr(*dimIx, value.getContext());
+ }
+
+ return {};
+}
+
+static LogicalResult
+computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
+ llvm::SmallVectorImpl<Value> &mapArgs) {
+ SmallVector<AffineExpr> results;
+ SmallVector<Value> symbols;
+ SmallVector<Value> dims;
+
+ for (Value indexExpr : indices) {
+ AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
+ if (!res) {
+ return failure();
+ }
+ results.push_back(res);
+ }
+
+ map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
+
+ dims.append(symbols);
+ mapArgs.swap(dims);
+ return success();
+}
+
+struct RaiseMemrefDialect
+ : public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ Operation *op = getOperation();
+ IRRewriter rewriter(ctx);
+ AffineMap map;
+ SmallVector<Value> mapArgs;
+ op->walk([&](Operation *op) {
+ rewriter.setInsertionPoint(op);
+ if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
+
+ if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
+ mapArgs))) {
+ rewriter.replaceOpWithNewOp<AffineStoreOp>(
+ op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
+ return;
+ }
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "[affine] Cannot raise memref op: " << op << "\n");
+
+ } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
+ if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
+ mapArgs))) {
+ rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
+ mapArgs);
+ return;
+ }
+ LLVM_DEBUG(llvm::dbgs()
+ << "[affine] Cannot raise memref op: " << op << "\n");
+ }
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::affine::createRaiseMemrefToAffine() {
+ return std::make_unique<RaiseMemrefDialect>();
+}
diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir
new file mode 100644
index 0000000000000..00cc98de1f40f
--- /dev/null
+++ b/mlir/test/Dialect/Affine/raise-memref.mlir
@@ -0,0 +1,138 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @reduce_window_max(
+func.func @reduce_window_max() {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = memref.alloc() : memref<1x8x8x64xf32>
+ %1 = memref.alloc() : memref<1x18x18x64xf32>
+ affine.for %arg0 = 0 to 1 {
+ affine.for %arg1 = 0 to 8 {
+ affine.for %arg2 = 0 to 8 {
+ affine.for %arg3 = 0 to 64 {
+ memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+ }
+ }
+ }
+ }
+ affine.for %arg0 = 0 to 1 {
+ affine.for %arg1 = 0 to 8 {
+ affine.for %arg2 = 0 to 8 {
+ affine.for %arg3 = 0 to 64 {
+ affine.for %arg4 = 0 to 1 {
+ affine.for %arg5 = 0 to 3 {
+ affine.for %arg6 = 0 to 3 {
+ affine.for %arg7 = 0 to 1 {
+ %2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+ %21 = arith.addi %arg0, %arg4 : index
+ %22 = arith.constant 2 : index
+ %23 = arith.muli %arg1, %22 : index
+ %24 = arith.addi %23, %arg5 : index
+ %25 = arith.muli %arg2, %22 : index
+ %26 = arith.addi %25, %arg6 : index
+ %27 = arith.addi %arg3, %arg7 : index
+ %3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32>
+ %4 = arith.cmpf ogt, %2, %3 : f32
+ %5 = arith.select %4, %2, %3 : f32
+ memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ return
+}
+
+// CHECK: %[[cst:.*]] = arith.constant 0
+// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
+// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
+// CHECK: affine.for %[[arg0:.*]] =
+// CHECK: affine.for %[[arg1:.*]] =
+// CHECK: affine.for %[[arg2:.*]] =
+// CHECK: affine.for %[[arg3:.*]] =
+// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] :
+// CHECK: affine.for %[[a0:.*]] =
+// CHECK: affine.for %[[a1:.*]] =
+// CHECK: affine.for %[[a2:.*]] =
+// CHECK: affine.for %[[a3:.*]] =
+// CHECK: affine.for %[[a4:.*]] =
+// CHECK: affine.for %[[a5:.*]] =
+// CHECK: affine.for %[[a6:.*]] =
+// CHECK: affine.for %[[a7:.*]] =
+// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
+// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] :
+// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
+// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
+// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
+
+// CHECK-LABEL: func @symbols(
+func.func @symbols(%N : index) {
+ %0 = memref.alloc() : memref<1024x1024xf32>
+ %1 = memref.alloc() : memref<1024x1024xf32>
+ %2 = memref.alloc() : memref<1024x1024xf32>
+ %cst1 = arith.constant 1 : index
+ %cst2 = arith.constant 2 : index
+ affine.for %i = 0 to %N {
+ affine.for %j = 0 to %N {
+ %7 = memref.load %2[%i, %j] : memref<1024x1024xf32>
+ %10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index {
+ %12 = arith.muli %N, %cst2 : index
+ %13 = arith.addi %12, %cst1 : index
+ %14 = arith.addi %13, %j : index
+ %5 = memref.load %0[%i, %12] : memref<1024x1024xf32>
+ %6 = memref.load %1[%14, %j] : memref<1024x1024xf32>
+ %8 = arith.mulf %5, %6 : f32
+ %9 = arith.addf %7, %8 : f32
+ %4 = arith.addi %N, %cst1 : index
+ %11 = arith.addi %ax, %cst1 : index
+ memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
+ memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised
+ %something = "ab.v"() : () -> index
+ memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be raised
+ affine.yield %11 : index
+ }
+ }
+ }
+ return
+}
+
+// CHECK: %[[cst1:.*]] = arith.constant 1 : index
+// CHECK: %[[v0:.*]] = memref.alloc() : memref<
+// CHECK: %[[v1:.*]] = memref.alloc() : memref<
+// CHECK: %[[v2:.*]] = memref.alloc() : memref<
+// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 {
+// CHECK: affine.for %[[a2:.*]] = 0 to %arg0 {
+// CHECK: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
+// CHECK: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
+// CHECK: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] :
+// CHECK: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] :
+// CHECK: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
+// CHECK: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
+// CHECK: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
+// CHECK: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] :
+// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] :
+// CHECK: %[[lhs7:.*]] = "ab.v"
+// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] :
+// CHECK: affine.yield %[[lhs6]]
+
+
+// CHECK-LABEL: func @non_affine(
+func.func @non_affine(%N : index) {
+ %2 = memref.alloc() : memref<1024x1024xf32>
+ affine.for %i = 0 to %N {
+ affine.for %j = 0 to %N {
+ %ij = arith.muli %i, %j : index
+ %7 = memref.load %2[%i, %ij] : memref<1024x1024xf32>
+ memref.store %7, %2[%ij, %ij] : memref<1024x1024xf32>
+ }
+ }
+ return
+}
+
+// CHECK: affine.for %[[i:.*]] =
+// CHECK: affine.for %[[j:.*]] =
+// CHECK: %[[ij:.*]] = arith.muli %[[i]], %[[j]]
+// CHECK: %[[v:.*]] = memref.load %{{.*}}[%[[i]], %[[ij]]]
+// CHECK: memref.store %[[v]], %{{.*}}[%[[ij]], %[[ij]]]
\ No newline at end of file
|
@llvm/pr-subscribers-mlir Author: Clément Fournier (oowekyala) ChangesNote: this is a reopening of #114032, which at the time had been approved but not merged. I inadvertently closed it last month and am just realizing. CC @ftynse -- This adds a pass that converts memref.load/store into affine.load/store. This is useful as those memref operators are ignored by passes like --affine-scalrep as they don't implement the Affine[Read/Write]OpInterface. Doing this allows you to put as much of your program in affine form before you apply affine optimization passes. This also slightly changes the implementation of affine::isValidDim. The previous implementation allowed values from the iter_args of affine loops to be used as valid dims. I think this doesn't make sense and what was meant is just the induction vars. In the real world, there is little reason to find an index in the iter_args, but I wrote that in my tests and found out it was treated as an affine dim, so corrected that. Full diff: https://github.com/llvm/llvm-project/pull/138004.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index e152101236dc7..c1b9c30d302dd 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -22,6 +22,9 @@ namespace mlir {
namespace func {
class FuncOp;
} // namespace func
+namespace memref {
+class MemRefDialect;
+} // namespace memref
namespace affine {
class AffineForOp;
@@ -48,6 +51,9 @@ createAffineLoopInvariantCodeMotionPass();
/// ops.
std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();
+/// Creates a pass that converts some memref operators to affine operators.
+std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();
+
/// Apply normalization transformations to affine loop-like ops. If
/// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
/// loop is replaced by its loop body).
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 77073aa29da73..a77bcac5ed407 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -397,6 +397,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
}
+def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
+ let summary = "Turn some memref operators to affine operators where supported";
+ let description = [{
+ Raise memref.load and memref.store to affine.store and affine.load, inferring
+ the affine map of those operators if needed. This allows passes like --affine-scalrep
+ to optimize those loads and stores (forwarding them or eliminating them).
+ They can be turned back to memref dialect ops with --lower-affine.
+ }];
+ let constructor = "mlir::affine::createRaiseMemrefToAffine()";
+ let dependentDialects = ["affine::AffineDialect"];
+}
+
def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
let summary = "Simplify affine expressions in maps/sets and normalize "
"memrefs";
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index dceebbfec586c..06204188e14e2 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -284,10 +284,12 @@ bool mlir::affine::isValidDim(Value value) {
return isValidDim(value, getAffineScope(defOp));
// This value has to be a block argument for an op that has the
- // `AffineScope` trait or for an affine.for or affine.parallel.
+ // `AffineScope` trait or an induction var of an affine.for or
+ // affine.parallel.
+ if (isAffineInductionVar(value))
+ return true;
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
- return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
- isa<AffineForOp, AffineParallelOp>(parentOp));
+ return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
}
// Value can be used as a dimension id iff it meets one of the following
@@ -306,10 +308,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
auto *op = value.getDefiningOp();
if (!op) {
- // This value has to be a block argument for an affine.for or an
+ // This value has to be an induction var for an affine.for or an
// affine.parallel.
- auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
- return isa<AffineForOp, AffineParallelOp>(parentOp);
+ return isAffineInductionVar(value);
}
// Affine apply operation is ok if all of its operands are ok.
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index c42789b01bc9f..1c82822b2bd7f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
LoopUnroll.cpp
LoopUnrollAndJam.cpp
PipelineDataTransfer.cpp
+ RaiseMemrefDialect.cpp
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
new file mode 100644
index 0000000000000..491d2e03c36bc
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
@@ -0,0 +1,187 @@
+//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements functionality to convert memref load and store ops to
+// the corresponding affine ops, inferring the affine map as needed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Analysis/Utils.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_RAISEMEMREFDIALECT
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+#define DEBUG_TYPE "raise-memref-to-affine"
+
+using namespace mlir;
+using namespace mlir::affine;
+
+namespace {
+
+/// Find the index of the given value in the `dims` list,
+/// and append it if it was not already in the list. The
+/// dims list is a list of symbols or dimensions of the
+/// affine map. Within the results of an affine map, they
+/// are identified by their index, which is why we need
+/// this function.
+static std::optional<size_t>
+findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
+ function_ref<bool(Value)> isValidElement) {
+
+ Value *loopIV = std::find(dims.begin(), dims.end(), value);
+ if (loopIV != dims.end()) {
+ // We found an IV that already has an index, return that index.
+ return {std::distance(dims.begin(), loopIV)};
+ }
+ if (isValidElement(value)) {
+ // This is a valid element for the dim/symbol list, push this as a
+ // parameter.
+ size_t idx = dims.size();
+ dims.push_back(value);
+ return idx;
+ }
+ return std::nullopt;
+}
+
+/// Convert a value to an affine expr if possible. Adds dims and symbols
+/// if needed.
+static AffineExpr toAffineExpr(Value value,
+ llvm::SmallVectorImpl<Value> &affineDims,
+ llvm::SmallVectorImpl<Value> &affineSymbols) {
+ using namespace matchers;
+ IntegerAttr::ValueType cst;
+ if (matchPattern(value, m_ConstantInt(&cst))) {
+ return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
+ }
+
+ Operation *definingOp = value.getDefiningOp();
+ if (llvm::isa_and_nonnull<arith::AddIOp>(definingOp) ||
+ llvm::isa_and_nonnull<arith::MulIOp>(definingOp)) {
+ // TODO: replace recursion with explicit stack.
+ // For the moment this can be tolerated as we only recurse on
+ // arith.addi and arith.muli, so there cannot be any infinite
+ // recursion. The depth of these expressions should be in most
+ // cases very manageable, as affine expressions should be as
+ // simple as `a + b * c`.
+ AffineExpr lhsE =
+ toAffineExpr(definingOp->getOperand(0), affineDims, affineSymbols);
+ AffineExpr rhsE =
+ toAffineExpr(definingOp->getOperand(1), affineDims, affineSymbols);
+
+ if (lhsE && rhsE) {
+ AffineExprKind kind;
+ if (isa<arith::AddIOp>(definingOp)) {
+ kind = mlir::AffineExprKind::Add;
+ } else {
+ kind = mlir::AffineExprKind::Mul;
+
+ if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) {
+ // This is not an affine expression, give up.
+ return {};
+ }
+ }
+ return getAffineBinaryOpExpr(kind, lhsE, rhsE);
+ }
+ return {};
+ }
+
+ if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
+ return affine::isValidSymbol(v);
+ })) {
+ return getAffineSymbolExpr(*dimIx, value.getContext());
+ }
+
+ if (auto dimIx = findInListOrAdd(
+ value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
+
+ return getAffineDimExpr(*dimIx, value.getContext());
+ }
+
+ return {};
+}
+
+static LogicalResult
+computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
+ llvm::SmallVectorImpl<Value> &mapArgs) {
+ SmallVector<AffineExpr> results;
+ SmallVector<Value> symbols;
+ SmallVector<Value> dims;
+
+ for (Value indexExpr : indices) {
+ AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
+ if (!res) {
+ return failure();
+ }
+ results.push_back(res);
+ }
+
+ map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
+
+ dims.append(symbols);
+ mapArgs.swap(dims);
+ return success();
+}
+
+struct RaiseMemrefDialect
+ : public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ Operation *op = getOperation();
+ IRRewriter rewriter(ctx);
+ AffineMap map;
+ SmallVector<Value> mapArgs;
+ op->walk([&](Operation *op) {
+ rewriter.setInsertionPoint(op);
+ if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
+
+ if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
+ mapArgs))) {
+ rewriter.replaceOpWithNewOp<AffineStoreOp>(
+ op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
+ return;
+ }
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "[affine] Cannot raise memref op: " << op << "\n");
+
+ } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
+ if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
+ mapArgs))) {
+ rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
+ mapArgs);
+ return;
+ }
+ LLVM_DEBUG(llvm::dbgs()
+ << "[affine] Cannot raise memref op: " << op << "\n");
+ }
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::affine::createRaiseMemrefToAffine() {
+ return std::make_unique<RaiseMemrefDialect>();
+}
diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir
new file mode 100644
index 0000000000000..00cc98de1f40f
--- /dev/null
+++ b/mlir/test/Dialect/Affine/raise-memref.mlir
@@ -0,0 +1,138 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @reduce_window_max(
+func.func @reduce_window_max() {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = memref.alloc() : memref<1x8x8x64xf32>
+ %1 = memref.alloc() : memref<1x18x18x64xf32>
+ affine.for %arg0 = 0 to 1 {
+ affine.for %arg1 = 0 to 8 {
+ affine.for %arg2 = 0 to 8 {
+ affine.for %arg3 = 0 to 64 {
+ memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+ }
+ }
+ }
+ }
+ affine.for %arg0 = 0 to 1 {
+ affine.for %arg1 = 0 to 8 {
+ affine.for %arg2 = 0 to 8 {
+ affine.for %arg3 = 0 to 64 {
+ affine.for %arg4 = 0 to 1 {
+ affine.for %arg5 = 0 to 3 {
+ affine.for %arg6 = 0 to 3 {
+ affine.for %arg7 = 0 to 1 {
+ %2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+ %21 = arith.addi %arg0, %arg4 : index
+ %22 = arith.constant 2 : index
+ %23 = arith.muli %arg1, %22 : index
+ %24 = arith.addi %23, %arg5 : index
+ %25 = arith.muli %arg2, %22 : index
+ %26 = arith.addi %25, %arg6 : index
+ %27 = arith.addi %arg3, %arg7 : index
+ %3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32>
+ %4 = arith.cmpf ogt, %2, %3 : f32
+ %5 = arith.select %4, %2, %3 : f32
+ memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ return
+}
+
+// CHECK: %[[cst:.*]] = arith.constant 0
+// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
+// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
+// CHECK: affine.for %[[arg0:.*]] =
+// CHECK: affine.for %[[arg1:.*]] =
+// CHECK: affine.for %[[arg2:.*]] =
+// CHECK: affine.for %[[arg3:.*]] =
+// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] :
+// CHECK: affine.for %[[a0:.*]] =
+// CHECK: affine.for %[[a1:.*]] =
+// CHECK: affine.for %[[a2:.*]] =
+// CHECK: affine.for %[[a3:.*]] =
+// CHECK: affine.for %[[a4:.*]] =
+// CHECK: affine.for %[[a5:.*]] =
+// CHECK: affine.for %[[a6:.*]] =
+// CHECK: affine.for %[[a7:.*]] =
+// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
+// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] :
+// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
+// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
+// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
+
+// CHECK-LABEL: func @symbols(
+func.func @symbols(%N : index) {
+ %0 = memref.alloc() : memref<1024x1024xf32>
+ %1 = memref.alloc() : memref<1024x1024xf32>
+ %2 = memref.alloc() : memref<1024x1024xf32>
+ %cst1 = arith.constant 1 : index
+ %cst2 = arith.constant 2 : index
+ affine.for %i = 0 to %N {
+ affine.for %j = 0 to %N {
+ %7 = memref.load %2[%i, %j] : memref<1024x1024xf32>
+ %10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index {
+ %12 = arith.muli %N, %cst2 : index
+ %13 = arith.addi %12, %cst1 : index
+ %14 = arith.addi %13, %j : index
+ %5 = memref.load %0[%i, %12] : memref<1024x1024xf32>
+ %6 = memref.load %1[%14, %j] : memref<1024x1024xf32>
+ %8 = arith.mulf %5, %6 : f32
+ %9 = arith.addf %7, %8 : f32
+ %4 = arith.addi %N, %cst1 : index
+ %11 = arith.addi %ax, %cst1 : index
+ memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
+ memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised
+ %something = "ab.v"() : () -> index
+ memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be raised
+ affine.yield %11 : index
+ }
+ }
+ }
+ return
+}
+
+// CHECK: %[[cst1:.*]] = arith.constant 1 : index
+// CHECK: %[[v0:.*]] = memref.alloc() : memref<
+// CHECK: %[[v1:.*]] = memref.alloc() : memref<
+// CHECK: %[[v2:.*]] = memref.alloc() : memref<
+// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 {
+// CHECK: affine.for %[[a2:.*]] = 0 to %arg0 {
+// CHECK: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
+// CHECK: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
+// CHECK: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] :
+// CHECK: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] :
+// CHECK: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
+// CHECK: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
+// CHECK: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
+// CHECK: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] :
+// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] :
+// CHECK: %[[lhs7:.*]] = "ab.v"
+// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] :
+// CHECK: affine.yield %[[lhs6]]
+
+
+// CHECK-LABEL: func @non_affine(
+func.func @non_affine(%N : index) {
+ %2 = memref.alloc() : memref<1024x1024xf32>
+ affine.for %i = 0 to %N {
+ affine.for %j = 0 to %N {
+ %ij = arith.muli %i, %j : index
+ %7 = memref.load %2[%i, %ij] : memref<1024x1024xf32>
+ memref.store %7, %2[%ij, %ij] : memref<1024x1024xf32>
+ }
+ }
+ return
+}
+
+// CHECK: affine.for %[[i:.*]] =
+// CHECK: affine.for %[[j:.*]] =
+// CHECK: %[[ij:.*]] = arith.muli %[[i]], %[[j]]
+// CHECK: %[[v:.*]] = memref.load %{{.*}}[%[[i]], %[[ij]]]
+// CHECK: memref.store %[[v]], %{{.*}}[%[[ij]], %[[ij]]]
\ No newline at end of file
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/12701 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/198/builds/4250 Here is the relevant piece of the build log for the reference
|
This adds a pass that converts memref.load/store into affine.load/store. This is useful as those memref operators are ignored by passes like --affine-scalrep as they don't implement the Affine[Read/Write]OpInterface. Doing this allows you to put as much of your program in affine form before you apply affine optimization passes. This also slightly changes the implementation of affine::isValidDim. The previous implementation allowed values from the iter_args of affine loops to be used as valid dims. I think this doesn't make sense and what was meant is just the induction vars. In the real world, there is little reason to find an index in the iter_args, but I wrote that in my tests and found out it was treated as an affine dim, so corrected that. Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]> Rebased from llvm#114032.
the change to isValidDim breaks at least one test in heir https://github.com/google/heir/blob/main/tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/runner/lower_ntt_perf_runner.mlir
that now fails with |
@ftynse do you have an opinion about isValidDim? This driveby change (at least how it's described) might not be correct |
I have reverted isValidDim part in #139069 |
Note: this is a reopening of #114032, which at the time had been approved but not merged. I inadvertently closed it last month and am just realizing.
CC @ftynse
--
This adds a pass that converts memref.load/store into affine.load/store. This is useful as those memref operators are ignored by passes like --affine-scalrep as they don't implement the Affine[Read/Write]OpInterface. Doing this allows you to put as much of your program in affine form before you apply affine optimization passes.
This also slightly changes the implementation of affine::isValidDim. The previous implementation allowed values from the iter_args of affine loops to be used as valid dims. I think this doesn't make sense and what was meant is just the induction vars. In the real world, there is little reason to find an index in the iter_args, but I wrote that in my tests and found out it was treated as an affine dim, so corrected that.