MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 1 | //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// |
| 2 | // |
| 3 | // Copyright 2019 The MLIR Authors. |
| 4 | // |
| 5 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | // you may not use this file except in compliance with the License. |
| 7 | // You may obtain a copy of the License at |
| 8 | // |
| 9 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | // |
| 11 | // Unless required by applicable law or agreed to in writing, software |
| 12 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | // See the License for the specific language governing permissions and |
| 15 | // limitations under the License. |
| 16 | // ============================================================================= |
| 17 | // |
| 18 | // This file implements loop fusion. |
| 19 | // |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | |
| 22 | #include "mlir/Analysis/AffineAnalysis.h" |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 23 | #include "mlir/Analysis/AffineStructures.h" |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 24 | #include "mlir/Analysis/LoopAnalysis.h" |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 25 | #include "mlir/Analysis/Utils.h" |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 26 | #include "mlir/IR/AffineExpr.h" |
| 27 | #include "mlir/IR/AffineMap.h" |
| 28 | #include "mlir/IR/Builders.h" |
| 29 | #include "mlir/IR/BuiltinOps.h" |
| 30 | #include "mlir/IR/StmtVisitor.h" |
| 31 | #include "mlir/Pass.h" |
| 32 | #include "mlir/StandardOps/StandardOps.h" |
| 33 | #include "mlir/Transforms/LoopUtils.h" |
| 34 | #include "mlir/Transforms/Passes.h" |
| 35 | #include "llvm/ADT/DenseMap.h" |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 36 | #include "llvm/ADT/DenseSet.h" |
| 37 | #include "llvm/ADT/SetVector.h" |
MLIR Team | 4eef795 | 2018-12-21 19:06:23 | [diff] [blame] | 38 | #include "llvm/Support/CommandLine.h" |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 39 | #include "llvm/Support/raw_ostream.h" |
| 40 | |
| 41 | using llvm::SetVector; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 42 | |
| 43 | using namespace mlir; |
| 44 | |
MLIR Team | 4eef795 | 2018-12-21 19:06:23 | [diff] [blame] | 45 | // TODO(andydavis) These flags are global for the pass to be used for |
| 46 | // experimentation. Find a way to provide more fine grained control (i.e. |
| 47 | // depth per-loop nest, or depth per load/store op) for this pass utilizing a |
| 48 | // cost model. |
| 49 | static llvm::cl::opt<unsigned> clSrcLoopDepth( |
| 50 | "src-loop-depth", llvm::cl::Hidden, |
| 51 | llvm::cl::desc("Controls the depth of the source loop nest at which " |
| 52 | "to apply loop iteration slicing before fusion.")); |
| 53 | |
| 54 | static llvm::cl::opt<unsigned> clDstLoopDepth( |
| 55 | "dst-loop-depth", llvm::cl::Hidden, |
| 56 | llvm::cl::desc("Controls the depth of the destination loop nest at which " |
| 57 | "to fuse the source loop nest slice.")); |
| 58 | |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 59 | namespace { |
| 60 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 61 | /// Loop fusion pass. This pass currently supports a greedy fusion policy, |
| 62 | /// which fuses loop nests with single-writer/single-reader memref dependences |
| 63 | /// with the goal of improving locality. |
| 64 | |
| 65 | // TODO(andydavis) Support fusion of source loop nests which write to multiple |
| 66 | // memrefs, where each memref can have multiple users (if profitable). |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 67 | // TODO(andydavis) Extend this pass to check for fusion preventing dependences, |
| 68 | // and add support for more general loop fusion algorithms. |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 69 | |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 70 | struct LoopFusion : public FunctionPass { |
Jacques Pienaar | cc9a6ed | 2018-11-07 18:24:03 | [diff] [blame] | 71 | LoopFusion() : FunctionPass(&LoopFusion::passID) {} |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 72 | |
Chris Lattner | 69d9e99 | 2018-12-28 16:48:09 | [diff] [blame] | 73 | PassResult runOnMLFunction(Function *f) override; |
Jacques Pienaar | 6f0fb22 | 2018-11-07 02:34:18 | [diff] [blame] | 74 | static char passID; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 75 | }; |
| 76 | |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 77 | } // end anonymous namespace |
| 78 | |
Jacques Pienaar | 6f0fb22 | 2018-11-07 02:34:18 | [diff] [blame] | 79 | char LoopFusion::passID = 0; |
| 80 | |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 81 | FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } |
| 82 | |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 83 | static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt, |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 84 | MemRefAccess *access) { |
| 85 | if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 86 | access->memref = loadOp->getMemRef(); |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 87 | access->opStmt = loadOrStoreOpStmt; |
| 88 | auto loadMemrefType = loadOp->getMemRefType(); |
| 89 | access->indices.reserve(loadMemrefType.getRank()); |
| 90 | for (auto *index : loadOp->getIndices()) { |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 91 | access->indices.push_back(index); |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 92 | } |
| 93 | } else { |
| 94 | assert(loadOrStoreOpStmt->isa<StoreOp>()); |
| 95 | auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>(); |
| 96 | access->opStmt = loadOrStoreOpStmt; |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 97 | access->memref = storeOp->getMemRef(); |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 98 | auto storeMemrefType = storeOp->getMemRefType(); |
| 99 | access->indices.reserve(storeMemrefType.getRank()); |
| 100 | for (auto *index : storeOp->getIndices()) { |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 101 | access->indices.push_back(index); |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 102 | } |
| 103 | } |
| 104 | } |
| 105 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 106 | // FusionCandidate encapsulates source and destination memref access within |
| 107 | // loop nests which are candidates for loop fusion. |
| 108 | struct FusionCandidate { |
| 109 | // Load or store access within src loop nest to be fused into dst loop nest. |
| 110 | MemRefAccess srcAccess; |
| 111 | // Load or store access within dst loop nest. |
| 112 | MemRefAccess dstAccess; |
| 113 | }; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 114 | |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 115 | static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt, |
| 116 | OperationInst *dstLoadOpStmt) { |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 117 | FusionCandidate candidate; |
| 118 | // Get store access for src loop nest. |
| 119 | getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess); |
| 120 | // Get load access for dst loop nest. |
| 121 | getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess); |
| 122 | return candidate; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 123 | } |
| 124 | |
MLIR Team | 4eef795 | 2018-12-21 19:06:23 | [diff] [blame] | 125 | // Returns the loop depth of the loop nest surrounding 'opStmt'. |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 126 | static unsigned getLoopDepth(OperationInst *opStmt) { |
MLIR Team | 4eef795 | 2018-12-21 19:06:23 | [diff] [blame] | 127 | unsigned loopDepth = 0; |
| 128 | auto *currStmt = opStmt->getParentStmt(); |
| 129 | ForStmt *currForStmt; |
| 130 | while (currStmt && (currForStmt = dyn_cast<ForStmt>(currStmt))) { |
| 131 | ++loopDepth; |
| 132 | currStmt = currStmt->getParentStmt(); |
| 133 | } |
| 134 | return loopDepth; |
| 135 | } |
| 136 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 137 | namespace { |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 138 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 139 | // LoopNestStateCollector walks loop nests and collects load and store |
| 140 | // operations, and whether or not an IfStmt was encountered in the loop nest. |
| 141 | class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> { |
| 142 | public: |
| 143 | SmallVector<ForStmt *, 4> forStmts; |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 144 | SmallVector<OperationInst *, 4> loadOpStmts; |
| 145 | SmallVector<OperationInst *, 4> storeOpStmts; |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 146 | bool hasIfStmt = false; |
| 147 | |
| 148 | void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } |
| 149 | |
| 150 | void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } |
| 151 | |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 152 | void visitOperationInst(OperationInst *opStmt) { |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 153 | if (opStmt->isa<LoadOp>()) |
| 154 | loadOpStmts.push_back(opStmt); |
| 155 | if (opStmt->isa<StoreOp>()) |
| 156 | storeOpStmts.push_back(opStmt); |
| 157 | } |
| 158 | }; |
| 159 | |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 160 | // MemRefDependenceGraph is a graph data structure where graph nodes are |
Chris Lattner | 69d9e99 | 2018-12-28 16:48:09 | [diff] [blame] | 161 | // top-level statements in a Function which contain load/store ops, and edges |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 162 | // are memref dependences between the nodes. |
| 163 | // TODO(andydavis) Add a depth parameter to dependence graph construction. |
| 164 | struct MemRefDependenceGraph { |
| 165 | public: |
| 166 | // Node represents a node in the graph. A Node is either an entire loop nest |
| 167 | // rooted at the top level which contains loads/stores, or a top level |
| 168 | // load/store. |
| 169 | struct Node { |
| 170 | // The unique identifier of this node in the graph. |
| 171 | unsigned id; |
| 172 | // The top-level statment which is (or contains) loads/stores. |
| 173 | Statement *stmt; |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 174 | // List of load operations. |
| 175 | SmallVector<OperationInst *, 4> loads; |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 176 | // List of store op stmts. |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 177 | SmallVector<OperationInst *, 4> stores; |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 178 | Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {} |
| 179 | |
| 180 | // Returns the load op count for 'memref'. |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 181 | unsigned getLoadOpCount(Value *memref) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 182 | unsigned loadOpCount = 0; |
| 183 | for (auto *loadOpStmt : loads) { |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 184 | if (memref == loadOpStmt->cast<LoadOp>()->getMemRef()) |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 185 | ++loadOpCount; |
| 186 | } |
| 187 | return loadOpCount; |
| 188 | } |
| 189 | |
| 190 | // Returns the store op count for 'memref'. |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 191 | unsigned getStoreOpCount(Value *memref) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 192 | unsigned storeOpCount = 0; |
| 193 | for (auto *storeOpStmt : stores) { |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 194 | if (memref == storeOpStmt->cast<StoreOp>()->getMemRef()) |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 195 | ++storeOpCount; |
| 196 | } |
| 197 | return storeOpCount; |
| 198 | } |
| 199 | }; |
| 200 | |
| 201 | // Edge represents a memref data dependece between nodes in the graph. |
| 202 | struct Edge { |
| 203 | // The id of the node at the other end of the edge. |
| 204 | unsigned id; |
| 205 | // The memref on which this edge represents a dependence. |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 206 | Value *memref; |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 207 | }; |
| 208 | |
| 209 | // Map from node id to Node. |
| 210 | DenseMap<unsigned, Node> nodes; |
| 211 | // Map from node id to list of input edges. |
| 212 | DenseMap<unsigned, SmallVector<Edge, 2>> inEdges; |
| 213 | // Map from node id to list of output edges. |
| 214 | DenseMap<unsigned, SmallVector<Edge, 2>> outEdges; |
| 215 | |
| 216 | MemRefDependenceGraph() {} |
| 217 | |
| 218 | // Initializes the dependence graph based on operations in 'f'. |
| 219 | // Returns true on success, false otherwise. |
Chris Lattner | 69d9e99 | 2018-12-28 16:48:09 | [diff] [blame] | 220 | bool init(Function *f); |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 221 | |
| 222 | // Returns the graph node for 'id'. |
| 223 | Node *getNode(unsigned id) { |
| 224 | auto it = nodes.find(id); |
| 225 | assert(it != nodes.end()); |
| 226 | return &it->second; |
| 227 | } |
| 228 | |
| 229 | // Adds an edge from node 'srcId' to node 'dstId' for 'memref'. |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 230 | void addEdge(unsigned srcId, unsigned dstId, Value *memref) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 231 | outEdges[srcId].push_back({dstId, memref}); |
| 232 | inEdges[dstId].push_back({srcId, memref}); |
| 233 | } |
| 234 | |
| 235 | // Removes an edge from node 'srcId' to node 'dstId' for 'memref'. |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 236 | void removeEdge(unsigned srcId, unsigned dstId, Value *memref) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 237 | assert(inEdges.count(dstId) > 0); |
| 238 | assert(outEdges.count(srcId) > 0); |
| 239 | // Remove 'srcId' from 'inEdges[dstId]'. |
| 240 | for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { |
| 241 | if ((*it).id == srcId && (*it).memref == memref) { |
| 242 | inEdges[dstId].erase(it); |
| 243 | break; |
| 244 | } |
| 245 | } |
| 246 | // Remove 'dstId' from 'outEdges[srcId]'. |
| 247 | for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) { |
| 248 | if ((*it).id == dstId && (*it).memref == memref) { |
| 249 | outEdges[srcId].erase(it); |
| 250 | break; |
| 251 | } |
| 252 | } |
| 253 | } |
| 254 | |
| 255 | // Returns the input edge count for node 'id' and 'memref'. |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 256 | unsigned getInEdgeCount(unsigned id, Value *memref) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 257 | unsigned inEdgeCount = 0; |
| 258 | if (inEdges.count(id) > 0) |
| 259 | for (auto &inEdge : inEdges[id]) |
| 260 | if (inEdge.memref == memref) |
| 261 | ++inEdgeCount; |
| 262 | return inEdgeCount; |
| 263 | } |
| 264 | |
| 265 | // Returns the output edge count for node 'id' and 'memref'. |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 266 | unsigned getOutEdgeCount(unsigned id, Value *memref) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 267 | unsigned outEdgeCount = 0; |
| 268 | if (outEdges.count(id) > 0) |
| 269 | for (auto &outEdge : outEdges[id]) |
| 270 | if (outEdge.memref == memref) |
| 271 | ++outEdgeCount; |
| 272 | return outEdgeCount; |
| 273 | } |
| 274 | |
| 275 | // Returns the min node id of all output edges from node 'id'. |
| 276 | unsigned getMinOutEdgeNodeId(unsigned id) { |
| 277 | unsigned minId = std::numeric_limits<unsigned>::max(); |
| 278 | if (outEdges.count(id) > 0) |
| 279 | for (auto &outEdge : outEdges[id]) |
| 280 | minId = std::min(minId, outEdge.id); |
| 281 | return minId; |
| 282 | } |
| 283 | |
| 284 | // Updates edge mappings from node 'srcId' to node 'dstId' and removes |
| 285 | // state associated with node 'srcId'. |
| 286 | void updateEdgesAndRemoveSrcNode(unsigned srcId, unsigned dstId) { |
| 287 | // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. |
| 288 | if (inEdges.count(srcId) > 0) { |
| 289 | SmallVector<Edge, 2> oldInEdges = inEdges[srcId]; |
| 290 | for (auto &inEdge : oldInEdges) { |
| 291 | // Remove edge from 'inEdge.id' to 'srcId'. |
| 292 | removeEdge(inEdge.id, srcId, inEdge.memref); |
| 293 | // Add edge from 'inEdge.id' to 'dstId'. |
| 294 | addEdge(inEdge.id, dstId, inEdge.memref); |
| 295 | } |
| 296 | } |
| 297 | // For each edge in 'outEdges[srcId]': add new edge remaping to 'dstId'. |
| 298 | if (outEdges.count(srcId) > 0) { |
| 299 | SmallVector<Edge, 2> oldOutEdges = outEdges[srcId]; |
| 300 | for (auto &outEdge : oldOutEdges) { |
| 301 | // Remove edge from 'srcId' to 'outEdge.id'. |
| 302 | removeEdge(srcId, outEdge.id, outEdge.memref); |
| 303 | // Add edge from 'dstId' to 'outEdge.id' (if 'outEdge.id' != 'dstId'). |
| 304 | if (outEdge.id != dstId) |
| 305 | addEdge(dstId, outEdge.id, outEdge.memref); |
| 306 | } |
| 307 | } |
| 308 | // Remove 'srcId' from graph state. |
| 309 | inEdges.erase(srcId); |
| 310 | outEdges.erase(srcId); |
| 311 | nodes.erase(srcId); |
| 312 | } |
| 313 | |
| 314 | // Adds ops in 'loads' and 'stores' to node at 'id'. |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 315 | void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads, |
| 316 | const SmallVectorImpl<OperationInst *> &stores) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 317 | Node *node = getNode(id); |
| 318 | for (auto *loadOpStmt : loads) |
| 319 | node->loads.push_back(loadOpStmt); |
| 320 | for (auto *storeOpStmt : stores) |
| 321 | node->stores.push_back(storeOpStmt); |
| 322 | } |
| 323 | |
| 324 | void print(raw_ostream &os) const { |
| 325 | os << "\nMemRefDependenceGraph\n"; |
| 326 | os << "\nNodes:\n"; |
| 327 | for (auto &idAndNode : nodes) { |
| 328 | os << "Node: " << idAndNode.first << "\n"; |
| 329 | auto it = inEdges.find(idAndNode.first); |
| 330 | if (it != inEdges.end()) { |
| 331 | for (const auto &e : it->second) |
| 332 | os << " InEdge: " << e.id << " " << e.memref << "\n"; |
| 333 | } |
| 334 | it = outEdges.find(idAndNode.first); |
| 335 | if (it != outEdges.end()) { |
| 336 | for (const auto &e : it->second) |
| 337 | os << " OutEdge: " << e.id << " " << e.memref << "\n"; |
| 338 | } |
| 339 | } |
| 340 | } |
| 341 | void dump() const { print(llvm::errs()); } |
| 342 | }; |
| 343 | |
| 344 | // Intializes the data dependence graph by walking statements in 'f'. |
| 345 | // Assigns each node in the graph a node id based on program order in 'f'. |
Chris Lattner | 315a466 | 2018-12-28 21:07:39 | [diff] [blame^] | 346 | // TODO(andydavis) Add support for taking a Block arg to construct the |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 347 | // dependence graph at a different depth. |
Chris Lattner | 69d9e99 | 2018-12-28 16:48:09 | [diff] [blame] | 348 | bool MemRefDependenceGraph::init(Function *f) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 349 | unsigned id = 0; |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 350 | DenseMap<Value *, SetVector<unsigned>> memrefAccesses; |
Chris Lattner | d613f5a | 2018-12-26 19:21:53 | [diff] [blame] | 351 | for (auto &stmt : *f->getBody()) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 352 | if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { |
| 353 | // Create graph node 'id' to represent top-level 'forStmt' and record |
| 354 | // all loads and store accesses it contains. |
| 355 | LoopNestStateCollector collector; |
| 356 | collector.walkForStmt(forStmt); |
| 357 | // Return false if IfStmts are found (not currently supported). |
| 358 | if (collector.hasIfStmt) |
| 359 | return false; |
| 360 | Node node(id++, &stmt); |
| 361 | for (auto *opStmt : collector.loadOpStmts) { |
| 362 | node.loads.push_back(opStmt); |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 363 | auto *memref = opStmt->cast<LoadOp>()->getMemRef(); |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 364 | memrefAccesses[memref].insert(node.id); |
| 365 | } |
| 366 | for (auto *opStmt : collector.storeOpStmts) { |
| 367 | node.stores.push_back(opStmt); |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 368 | auto *memref = opStmt->cast<StoreOp>()->getMemRef(); |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 369 | memrefAccesses[memref].insert(node.id); |
| 370 | } |
| 371 | nodes.insert({node.id, node}); |
| 372 | } |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 373 | if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 374 | if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { |
| 375 | // Create graph node for top-level load op. |
| 376 | Node node(id++, &stmt); |
| 377 | node.loads.push_back(opStmt); |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 378 | auto *memref = opStmt->cast<LoadOp>()->getMemRef(); |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 379 | memrefAccesses[memref].insert(node.id); |
| 380 | nodes.insert({node.id, node}); |
| 381 | } |
| 382 | if (auto storeOp = opStmt->dyn_cast<StoreOp>()) { |
| 383 | // Create graph node for top-level store op. |
| 384 | Node node(id++, &stmt); |
| 385 | node.stores.push_back(opStmt); |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 386 | auto *memref = opStmt->cast<StoreOp>()->getMemRef(); |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 387 | memrefAccesses[memref].insert(node.id); |
| 388 | nodes.insert({node.id, node}); |
| 389 | } |
| 390 | } |
| 391 | // Return false if IfStmts are found (not currently supported). |
| 392 | if (isa<IfStmt>(&stmt)) |
| 393 | return false; |
| 394 | } |
| 395 | |
| 396 | // Walk memref access lists and add graph edges between dependent nodes. |
| 397 | for (auto &memrefAndList : memrefAccesses) { |
| 398 | unsigned n = memrefAndList.second.size(); |
| 399 | for (unsigned i = 0; i < n; ++i) { |
| 400 | unsigned srcId = memrefAndList.second[i]; |
| 401 | bool srcHasStore = |
| 402 | getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; |
| 403 | for (unsigned j = i + 1; j < n; ++j) { |
| 404 | unsigned dstId = memrefAndList.second[j]; |
| 405 | bool dstHasStore = |
| 406 | getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; |
| 407 | if (srcHasStore || dstHasStore) |
| 408 | addEdge(srcId, dstId, memrefAndList.first); |
| 409 | } |
| 410 | } |
| 411 | } |
| 412 | return true; |
| 413 | } |
| 414 | |
| 415 | // GreedyFusion greedily fuses loop nests which have a producer/consumer |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 416 | // relationship on a memref, with the goal of improving locality. Currently, |
| 417 | // this the producer/consumer relationship is required to be unique in the |
Chris Lattner | 69d9e99 | 2018-12-28 16:48:09 | [diff] [blame] | 418 | // Function (there are TODOs to relax this constraint in the future). |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 419 | // |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 420 | // The steps of the algorithm are as follows: |
| 421 | // |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 422 | // *) A worklist is initialized with node ids from the dependence graph. |
| 423 | // *) For each node id in the worklist: |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 424 | // *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate |
| 425 | // destination ForStmt into which fusion will be attempted. |
| 426 | // *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'. |
| 427 | // *) For each LoadOp in 'dstLoadOps' do: |
Chris Lattner | 69d9e99 | 2018-12-28 16:48:09 | [diff] [blame] | 428 | // *) Lookup dependent loop nests at earlier positions in the Function |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 429 | // which have a single store op to the same memref. |
| 430 | // *) Check if dependences would be violated by the fusion. For example, |
| 431 | // the src loop nest may load from memrefs which are different than |
| 432 | // the producer-consumer memref between src and dest loop nests. |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 433 | // *) Get a computation slice of 'srcLoopNest', which adjusts its loop |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 434 | // bounds to be functions of 'dstLoopNest' IVs and symbols. |
| 435 | // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', |
| 436 | // just before the dst load op user. |
| 437 | // *) Add the newly fused load/store operation statements to the state, |
| 438 | // and also add newly fuse load ops to 'dstLoopOps' to be considered |
| 439 | // as fusion dst load ops in another iteration. |
| 440 | // *) Remove old src loop nest and its associated state. |
| 441 | // |
| 442 | // Given a graph where top-level statements are vertices in the set 'V' and |
| 443 | // edges in the set 'E' are dependences between vertices, this algorithm |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 444 | // takes O(V) time for initialization, and has runtime O(V + E). |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 445 | // |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 446 | // This greedy algorithm is not 'maximal' due to the current restriction of |
| 447 | // fusing along single producer consumer edges, but there is a TODO to fix this. |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 448 | // |
| 449 | // TODO(andydavis) Experiment with other fusion policies. |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 450 | // TODO(andydavis) Add support for fusing for input reuse (perhaps by |
| 451 | // constructing a graph with edges which represent loads from the same memref |
| 452 | // in two different loop nestst. |
| 453 | struct GreedyFusion { |
| 454 | public: |
| 455 | MemRefDependenceGraph *mdg; |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 456 | SmallVector<unsigned, 4> worklist; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 457 | |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 458 | GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) { |
| 459 | // Initialize worklist with nodes from 'mdg'. |
| 460 | worklist.resize(mdg->nodes.size()); |
| 461 | std::iota(worklist.begin(), worklist.end(), 0); |
| 462 | } |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 463 | |
| 464 | void run() { |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 465 | while (!worklist.empty()) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 466 | unsigned dstId = worklist.back(); |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 467 | worklist.pop_back(); |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 468 | // Skip if this node was removed (fused into another node). |
| 469 | if (mdg->nodes.count(dstId) == 0) |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 470 | continue; |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 471 | // Get 'dstNode' into which to attempt fusion. |
| 472 | auto *dstNode = mdg->getNode(dstId); |
| 473 | // Skip if 'dstNode' is not a loop nest. |
| 474 | if (!isa<ForStmt>(dstNode->stmt)) |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 475 | continue; |
| 476 | |
Chris Lattner | 5187cfc | 2018-12-28 05:21:41 | [diff] [blame] | 477 | SmallVector<OperationInst *, 4> loads = dstNode->loads; |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 478 | while (!loads.empty()) { |
| 479 | auto *dstLoadOpStmt = loads.pop_back_val(); |
Chris Lattner | 3f19031 | 2018-12-27 22:35:10 | [diff] [blame] | 480 | auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef(); |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 481 | // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'. |
| 482 | if (dstNode->getLoadOpCount(memref) != 1) |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 483 | continue; |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 484 | // Skip if no input edges along which to fuse. |
| 485 | if (mdg->inEdges.count(dstId) == 0) |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 486 | continue; |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 487 | // Iterate through in edges for 'dstId'. |
| 488 | for (auto &srcEdge : mdg->inEdges[dstId]) { |
| 489 | // Skip 'srcEdge' if not for 'memref'. |
| 490 | if (srcEdge.memref != memref) |
| 491 | continue; |
| 492 | auto *srcNode = mdg->getNode(srcEdge.id); |
| 493 | // Skip if 'srcNode' is not a loop nest. |
| 494 | if (!isa<ForStmt>(srcNode->stmt)) |
| 495 | continue; |
| 496 | // Skip if 'srcNode' has more than one store to 'memref'. |
| 497 | if (srcNode->getStoreOpCount(memref) != 1) |
| 498 | continue; |
| 499 | // Skip 'srcNode' if it has out edges on 'memref' other than 'dstId'. |
| 500 | if (mdg->getOutEdgeCount(srcNode->id, memref) != 1) |
| 501 | continue; |
| 502 | // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly |
| 503 | // TODO(andydavis) Track dependence type with edges, and just check |
| 504 | // for WAW dependence edge here. |
| 505 | if (mdg->getInEdgeCount(srcNode->id, memref) != 0) |
| 506 | continue; |
| 507 | // Skip if 'srcNode' has out edges to other memrefs after 'dstId'. |
| 508 | if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId) |
| 509 | continue; |
| 510 | // Get unique 'srcNode' store op. |
| 511 | auto *srcStoreOpStmt = srcNode->stores.front(); |
| 512 | // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'. |
| 513 | FusionCandidate candidate = |
| 514 | buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt); |
| 515 | // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. |
MLIR Team | 4eef795 | 2018-12-21 19:06:23 | [diff] [blame] | 516 | unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0 |
| 517 | ? clSrcLoopDepth |
| 518 | : getLoopDepth(srcStoreOpStmt); |
| 519 | unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0 |
| 520 | ? clDstLoopDepth |
| 521 | : getLoopDepth(dstLoadOpStmt); |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 522 | auto *sliceLoopNest = mlir::insertBackwardComputationSlice( |
MLIR Team | 4eef795 | 2018-12-21 19:06:23 | [diff] [blame] | 523 | &candidate.srcAccess, &candidate.dstAccess, srcLoopDepth, |
| 524 | dstLoopDepth); |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 525 | if (sliceLoopNest != nullptr) { |
| 526 | // Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode' |
| 527 | mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id); |
| 528 | // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'. |
| 529 | LoopNestStateCollector collector; |
| 530 | collector.walkForStmt(sliceLoopNest); |
| 531 | mdg->addToNode(dstId, collector.loadOpStmts, |
| 532 | collector.storeOpStmts); |
| 533 | // Add new load ops to current Node load op list 'loads' to |
| 534 | // continue fusing based on new operands. |
| 535 | for (auto *loadOpStmt : collector.loadOpStmts) |
| 536 | loads.push_back(loadOpStmt); |
| 537 | // Promote single iteration loops to single IV value. |
| 538 | for (auto *forStmt : collector.forStmts) { |
| 539 | promoteIfSingleIteration(forStmt); |
| 540 | } |
| 541 | // Remove old src loop nest. |
| 542 | cast<ForStmt>(srcNode->stmt)->erase(); |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 543 | } |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 544 | } |
| 545 | } |
| 546 | } |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 547 | } |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame] | 548 | }; |
| 549 | |
| 550 | } // end anonymous namespace |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 551 | |
Chris Lattner | 69d9e99 | 2018-12-28 16:48:09 | [diff] [blame] | 552 | PassResult LoopFusion::runOnMLFunction(Function *f) { |
MLIR Team | 6892ffb | 2018-12-20 04:42:55 | [diff] [blame] | 553 | MemRefDependenceGraph g; |
| 554 | if (g.init(f)) |
| 555 | GreedyFusion(&g).run(); |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 556 | return success(); |
| 557 | } |
Jacques Pienaar | 6f0fb22 | 2018-11-07 02:34:18 | [diff] [blame] | 558 | |
| 559 | static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests"); |