blob: 804acba0d5aee5f7f7b8c62c5a62cd9028960e9b [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1423#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0024#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1425#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0026#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinOps.h"
Chris Lattner456ad6a2018-12-29 00:05:3530#include "mlir/IR/InstVisitor.h"
MLIR Teamf28e4df2018-11-01 14:26:0031#include "mlir/Pass.h"
32#include "mlir/StandardOps/StandardOps.h"
33#include "mlir/Transforms/LoopUtils.h"
34#include "mlir/Transforms/Passes.h"
35#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1436#include "llvm/ADT/DenseSet.h"
37#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2338#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2539#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1440#include "llvm/Support/raw_ostream.h"
41
MLIR Team38c2fe32019-01-14 19:26:2542#define DEBUG_TYPE "loop-fusion"
43
MLIR Team3b692302018-12-17 17:57:1444using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0045
46using namespace mlir;
47
48namespace {
49
MLIR Team3b692302018-12-17 17:57:1450/// Loop fusion pass. This pass currently supports a greedy fusion policy,
51/// which fuses loop nests with single-writer/single-reader memref dependences
52/// with the goal of improving locality.
53
54// TODO(andydavis) Support fusion of source loop nests which write to multiple
55// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0056// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
57// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1458
MLIR Teamf28e4df2018-11-01 14:26:0059struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0360 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0061
Chris Lattner79748892018-12-31 07:10:3562 PassResult runOnFunction(Function *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1863 static char passID;
MLIR Teamf28e4df2018-11-01 14:26:0064};
65
MLIR Teamf28e4df2018-11-01 14:26:0066} // end anonymous namespace
67
Jacques Pienaar6f0fb222018-11-07 02:34:1868char LoopFusion::passID = 0;
69
MLIR Teamf28e4df2018-11-01 14:26:0070FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
71
MLIR Team3b692302018-12-17 17:57:1472namespace {
MLIR Teamf28e4df2018-11-01 14:26:0073
MLIR Team3b692302018-12-17 17:57:1474// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:3575// operations, and whether or not an IfInst was encountered in the loop nest.
76class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
MLIR Team3b692302018-12-17 17:57:1477public:
Chris Lattner456ad6a2018-12-29 00:05:3578 SmallVector<ForInst *, 4> forInsts;
79 SmallVector<OperationInst *, 4> loadOpInsts;
80 SmallVector<OperationInst *, 4> storeOpInsts;
81 bool hasIfInst = false;
MLIR Team3b692302018-12-17 17:57:1482
Chris Lattner456ad6a2018-12-29 00:05:3583 void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
MLIR Team3b692302018-12-17 17:57:1484
Chris Lattner456ad6a2018-12-29 00:05:3585 void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
MLIR Team3b692302018-12-17 17:57:1486
Chris Lattner456ad6a2018-12-29 00:05:3587 void visitOperationInst(OperationInst *opInst) {
88 if (opInst->isa<LoadOp>())
89 loadOpInsts.push_back(opInst);
90 if (opInst->isa<StoreOp>())
91 storeOpInsts.push_back(opInst);
MLIR Team3b692302018-12-17 17:57:1492 }
93};
94
MLIR Team6892ffb2018-12-20 04:42:5595// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:3596// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:5597// are memref dependences between the nodes.
98// TODO(andydavis) Add a depth parameter to dependence graph construction.
99struct MemRefDependenceGraph {
100public:
101 // Node represents a node in the graph. A Node is either an entire loop nest
102 // rooted at the top level which contains loads/stores, or a top level
103 // load/store.
104 struct Node {
105 // The unique identifier of this node in the graph.
106 unsigned id;
107 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35108 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41109 // List of load operations.
110 SmallVector<OperationInst *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35111 // List of store op insts.
Chris Lattner5187cfc2018-12-28 05:21:41112 SmallVector<OperationInst *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35113 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55114
115 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10116 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55117 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35118 for (auto *loadOpInst : loads) {
119 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55120 ++loadOpCount;
121 }
122 return loadOpCount;
123 }
124
125 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10126 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55127 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35128 for (auto *storeOpInst : stores) {
129 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55130 ++storeOpCount;
131 }
132 return storeOpCount;
133 }
134 };
135
136 // Edge represents a memref data dependece between nodes in the graph.
137 struct Edge {
138 // The id of the node at the other end of the edge.
139 unsigned id;
140 // The memref on which this edge represents a dependence.
Chris Lattner3f190312018-12-27 22:35:10141 Value *memref;
MLIR Team6892ffb2018-12-20 04:42:55142 };
143
144 // Map from node id to Node.
145 DenseMap<unsigned, Node> nodes;
146 // Map from node id to list of input edges.
147 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
148 // Map from node id to list of output edges.
149 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
150
151 MemRefDependenceGraph() {}
152
153 // Initializes the dependence graph based on operations in 'f'.
154 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09155 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55156
157 // Returns the graph node for 'id'.
158 Node *getNode(unsigned id) {
159 auto it = nodes.find(id);
160 assert(it != nodes.end());
161 return &it->second;
162 }
163
MLIR Team27d067e2019-01-16 17:55:02164 // Returns true iff there is an edge from node 'srcId' to node 'dstId' for
165 // 'memref'. Returns false otherwise.
166 bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) {
167 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
168 return false;
169 }
170 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
171 return edge.id == dstId && edge.memref == memref;
172 });
173 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
174 return edge.id == srcId && edge.memref == memref;
175 });
176 return hasOutEdge && hasInEdge;
177 }
178
MLIR Team6892ffb2018-12-20 04:42:55179 // Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10180 void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team27d067e2019-01-16 17:55:02181 if (!hasEdge(srcId, dstId, memref)) {
182 outEdges[srcId].push_back({dstId, memref});
183 inEdges[dstId].push_back({srcId, memref});
184 }
MLIR Team6892ffb2018-12-20 04:42:55185 }
186
187 // Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10188 void removeEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55189 assert(inEdges.count(dstId) > 0);
190 assert(outEdges.count(srcId) > 0);
191 // Remove 'srcId' from 'inEdges[dstId]'.
192 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
193 if ((*it).id == srcId && (*it).memref == memref) {
194 inEdges[dstId].erase(it);
195 break;
196 }
197 }
198 // Remove 'dstId' from 'outEdges[srcId]'.
199 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
200 if ((*it).id == dstId && (*it).memref == memref) {
201 outEdges[srcId].erase(it);
202 break;
203 }
204 }
205 }
206
207 // Returns the input edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10208 unsigned getInEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55209 unsigned inEdgeCount = 0;
210 if (inEdges.count(id) > 0)
211 for (auto &inEdge : inEdges[id])
212 if (inEdge.memref == memref)
213 ++inEdgeCount;
214 return inEdgeCount;
215 }
216
217 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10218 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55219 unsigned outEdgeCount = 0;
220 if (outEdges.count(id) > 0)
221 for (auto &outEdge : outEdges[id])
222 if (outEdge.memref == memref)
223 ++outEdgeCount;
224 return outEdgeCount;
225 }
226
227 // Returns the min node id of all output edges from node 'id'.
228 unsigned getMinOutEdgeNodeId(unsigned id) {
229 unsigned minId = std::numeric_limits<unsigned>::max();
230 if (outEdges.count(id) > 0)
231 for (auto &outEdge : outEdges[id])
232 minId = std::min(minId, outEdge.id);
233 return minId;
234 }
235
236 // Updates edge mappings from node 'srcId' to node 'dstId' and removes
237 // state associated with node 'srcId'.
238 void updateEdgesAndRemoveSrcNode(unsigned srcId, unsigned dstId) {
239 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
240 if (inEdges.count(srcId) > 0) {
241 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
242 for (auto &inEdge : oldInEdges) {
243 // Remove edge from 'inEdge.id' to 'srcId'.
244 removeEdge(inEdge.id, srcId, inEdge.memref);
245 // Add edge from 'inEdge.id' to 'dstId'.
246 addEdge(inEdge.id, dstId, inEdge.memref);
247 }
248 }
249 // For each edge in 'outEdges[srcId]': add new edge remaping to 'dstId'.
250 if (outEdges.count(srcId) > 0) {
251 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
252 for (auto &outEdge : oldOutEdges) {
253 // Remove edge from 'srcId' to 'outEdge.id'.
254 removeEdge(srcId, outEdge.id, outEdge.memref);
255 // Add edge from 'dstId' to 'outEdge.id' (if 'outEdge.id' != 'dstId').
256 if (outEdge.id != dstId)
257 addEdge(dstId, outEdge.id, outEdge.memref);
258 }
259 }
260 // Remove 'srcId' from graph state.
261 inEdges.erase(srcId);
262 outEdges.erase(srcId);
263 nodes.erase(srcId);
264 }
265
266 // Adds ops in 'loads' and 'stores' to node at 'id'.
Chris Lattner5187cfc2018-12-28 05:21:41267 void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
268 const SmallVectorImpl<OperationInst *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55269 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35270 for (auto *loadOpInst : loads)
271 node->loads.push_back(loadOpInst);
272 for (auto *storeOpInst : stores)
273 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55274 }
275
276 void print(raw_ostream &os) const {
277 os << "\nMemRefDependenceGraph\n";
278 os << "\nNodes:\n";
279 for (auto &idAndNode : nodes) {
280 os << "Node: " << idAndNode.first << "\n";
281 auto it = inEdges.find(idAndNode.first);
282 if (it != inEdges.end()) {
283 for (const auto &e : it->second)
284 os << " InEdge: " << e.id << " " << e.memref << "\n";
285 }
286 it = outEdges.find(idAndNode.first);
287 if (it != outEdges.end()) {
288 for (const auto &e : it->second)
289 os << " OutEdge: " << e.id << " " << e.memref << "\n";
290 }
291 }
292 }
293 void dump() const { print(llvm::errs()); }
294};
295
Chris Lattner456ad6a2018-12-29 00:05:35296// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55297// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39298// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55299// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09300bool MemRefDependenceGraph::init(Function *f) {
MLIR Team6892ffb2018-12-20 04:42:55301 unsigned id = 0;
Chris Lattner3f190312018-12-27 22:35:10302 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43303
304 // TODO: support multi-block functions.
305 if (f->getBlocks().size() != 1)
306 return false;
307
308 for (auto &inst : f->front()) {
Chris Lattner456ad6a2018-12-29 00:05:35309 if (auto *forInst = dyn_cast<ForInst>(&inst)) {
310 // Create graph node 'id' to represent top-level 'forInst' and record
MLIR Team6892ffb2018-12-20 04:42:55311 // all loads and store accesses it contains.
312 LoopNestStateCollector collector;
Chris Lattner456ad6a2018-12-29 00:05:35313 collector.walkForInst(forInst);
314 // Return false if IfInsts are found (not currently supported).
315 if (collector.hasIfInst)
MLIR Team6892ffb2018-12-20 04:42:55316 return false;
Chris Lattner456ad6a2018-12-29 00:05:35317 Node node(id++, &inst);
318 for (auto *opInst : collector.loadOpInsts) {
319 node.loads.push_back(opInst);
320 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55321 memrefAccesses[memref].insert(node.id);
322 }
Chris Lattner456ad6a2018-12-29 00:05:35323 for (auto *opInst : collector.storeOpInsts) {
324 node.stores.push_back(opInst);
325 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55326 memrefAccesses[memref].insert(node.id);
327 }
328 nodes.insert({node.id, node});
329 }
Chris Lattner456ad6a2018-12-29 00:05:35330 if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
331 if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55332 // Create graph node for top-level load op.
Chris Lattner456ad6a2018-12-29 00:05:35333 Node node(id++, &inst);
334 node.loads.push_back(opInst);
335 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55336 memrefAccesses[memref].insert(node.id);
337 nodes.insert({node.id, node});
338 }
Chris Lattner456ad6a2018-12-29 00:05:35339 if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55340 // Create graph node for top-level store op.
Chris Lattner456ad6a2018-12-29 00:05:35341 Node node(id++, &inst);
342 node.stores.push_back(opInst);
343 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55344 memrefAccesses[memref].insert(node.id);
345 nodes.insert({node.id, node});
346 }
347 }
Chris Lattner456ad6a2018-12-29 00:05:35348 // Return false if IfInsts are found (not currently supported).
349 if (isa<IfInst>(&inst))
MLIR Team6892ffb2018-12-20 04:42:55350 return false;
351 }
352
353 // Walk memref access lists and add graph edges between dependent nodes.
354 for (auto &memrefAndList : memrefAccesses) {
355 unsigned n = memrefAndList.second.size();
356 for (unsigned i = 0; i < n; ++i) {
357 unsigned srcId = memrefAndList.second[i];
358 bool srcHasStore =
359 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
360 for (unsigned j = i + 1; j < n; ++j) {
361 unsigned dstId = memrefAndList.second[j];
362 bool dstHasStore =
363 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
364 if (srcHasStore || dstHasStore)
365 addEdge(srcId, dstId, memrefAndList.first);
366 }
367 }
368 }
369 return true;
370}
371
MLIR Team38c2fe32019-01-14 19:26:25372namespace {
373
374// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
375// and operation count) for a loop nest up until the innermost loop body.
376struct LoopNestStats {
377 // Map from ForInst to immediate child ForInsts in its loop body.
378 DenseMap<ForInst *, SmallVector<ForInst *, 2>> loopMap;
379 // Map from ForInst to count of operations in its loop body.
380 DenseMap<ForInst *, uint64_t> opCountMap;
381 // Map from ForInst to its constant trip count.
382 DenseMap<ForInst *, uint64_t> tripCountMap;
383};
384
385// LoopNestStatsCollector walks a single loop nest and gathers per-loop
386// trip count and operation count statistics and records them in 'stats'.
387class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
388public:
389 LoopNestStats *stats;
390 bool hasLoopWithNonConstTripCount = false;
391
392 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
393
394 void visitForInst(ForInst *forInst) {
395 auto *parentInst = forInst->getParentInst();
396 if (parentInst != nullptr) {
397 assert(isa<ForInst>(parentInst) && "Expected parent ForInst");
398 // Add mapping to 'forInst' from its parent ForInst.
399 stats->loopMap[cast<ForInst>(parentInst)].push_back(forInst);
400 }
401 // Record the number of op instructions in the body of 'forInst'.
402 unsigned count = 0;
403 stats->opCountMap[forInst] = 0;
404 for (auto &inst : *forInst->getBody()) {
405 if (isa<OperationInst>(&inst))
406 ++count;
407 }
408 stats->opCountMap[forInst] = count;
409 // Record trip count for 'forInst'. Set flag if trip count is not constant.
410 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(*forInst);
411 if (!maybeConstTripCount.hasValue()) {
412 hasLoopWithNonConstTripCount = true;
413 return;
414 }
415 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
416 }
417};
418
419// Computes the total cost of the loop nest rooted at 'forInst'.
420// Currently, the total cost is computed by counting the total operation
421// instance count (i.e. total number of operations in the loop bodyloop
422// operation count * loop trip count) for the entire loop nest.
423// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
424// specified in the map when computing the total op instance count.
425// NOTE: this is used to compute the cost of computation slices, which are
426// sliced along the iteration dimension, and thus reduce the trip count.
427// If 'computeCostMap' is non-null, the total op count for forInsts specified
428// in the map is increased (not overridden) by adding the op count from the
429// map to the existing op count for the for loop. This is done before
430// multiplying by the loop's trip count, and is used to model the cost of
431// inserting a sliced loop nest of known cost into the loop's body.
432// NOTE: this is used to compute the cost of fusing a slice of some loop nest
433// within another loop.
MLIR Team27d067e2019-01-16 17:55:02434static uint64_t getComputeCost(
435 ForInst *forInst, LoopNestStats *stats,
436 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
437 DenseMap<ForInst *, uint64_t> *computeCostMap) {
MLIR Team38c2fe32019-01-14 19:26:25438 // 'opCount' is the total number operations in one iteration of 'forInst' body
439 uint64_t opCount = stats->opCountMap[forInst];
440 if (stats->loopMap.count(forInst) > 0) {
441 for (auto *childForInst : stats->loopMap[forInst]) {
442 opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
443 computeCostMap);
444 }
445 }
446 // Add in additional op instances from slice (if specified in map).
447 if (computeCostMap != nullptr) {
448 auto it = computeCostMap->find(forInst);
449 if (it != computeCostMap->end()) {
450 opCount += it->second;
451 }
452 }
453 // Override trip count (if specified in map).
454 uint64_t tripCount = stats->tripCountMap[forInst];
455 if (tripCountOverrideMap != nullptr) {
456 auto it = tripCountOverrideMap->find(forInst);
457 if (it != tripCountOverrideMap->end()) {
458 tripCount = it->second;
459 }
460 }
461 // Returns the total number of dynamic instances of operations in loop body.
462 return tripCount * opCount;
463}
464
465} // end anonymous namespace
466
MLIR Team27d067e2019-01-16 17:55:02467static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00468 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
469 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02470 assert(lbMap.getNumDims() == ubMap.getNumDims());
471 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
472 // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
473 // ub_expr - lb_expr
474 AffineExpr lbExpr(lbMap.getResult(0));
475 AffineExpr ubExpr(ubMap.getResult(0));
476 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
477 lbMap.getNumSymbols());
478 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
479 if (!cExpr)
480 return None;
481 return cExpr.getValue();
482}
483
MLIR Team38c2fe32019-01-14 19:26:25484// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
485// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
486// Returns true on success, false otherwise (if a non-constant trip count
487// was encountered).
488// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02489static bool buildSliceTripCountMap(
490 OperationInst *srcOpInst, ComputationSliceState *sliceState,
491 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
MLIR Team38c2fe32019-01-14 19:26:25492 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02493 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25494 unsigned numSrcLoopIVs = srcLoopIVs.size();
495 // Populate map from ForInst -> trip count
496 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
497 AffineMap lbMap = sliceState->lbs[i];
498 AffineMap ubMap = sliceState->ubs[i];
499 if (lbMap == AffineMap::Null() || ubMap == AffineMap::Null()) {
500 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
501 if (srcLoopIVs[i]->hasConstantLowerBound() &&
502 srcLoopIVs[i]->hasConstantUpperBound()) {
503 (*tripCountMap)[srcLoopIVs[i]] =
504 srcLoopIVs[i]->getConstantUpperBound() -
505 srcLoopIVs[i]->getConstantLowerBound();
506 continue;
507 }
508 return false;
509 }
MLIR Team27d067e2019-01-16 17:55:02510 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
511 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25512 return false;
MLIR Team27d067e2019-01-16 17:55:02513 (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25514 }
515 return true;
516}
517
MLIR Team27d067e2019-01-16 17:55:02518// Removes load operations from 'srcLoads' which operate on 'memref', and
519// adds them to 'dstLoads'.
520static void
521moveLoadsAccessingMemrefTo(Value *memref,
522 SmallVectorImpl<OperationInst *> *srcLoads,
523 SmallVectorImpl<OperationInst *> *dstLoads) {
524 dstLoads->clear();
525 SmallVector<OperationInst *, 4> srcLoadsToKeep;
526 for (auto *load : *srcLoads) {
527 if (load->cast<LoadOp>()->getMemRef() == memref)
528 dstLoads->push_back(load);
529 else
530 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25531 }
MLIR Team27d067e2019-01-16 17:55:02532 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25533}
534
MLIR Team27d067e2019-01-16 17:55:02535// Returns the innermost common loop depth for the set of operations in 'ops'.
536static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
537 unsigned numOps = ops.size();
538 assert(numOps > 0);
539
540 std::vector<SmallVector<ForInst *, 4>> loops(numOps);
541 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
542 for (unsigned i = 0; i < numOps; ++i) {
543 getLoopIVs(*ops[i], &loops[i]);
544 loopDepthLimit =
545 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25546 }
MLIR Team27d067e2019-01-16 17:55:02547
548 unsigned loopDepth = 0;
549 for (unsigned d = 0; d < loopDepthLimit; ++d) {
550 unsigned i;
551 for (i = 1; i < numOps; ++i) {
552 if (loops[i - 1][d] != loops[i][d]) {
553 break;
554 }
555 }
556 if (i != numOps)
557 break;
558 ++loopDepth;
559 }
560 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25561}
562
Uday Bondhugulac1ca23e2019-01-16 21:13:00563// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
564// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:02565// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
566// and 'sliceStateB' are aligned.
567// Specifically, when taking the union of overlapping intervals, it assumes
568// that both intervals start at zero. Support needs to be added to take into
569// account interval start offset when computing the union.
570// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:00571static bool getSliceUnion(const ComputationSliceState &sliceStateA,
572 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:02573 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
574 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
575
576 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
577 AffineMap lbMapA = sliceStateA.lbs[i];
578 AffineMap ubMapA = sliceStateA.ubs[i];
579 if (lbMapA == AffineMap::Null()) {
580 assert(ubMapA == AffineMap::Null());
581 continue;
582 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00583 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:02584
585 AffineMap lbMapB = sliceStateB->lbs[i];
586 AffineMap ubMapB = sliceStateB->ubs[i];
587 if (lbMapB == AffineMap::Null()) {
588 assert(ubMapB == AffineMap::Null());
589 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
590 sliceStateB->lbs[i] = lbMapA;
591 sliceStateB->ubs[i] = ubMapA;
592 continue;
593 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00594
595 // TODO(andydavis) Change this code to take the min across all lower bounds
596 // and max across all upper bounds for each dimension. This code can for
597 // cases where a unique min or max could not be statically determined.
598
599 // Assumption: both lower bounds are the same.
600 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:02601 return false;
602
603 // Add bound with the largest trip count to union.
604 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
605 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
606 if (!tripCountA.hasValue() || !tripCountB.hasValue())
607 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:00608
MLIR Team27d067e2019-01-16 17:55:02609 if (tripCountA.getValue() > tripCountB.getValue()) {
610 sliceStateB->lbs[i] = lbMapA;
611 sliceStateB->ubs[i] = ubMapA;
612 }
613 }
614 return true;
615}
616
617// Checks the profitability of fusing a backwards slice of the loop nest
618// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
619// Returns true if it profitable to fuse the candidate loop nests. Returns
620// false otherwise.
MLIR Team38c2fe32019-01-14 19:26:25621// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:02622// *) Computes the backward computation slice at 'srcOpInst'. This
623// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:25624// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:02625// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:25626// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
627// loop nest is the total number of dynamic operation instances in the loop
628// nest).
629// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:02630// loop nest at various values of dst loop depth, attempting to fuse
631// the largest compution slice at the maximal dst loop depth (closest to the
632// load) to minimize reuse distance and potentially enable subsequent
633// load/store forwarding.
634// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for
635// the same memref as is written by 'srcOpInst', then the union of slice
636// loop bounds is used to compute the slice and associated slice cost.
MLIR Team38c2fe32019-01-14 19:26:25637// NOTE: 'dstLoopDepth' refers the loop depth within the destination loop
638// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:02639// NOTE: We attempt to maximize the dst loop depth, but there are cases
640// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:25641// loop (within the src computation slice) at a depth which results in
642// execessive recomputation (see unit tests for examples).
643// *) Compares the total cost of the unfused loop nests to the min cost fused
644// loop nest computed in the previous step, and returns true if the latter
645// is lower.
MLIR Team27d067e2019-01-16 17:55:02646static bool isFusionProfitable(OperationInst *srcOpInst,
647 ArrayRef<OperationInst *> dstOpInsts,
MLIR Team38c2fe32019-01-14 19:26:25648 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:02649 unsigned *dstLoopDepth) {
MLIR Team38c2fe32019-01-14 19:26:25650 // Compute cost of sliced and unsliced src loop nest.
651 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02652 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25653 unsigned numSrcLoopIVs = srcLoopIVs.size();
654
655 // Walk src loop nest and collect stats.
656 LoopNestStats srcLoopNestStats;
657 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
658 srcStatsCollector.walk(srcLoopIVs[0]);
659 // Currently only constant trip count loop nests are supported.
660 if (srcStatsCollector.hasLoopWithNonConstTripCount)
661 return false;
662
663 // Compute cost of dst loop nest.
664 SmallVector<ForInst *, 4> dstLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02665 getLoopIVs(*dstOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25666
667 LoopNestStats dstLoopNestStats;
668 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
669 dstStatsCollector.walk(dstLoopIVs[0]);
670 // Currently only constant trip count loop nests are supported.
671 if (dstStatsCollector.hasLoopWithNonConstTripCount)
672 return false;
673
MLIR Team27d067e2019-01-16 17:55:02674 // Compute the innermost common loop for ops in 'dstOpInst'.
675 unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts);
676 if (maxDstLoopDepth == 0)
677 return false;
678
679 // Search for min cost value for 'dstLoopDepth'. At each value of
680 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
681 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
682 // of these bounds). Next the union slice bounds are used to calculate
683 // the cost of the slice and the cost of the slice inserted into the dst
684 // loop nest at 'dstLoopDepth'.
MLIR Team38c2fe32019-01-14 19:26:25685 unsigned minFusedLoopNestComputeCost = std::numeric_limits<unsigned>::max();
MLIR Team38c2fe32019-01-14 19:26:25686 unsigned bestDstLoopDepth;
MLIR Team27d067e2019-01-16 17:55:02687 SmallVector<ComputationSliceState, 4> sliceStates;
688 sliceStates.resize(maxDstLoopDepth);
689
690 llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
691 DenseMap<ForInst *, uint64_t> computeCostMap;
692 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
693 MemRefAccess srcAccess(srcOpInst);
694 // Handle the common case of one dst load without a copy.
695 if (!mlir::getBackwardComputationSliceState(
696 srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1]))
697 return false;
698 // Compute the union of slice bound of all ops in 'dstOpInsts'.
699 for (int j = 1, e = dstOpInsts.size(); j < e; ++j) {
700 MemRefAccess dstAccess(dstOpInsts[j]);
701 ComputationSliceState tmpSliceState;
702 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
703 &tmpSliceState))
704 return false;
705 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:00706 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:25707 }
MLIR Team27d067e2019-01-16 17:55:02708 // Build trip count map for computation slice.
709 sliceTripCountMap.clear();
710 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
711 &sliceTripCountMap))
712 return false;
713
MLIR Team38c2fe32019-01-14 19:26:25714 // Compute op instance count for the src loop nest with iteration slicing.
715 uint64_t sliceComputeCost =
MLIR Team27d067e2019-01-16 17:55:02716 getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap,
MLIR Team38c2fe32019-01-14 19:26:25717 /*computeCostMap=*/nullptr);
718
MLIR Team27d067e2019-01-16 17:55:02719 // Compute cost of fusion for these values of 'i' and 'j'.
720 computeCostMap.clear();
721 computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
722 uint64_t fusedLoopNestComputeCost =
723 getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
724 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
725 if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) {
726 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
727 bestDstLoopDepth = i;
MLIR Team38c2fe32019-01-14 19:26:25728 }
729 }
730
731 // Compute op instance count for the src loop nest without iteration slicing.
732 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
733 /*tripCountOverrideMap=*/nullptr,
734 /*computeCostMap=*/nullptr);
735 // Compute op instance count for the src loop nest.
736 uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
737 /*tripCountOverrideMap=*/nullptr,
738 /*computeCostMap=*/nullptr);
739
740 LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics "
MLIR Team38c2fe32019-01-14 19:26:25741 << " bestDstLoopDepth: " << bestDstLoopDepth
742 << " srcLoopNestCost: " << srcLoopNestCost
743 << " dstLoopNestCost: " << dstLoopNestCost
744 << " minFusedLoopNestComputeCost: "
745 << minFusedLoopNestComputeCost << "\n");
746
747 // Do not fuse if fused loop would increase the total cost of the computation.
748 // TODO(andydavis) Use locality/reduction in slice memref size/opportunity
749 // for load/store forwarding in cost model.
750 if (minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost)
751 return false;
MLIR Team27d067e2019-01-16 17:55:02752 // Update return parameter 'sliceState' with 'bestSliceState'.
753 ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1];
754 sliceState->lbs = bestSliceState->lbs;
755 sliceState->ubs = bestSliceState->ubs;
756 sliceState->lbOperands = bestSliceState->lbOperands;
757 sliceState->ubOperands = bestSliceState->ubOperands;
758 // Set dstLoopDepth based on best values from search.
MLIR Team38c2fe32019-01-14 19:26:25759 *dstLoopDepth = bestDstLoopDepth;
MLIR Team27d067e2019-01-16 17:55:02760 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:25761 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
MLIR Team27d067e2019-01-16 17:55:02762 if (sliceState->lbs[i] != AffineMap::Null()) {
763 canonicalizeMapAndOperands(&sliceState->lbs[i],
764 &sliceState->lbOperands[i]);
765 }
766 if (sliceState->ubs[i] != AffineMap::Null()) {
767 canonicalizeMapAndOperands(&sliceState->ubs[i],
768 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:25769 }
770 }
771 return true;
772}
773
MLIR Team6892ffb2018-12-20 04:42:55774// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:14775// relationship on a memref, with the goal of improving locality. Currently,
776// this the producer/consumer relationship is required to be unique in the
Chris Lattner69d9e992018-12-28 16:48:09777// Function (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:00778//
MLIR Team3b692302018-12-17 17:57:14779// The steps of the algorithm are as follows:
780//
MLIR Team6892ffb2018-12-20 04:42:55781// *) A worklist is initialized with node ids from the dependence graph.
782// *) For each node id in the worklist:
Chris Lattner456ad6a2018-12-29 00:05:35783// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
784// destination ForInst into which fusion will be attempted.
785// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:14786// *) For each LoadOp in 'dstLoadOps' do:
Chris Lattner69d9e992018-12-28 16:48:09787// *) Lookup dependent loop nests at earlier positions in the Function
MLIR Team3b692302018-12-17 17:57:14788// which have a single store op to the same memref.
789// *) Check if dependences would be violated by the fusion. For example,
790// the src loop nest may load from memrefs which are different than
791// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:55792// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:14793// bounds to be functions of 'dstLoopNest' IVs and symbols.
794// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
795// just before the dst load op user.
Chris Lattner456ad6a2018-12-29 00:05:35796// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:14797// and also add newly fuse load ops to 'dstLoopOps' to be considered
798// as fusion dst load ops in another iteration.
799// *) Remove old src loop nest and its associated state.
800//
Chris Lattner456ad6a2018-12-29 00:05:35801// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:14802// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:55803// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:14804//
MLIR Team6892ffb2018-12-20 04:42:55805// This greedy algorithm is not 'maximal' due to the current restriction of
806// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:14807//
808// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:55809// TODO(andydavis) Add support for fusing for input reuse (perhaps by
810// constructing a graph with edges which represent loads from the same memref
811// in two different loop nestst.
812struct GreedyFusion {
813public:
814 MemRefDependenceGraph *mdg;
MLIR Team3b692302018-12-17 17:57:14815 SmallVector<unsigned, 4> worklist;
MLIR Teamf28e4df2018-11-01 14:26:00816
MLIR Team6892ffb2018-12-20 04:42:55817 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
818 // Initialize worklist with nodes from 'mdg'.
819 worklist.resize(mdg->nodes.size());
820 std::iota(worklist.begin(), worklist.end(), 0);
821 }
MLIR Team3b692302018-12-17 17:57:14822
823 void run() {
MLIR Team3b692302018-12-17 17:57:14824 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:55825 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:14826 worklist.pop_back();
MLIR Team6892ffb2018-12-20 04:42:55827 // Skip if this node was removed (fused into another node).
828 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:14829 continue;
MLIR Team6892ffb2018-12-20 04:42:55830 // Get 'dstNode' into which to attempt fusion.
831 auto *dstNode = mdg->getNode(dstId);
832 // Skip if 'dstNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:35833 if (!isa<ForInst>(dstNode->inst))
MLIR Team3b692302018-12-17 17:57:14834 continue;
835
Chris Lattner5187cfc2018-12-28 05:21:41836 SmallVector<OperationInst *, 4> loads = dstNode->loads;
MLIR Team27d067e2019-01-16 17:55:02837 SmallVector<OperationInst *, 4> dstLoadOpInsts;
MLIR Team6892ffb2018-12-20 04:42:55838 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:02839 // Get memref of load on top of the stack.
840 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
841 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
842 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:55843 // Skip if no input edges along which to fuse.
844 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:14845 continue;
MLIR Team6892ffb2018-12-20 04:42:55846 // Iterate through in edges for 'dstId'.
847 for (auto &srcEdge : mdg->inEdges[dstId]) {
848 // Skip 'srcEdge' if not for 'memref'.
849 if (srcEdge.memref != memref)
850 continue;
851 auto *srcNode = mdg->getNode(srcEdge.id);
852 // Skip if 'srcNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:35853 if (!isa<ForInst>(srcNode->inst))
MLIR Team6892ffb2018-12-20 04:42:55854 continue;
855 // Skip if 'srcNode' has more than one store to 'memref'.
856 if (srcNode->getStoreOpCount(memref) != 1)
857 continue;
858 // Skip 'srcNode' if it has out edges on 'memref' other than 'dstId'.
859 if (mdg->getOutEdgeCount(srcNode->id, memref) != 1)
860 continue;
861 // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly
862 // TODO(andydavis) Track dependence type with edges, and just check
863 // for WAW dependence edge here.
864 if (mdg->getInEdgeCount(srcNode->id, memref) != 0)
865 continue;
866 // Skip if 'srcNode' has out edges to other memrefs after 'dstId'.
867 if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId)
868 continue;
869 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:35870 auto *srcStoreOpInst = srcNode->stores.front();
MLIR Team38c2fe32019-01-14 19:26:25871 // Check if fusion would be profitable.
MLIR Team38c2fe32019-01-14 19:26:25872 unsigned dstLoopDepth;
873 mlir::ComputationSliceState sliceState;
MLIR Team27d067e2019-01-16 17:55:02874 if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
MLIR Team38c2fe32019-01-14 19:26:25875 &dstLoopDepth))
876 continue;
MLIR Team6892ffb2018-12-20 04:42:55877 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
878 auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
MLIR Team27d067e2019-01-16 17:55:02879 srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:55880 if (sliceLoopNest != nullptr) {
881 // Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode'
882 mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);
883 // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'.
884 LoopNestStateCollector collector;
Chris Lattner456ad6a2018-12-29 00:05:35885 collector.walkForInst(sliceLoopNest);
886 mdg->addToNode(dstId, collector.loadOpInsts,
887 collector.storeOpInsts);
MLIR Team6892ffb2018-12-20 04:42:55888 // Add new load ops to current Node load op list 'loads' to
889 // continue fusing based on new operands.
Chris Lattner456ad6a2018-12-29 00:05:35890 for (auto *loadOpInst : collector.loadOpInsts)
891 loads.push_back(loadOpInst);
MLIR Team6892ffb2018-12-20 04:42:55892 // Promote single iteration loops to single IV value.
Chris Lattner456ad6a2018-12-29 00:05:35893 for (auto *forInst : collector.forInsts) {
894 promoteIfSingleIteration(forInst);
MLIR Team6892ffb2018-12-20 04:42:55895 }
896 // Remove old src loop nest.
Chris Lattner456ad6a2018-12-29 00:05:35897 cast<ForInst>(srcNode->inst)->erase();
MLIR Team3b692302018-12-17 17:57:14898 }
MLIR Team3b692302018-12-17 17:57:14899 }
900 }
901 }
MLIR Teamf28e4df2018-11-01 14:26:00902 }
MLIR Team3b692302018-12-17 17:57:14903};
904
905} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:00906
Chris Lattner79748892018-12-31 07:10:35907PassResult LoopFusion::runOnFunction(Function *f) {
MLIR Team6892ffb2018-12-20 04:42:55908 MemRefDependenceGraph g;
909 if (g.init(f))
910 GreedyFusion(&g).run();
MLIR Teamf28e4df2018-11-01 14:26:00911 return success();
912}
Jacques Pienaar6f0fb222018-11-07 02:34:18913
914static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");