MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 1 | //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// |
| 2 | // |
| 3 | // Copyright 2019 The MLIR Authors. |
| 4 | // |
| 5 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | // you may not use this file except in compliance with the License. |
| 7 | // You may obtain a copy of the License at |
| 8 | // |
| 9 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | // |
| 11 | // Unless required by applicable law or agreed to in writing, software |
| 12 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | // See the License for the specific language governing permissions and |
| 15 | // limitations under the License. |
| 16 | // ============================================================================= |
| 17 | // |
| 18 | // This file implements loop fusion. |
| 19 | // |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | |
| 22 | #include "mlir/Analysis/AffineAnalysis.h" |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 23 | #include "mlir/Analysis/AffineStructures.h" |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 24 | #include "mlir/Analysis/LoopAnalysis.h" |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 25 | #include "mlir/Analysis/Utils.h" |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 26 | #include "mlir/IR/AffineExpr.h" |
| 27 | #include "mlir/IR/AffineMap.h" |
| 28 | #include "mlir/IR/Builders.h" |
| 29 | #include "mlir/IR/BuiltinOps.h" |
| 30 | #include "mlir/IR/StmtVisitor.h" |
| 31 | #include "mlir/Pass.h" |
| 32 | #include "mlir/StandardOps/StandardOps.h" |
| 33 | #include "mlir/Transforms/LoopUtils.h" |
| 34 | #include "mlir/Transforms/Passes.h" |
| 35 | #include "llvm/ADT/DenseMap.h" |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 36 | #include "llvm/ADT/DenseSet.h" |
| 37 | #include "llvm/ADT/SetVector.h" |
| 38 | #include "llvm/Support/raw_ostream.h" |
| 39 | |
| 40 | using llvm::SetVector; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 41 | |
| 42 | using namespace mlir; |
| 43 | |
| 44 | namespace { |
| 45 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 46 | /// Loop fusion pass. This pass currently supports a greedy fusion policy, |
| 47 | /// which fuses loop nests with single-writer/single-reader memref dependences |
| 48 | /// with the goal of improving locality. |
| 49 | |
| 50 | // TODO(andydavis) Support fusion of source loop nests which write to multiple |
| 51 | // memrefs, where each memref can have multiple users (if profitable). |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 52 | // TODO(andydavis) Extend this pass to check for fusion preventing dependences, |
| 53 | // and add support for more general loop fusion algorithms. |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 54 | |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 55 | struct LoopFusion : public FunctionPass { |
Jacques Pienaar | cc9a6ed | 2018-11-07 18:24:03 | [diff] [blame] | 56 | LoopFusion() : FunctionPass(&LoopFusion::passID) {} |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 57 | |
| 58 | PassResult runOnMLFunction(MLFunction *f) override; |
Jacques Pienaar | 6f0fb22 | 2018-11-07 02:34:18 | [diff] [blame] | 59 | static char passID; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 60 | }; |
| 61 | |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 62 | } // end anonymous namespace |
| 63 | |
Jacques Pienaar | 6f0fb22 | 2018-11-07 02:34:18 | [diff] [blame] | 64 | char LoopFusion::passID = 0; |
| 65 | |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 66 | FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } |
| 67 | |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 68 | static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, |
| 69 | MemRefAccess *access) { |
| 70 | if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { |
| 71 | access->memref = cast<MLValue>(loadOp->getMemRef()); |
| 72 | access->opStmt = loadOrStoreOpStmt; |
| 73 | auto loadMemrefType = loadOp->getMemRefType(); |
| 74 | access->indices.reserve(loadMemrefType.getRank()); |
| 75 | for (auto *index : loadOp->getIndices()) { |
| 76 | access->indices.push_back(cast<MLValue>(index)); |
| 77 | } |
| 78 | } else { |
| 79 | assert(loadOrStoreOpStmt->isa<StoreOp>()); |
| 80 | auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>(); |
| 81 | access->opStmt = loadOrStoreOpStmt; |
| 82 | access->memref = cast<MLValue>(storeOp->getMemRef()); |
| 83 | auto storeMemrefType = storeOp->getMemRefType(); |
| 84 | access->indices.reserve(storeMemrefType.getRank()); |
| 85 | for (auto *index : storeOp->getIndices()) { |
| 86 | access->indices.push_back(cast<MLValue>(index)); |
| 87 | } |
| 88 | } |
| 89 | } |
| 90 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 91 | // FusionCandidate encapsulates source and destination memref access within |
| 92 | // loop nests which are candidates for loop fusion. |
| 93 | struct FusionCandidate { |
| 94 | // Load or store access within src loop nest to be fused into dst loop nest. |
| 95 | MemRefAccess srcAccess; |
| 96 | // Load or store access within dst loop nest. |
| 97 | MemRefAccess dstAccess; |
| 98 | }; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 99 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 100 | static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt, |
| 101 | OperationStmt *dstLoadOpStmt) { |
| 102 | FusionCandidate candidate; |
| 103 | // Get store access for src loop nest. |
| 104 | getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess); |
| 105 | // Get load access for dst loop nest. |
| 106 | getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess); |
| 107 | return candidate; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 108 | } |
| 109 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 110 | namespace { |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 111 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 112 | // LoopNestStateCollector walks loop nests and collects load and store |
| 113 | // operations, and whether or not an IfStmt was encountered in the loop nest. |
| 114 | class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> { |
| 115 | public: |
| 116 | SmallVector<ForStmt *, 4> forStmts; |
| 117 | SmallVector<OperationStmt *, 4> loadOpStmts; |
| 118 | SmallVector<OperationStmt *, 4> storeOpStmts; |
| 119 | bool hasIfStmt = false; |
| 120 | |
| 121 | void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } |
| 122 | |
| 123 | void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } |
| 124 | |
| 125 | void visitOperationStmt(OperationStmt *opStmt) { |
| 126 | if (opStmt->isa<LoadOp>()) |
| 127 | loadOpStmts.push_back(opStmt); |
| 128 | if (opStmt->isa<StoreOp>()) |
| 129 | storeOpStmts.push_back(opStmt); |
| 130 | } |
| 131 | }; |
| 132 | |
| 133 | // GreedyFusionPolicy greedily fuses loop nests which have a producer/consumer |
| 134 | // relationship on a memref, with the goal of improving locality. Currently, |
| 135 | // this the producer/consumer relationship is required to be unique in the |
| 136 | // MLFunction (there are TODOs to relax this constraint in the future). |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 137 | // |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 138 | // The steps of the algorithm are as follows: |
| 139 | // |
| 140 | // *) Initialize. While visiting each statement in the MLFunction do: |
| 141 | // *) Assign each top-level ForStmt a 'position' which is its initial |
| 142 | // position in the MLFunction's StmtBlock at the start of the pass. |
| 143 | // *) Gather memref load/store state aggregated by top-level statement. For |
| 144 | // example, all loads and stores contained in a loop nest are aggregated |
| 145 | // under the loop nest's top-level ForStmt. |
| 146 | // *) Add each top-level ForStmt to a worklist. |
| 147 | // |
| 148 | // *) Run. The algorithm processes the worklist with the following steps: |
| 149 | // *) The worklist is processed in reverse order (starting from the last |
| 150 | // top-level ForStmt in the MLFunction). |
| 151 | // *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate |
| 152 | // destination ForStmt into which fusion will be attempted. |
| 153 | // *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'. |
| 154 | // *) For each LoadOp in 'dstLoadOps' do: |
| 155 | // *) Lookup dependent loop nests at earlier positions in the MLFunction |
| 156 | // which have a single store op to the same memref. |
| 157 | // *) Check if dependences would be violated by the fusion. For example, |
| 158 | // the src loop nest may load from memrefs which are different than |
| 159 | // the producer-consumer memref between src and dest loop nests. |
| 160 | // *) Get a computation slice of 'srcLoopNest', which adjust its loop |
| 161 | // bounds to be functions of 'dstLoopNest' IVs and symbols. |
| 162 | // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', |
| 163 | // just before the dst load op user. |
| 164 | // *) Add the newly fused load/store operation statements to the state, |
| 165 | // and also add newly fuse load ops to 'dstLoopOps' to be considered |
| 166 | // as fusion dst load ops in another iteration. |
| 167 | // *) Remove old src loop nest and its associated state. |
| 168 | // |
| 169 | // Given a graph where top-level statements are vertices in the set 'V' and |
| 170 | // edges in the set 'E' are dependences between vertices, this algorithm |
| 171 | // takes O(V) time for initialization, and has runtime O(V * E). |
| 172 | // TODO(andydavis) Reduce this time complexity to O(V + E). |
| 173 | // |
| 174 | // This greedy algorithm is not 'maximally' but there is a TODO to fix this. |
| 175 | // |
| 176 | // TODO(andydavis) Experiment with other fusion policies. |
| 177 | struct GreedyFusionPolicy { |
| 178 | // Convenience wrapper with information about 'stmt' ready to access. |
| 179 | struct StmtInfo { |
| 180 | Statement *stmt; |
| 181 | bool isOrContainsIfStmt = false; |
| 182 | }; |
| 183 | // The worklist of top-level loop nest positions. |
| 184 | SmallVector<unsigned, 4> worklist; |
| 185 | // Mapping from top-level position to StmtInfo. |
| 186 | DenseMap<unsigned, StmtInfo> posToStmtInfo; |
| 187 | // Mapping from memref MLValue to set of top-level positions of loop nests |
| 188 | // which contain load ops on that memref. |
| 189 | DenseMap<MLValue *, DenseSet<unsigned>> memrefToLoadPosSet; |
| 190 | // Mapping from memref MLValue to set of top-level positions of loop nests |
| 191 | // which contain store ops on that memref. |
| 192 | DenseMap<MLValue *, DenseSet<unsigned>> memrefToStorePosSet; |
| 193 | // Mapping from top-level loop nest to the set of load ops it contains. |
| 194 | DenseMap<ForStmt *, SetVector<OperationStmt *>> forStmtToLoadOps; |
| 195 | // Mapping from top-level loop nest to the set of store ops it contains. |
| 196 | DenseMap<ForStmt *, SetVector<OperationStmt *>> forStmtToStoreOps; |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 197 | |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 198 | GreedyFusionPolicy(MLFunction *f) { init(f); } |
| 199 | |
| 200 | void run() { |
| 201 | if (hasIfStmts()) |
| 202 | return; |
| 203 | |
| 204 | while (!worklist.empty()) { |
| 205 | // Pop the position of a loop nest into which fusion will be attempted. |
| 206 | unsigned dstPos = worklist.back(); |
| 207 | worklist.pop_back(); |
| 208 | // Skip if 'dstPos' is not tracked (was fused into another loop nest). |
| 209 | if (posToStmtInfo.count(dstPos) == 0) |
| 210 | continue; |
| 211 | // Get the top-level ForStmt at 'dstPos'. |
| 212 | auto *dstForStmt = getForStmtAtPos(dstPos); |
| 213 | // Skip if this ForStmt contains no load ops. |
| 214 | if (forStmtToLoadOps.count(dstForStmt) == 0) |
| 215 | continue; |
| 216 | |
| 217 | // Greedy Policy: iterate through load ops in 'dstForStmt', greedily |
| 218 | // fusing in src loop nests which have a single store op on the same |
| 219 | // memref, until a fixed point is reached where there is nothing left to |
| 220 | // fuse. |
| 221 | SetVector<OperationStmt *> dstLoadOps = forStmtToLoadOps[dstForStmt]; |
| 222 | while (!dstLoadOps.empty()) { |
| 223 | auto *dstLoadOpStmt = dstLoadOps.pop_back_val(); |
| 224 | |
| 225 | auto dstLoadOp = dstLoadOpStmt->cast<LoadOp>(); |
| 226 | auto *memref = cast<MLValue>(dstLoadOp->getMemRef()); |
| 227 | // Skip if not single src store / dst load pair on 'memref'. |
| 228 | if (memrefToLoadPosSet[memref].size() != 1 || |
| 229 | memrefToStorePosSet[memref].size() != 1) |
| 230 | continue; |
| 231 | unsigned srcPos = *memrefToStorePosSet[memref].begin(); |
| 232 | if (srcPos >= dstPos) |
| 233 | continue; |
| 234 | auto *srcForStmt = getForStmtAtPos(srcPos); |
| 235 | // Skip if 'srcForStmt' has more than one store op. |
| 236 | if (forStmtToStoreOps[srcForStmt].size() > 1) |
| 237 | continue; |
| 238 | // Skip if fusion would violated dependences between 'memref' access |
| 239 | // for loop nests between 'srcPos' and 'dstPos': |
| 240 | // For each src load op: check for store ops in range (srcPos, dstPos). |
| 241 | // For each src store op: check for load ops in range (srcPos, dstPos). |
| 242 | if (moveWouldViolateDependences(srcPos, dstPos)) |
| 243 | continue; |
| 244 | auto *srcStoreOpStmt = forStmtToStoreOps[srcForStmt].front(); |
| 245 | // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'. |
| 246 | FusionCandidate candidate = |
| 247 | buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt); |
| 248 | // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. |
| 249 | auto *sliceLoopNest = mlir::insertBackwardComputationSlice( |
| 250 | &candidate.srcAccess, &candidate.dstAccess); |
| 251 | if (sliceLoopNest != nullptr) { |
| 252 | // Remove 'srcPos' mappings from 'state'. |
| 253 | moveAccessesAndRemovePos(srcPos, dstPos); |
| 254 | // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'. |
| 255 | LoopNestStateCollector collector; |
| 256 | collector.walkForStmt(sliceLoopNest); |
| 257 | // Record mappings for loads and stores from 'collector'. |
| 258 | for (auto *opStmt : collector.loadOpStmts) { |
| 259 | addLoadOpStmtAt(dstPos, opStmt, dstForStmt); |
| 260 | // Add newly fused load ops to 'dstLoadOps' to be considered for |
| 261 | // fusion on subsequent iterations. |
| 262 | dstLoadOps.insert(opStmt); |
| 263 | } |
| 264 | for (auto *opStmt : collector.storeOpStmts) { |
| 265 | addStoreOpStmtAt(dstPos, opStmt, dstForStmt); |
| 266 | } |
| 267 | for (auto *forStmt : collector.forStmts) { |
| 268 | promoteIfSingleIteration(forStmt); |
| 269 | } |
| 270 | // Remove old src loop nest. |
| 271 | srcForStmt->erase(); |
| 272 | } |
| 273 | } |
| 274 | } |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 275 | } |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 276 | |
| 277 | // Walk MLFunction 'f' assigning each top-level statement a position, and |
| 278 | // gathering state on load and store ops. |
| 279 | void init(MLFunction *f) { |
| 280 | unsigned pos = 0; |
| 281 | for (auto &stmt : *f) { |
| 282 | if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { |
| 283 | // Record all loads and store accesses in 'forStmt' at 'pos'. |
| 284 | LoopNestStateCollector collector; |
| 285 | collector.walkForStmt(forStmt); |
| 286 | // Create StmtInfo for 'forStmt' for top-level loop nests. |
| 287 | addStmtInfoAt(pos, forStmt, collector.hasIfStmt); |
| 288 | // Record mappings for loads and stores from 'collector'. |
| 289 | for (auto *opStmt : collector.loadOpStmts) { |
| 290 | addLoadOpStmtAt(pos, opStmt, forStmt); |
| 291 | } |
| 292 | for (auto *opStmt : collector.storeOpStmts) { |
| 293 | addStoreOpStmtAt(pos, opStmt, forStmt); |
| 294 | } |
| 295 | // Add 'pos' associated with 'forStmt' to worklist. |
| 296 | worklist.push_back(pos); |
| 297 | } |
| 298 | if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { |
| 299 | if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { |
| 300 | // Create StmtInfo for top-level load op. |
| 301 | addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false); |
| 302 | addLoadOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr); |
| 303 | } |
| 304 | if (auto storeOp = opStmt->dyn_cast<StoreOp>()) { |
| 305 | // Create StmtInfo for top-level store op. |
| 306 | addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false); |
| 307 | addStoreOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr); |
| 308 | } |
| 309 | } |
| 310 | if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) { |
| 311 | addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/true); |
| 312 | } |
| 313 | ++pos; |
| 314 | } |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 315 | } |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 316 | |
| 317 | // Check if fusing loop nest at 'srcPos' into the loop nest at 'dstPos' |
| 318 | // would violated any dependences w.r.t other loop nests in that range. |
| 319 | bool moveWouldViolateDependences(unsigned srcPos, unsigned dstPos) { |
| 320 | // Lookup src ForStmt at 'srcPos'. |
| 321 | auto *srcForStmt = getForStmtAtPos(srcPos); |
| 322 | // For each src load op: check for store ops in range (srcPos, dstPos). |
| 323 | if (forStmtToLoadOps.count(srcForStmt) > 0) { |
| 324 | for (auto *opStmt : forStmtToLoadOps[srcForStmt]) { |
| 325 | auto loadOp = opStmt->cast<LoadOp>(); |
| 326 | auto *memref = cast<MLValue>(loadOp->getMemRef()); |
| 327 | for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) { |
| 328 | if (memrefToStorePosSet.count(memref) > 0 && |
| 329 | memrefToStorePosSet[memref].count(pos) > 0) |
| 330 | return true; |
| 331 | } |
| 332 | } |
| 333 | } |
| 334 | // For each src store op: check for load ops in range (srcPos, dstPos). |
| 335 | if (forStmtToStoreOps.count(srcForStmt) > 0) { |
| 336 | for (auto *opStmt : forStmtToStoreOps[srcForStmt]) { |
| 337 | auto storeOp = opStmt->cast<StoreOp>(); |
| 338 | auto *memref = cast<MLValue>(storeOp->getMemRef()); |
| 339 | for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) { |
| 340 | if (memrefToLoadPosSet.count(memref) > 0 && |
| 341 | memrefToLoadPosSet[memref].count(pos) > 0) |
| 342 | return true; |
| 343 | } |
| 344 | } |
| 345 | } |
| 346 | return false; |
| 347 | } |
| 348 | |
| 349 | // Update mappings of memref loads and stores at 'srcPos' to 'dstPos'. |
| 350 | void moveAccessesAndRemovePos(unsigned srcPos, unsigned dstPos) { |
| 351 | // Lookup ForStmt at 'srcPos'. |
| 352 | auto *srcForStmt = getForStmtAtPos(srcPos); |
| 353 | // Move load op accesses from src to dst. |
| 354 | if (forStmtToLoadOps.count(srcForStmt) > 0) { |
| 355 | for (auto *opStmt : forStmtToLoadOps[srcForStmt]) { |
| 356 | auto loadOp = opStmt->cast<LoadOp>(); |
| 357 | auto *memref = cast<MLValue>(loadOp->getMemRef()); |
| 358 | // Remove 'memref' to 'srcPos' mapping. |
| 359 | memrefToLoadPosSet[memref].erase(srcPos); |
| 360 | } |
| 361 | } |
| 362 | // Move store op accesses from src to dst. |
| 363 | if (forStmtToStoreOps.count(srcForStmt) > 0) { |
| 364 | for (auto *opStmt : forStmtToStoreOps[srcForStmt]) { |
| 365 | auto storeOp = opStmt->cast<StoreOp>(); |
| 366 | auto *memref = cast<MLValue>(storeOp->getMemRef()); |
| 367 | // Remove 'memref' to 'srcPos' mapping. |
| 368 | memrefToStorePosSet[memref].erase(srcPos); |
| 369 | } |
| 370 | } |
| 371 | // Remove old state. |
| 372 | forStmtToLoadOps.erase(srcForStmt); |
| 373 | forStmtToStoreOps.erase(srcForStmt); |
| 374 | posToStmtInfo.erase(srcPos); |
| 375 | } |
| 376 | |
| 377 | ForStmt *getForStmtAtPos(unsigned pos) { |
| 378 | assert(posToStmtInfo.count(pos) > 0); |
| 379 | assert(isa<ForStmt>(posToStmtInfo[pos].stmt)); |
| 380 | return cast<ForStmt>(posToStmtInfo[pos].stmt); |
| 381 | } |
| 382 | |
| 383 | void addStmtInfoAt(unsigned pos, Statement *stmt, bool hasIfStmt) { |
| 384 | StmtInfo stmtInfo; |
| 385 | stmtInfo.stmt = stmt; |
| 386 | stmtInfo.isOrContainsIfStmt = hasIfStmt; |
| 387 | // Add mapping from 'pos' to StmtInfo for 'forStmt'. |
| 388 | posToStmtInfo[pos] = stmtInfo; |
| 389 | } |
| 390 | |
| 391 | // Adds the following mappings: |
| 392 | // *) 'containingForStmt' to load 'opStmt' |
| 393 | // *) 'memref' of load 'opStmt' to 'topLevelPos'. |
| 394 | void addLoadOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt, |
| 395 | ForStmt *containingForStmt) { |
| 396 | if (containingForStmt != nullptr) { |
| 397 | // Add mapping from 'containingForStmt' to 'opStmt' for load op. |
| 398 | forStmtToLoadOps[containingForStmt].insert(opStmt); |
| 399 | } |
| 400 | auto loadOp = opStmt->cast<LoadOp>(); |
| 401 | auto *memref = cast<MLValue>(loadOp->getMemRef()); |
| 402 | // Add mapping from 'memref' to 'topLevelPos' for load. |
| 403 | memrefToLoadPosSet[memref].insert(topLevelPos); |
| 404 | } |
| 405 | |
| 406 | // Adds the following mappings: |
| 407 | // *) 'containingForStmt' to store 'opStmt' |
| 408 | // *) 'memref' of store 'opStmt' to 'topLevelPos'. |
| 409 | void addStoreOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt, |
| 410 | ForStmt *containingForStmt) { |
| 411 | if (containingForStmt != nullptr) { |
| 412 | // Add mapping from 'forStmt' to 'opStmt' for store op. |
| 413 | forStmtToStoreOps[containingForStmt].insert(opStmt); |
| 414 | } |
| 415 | auto storeOp = opStmt->cast<StoreOp>(); |
| 416 | auto *memref = cast<MLValue>(storeOp->getMemRef()); |
| 417 | // Add mapping from 'memref' to 'topLevelPos' for store. |
| 418 | memrefToStorePosSet[memref].insert(topLevelPos); |
| 419 | } |
| 420 | |
| 421 | bool hasIfStmts() { |
| 422 | for (auto &pair : posToStmtInfo) |
| 423 | if (pair.second.isOrContainsIfStmt) |
| 424 | return true; |
| 425 | return false; |
| 426 | } |
| 427 | }; |
| 428 | |
| 429 | } // end anonymous namespace |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 430 | |
| 431 | PassResult LoopFusion::runOnMLFunction(MLFunction *f) { |
MLIR Team | 3b69230 | 2018-12-17 17:57:14 | [diff] [blame^] | 432 | GreedyFusionPolicy(f).run(); |
MLIR Team | f28e4df | 2018-11-01 14:26:00 | [diff] [blame] | 433 | return success(); |
| 434 | } |
Jacques Pienaar | 6f0fb22 | 2018-11-07 02:34:18 | [diff] [blame] | 435 | |
| 436 | static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests"); |