blob: c86eec3d276ccd603eef4511a3c407e6f2fa33fd [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1423#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0024#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1425#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0026#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinOps.h"
30#include "mlir/IR/StmtVisitor.h"
31#include "mlir/Pass.h"
32#include "mlir/StandardOps/StandardOps.h"
33#include "mlir/Transforms/LoopUtils.h"
34#include "mlir/Transforms/Passes.h"
35#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1436#include "llvm/ADT/DenseSet.h"
37#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2338#include "llvm/Support/CommandLine.h"
MLIR Team3b692302018-12-17 17:57:1439#include "llvm/Support/raw_ostream.h"
40
41using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0042
43using namespace mlir;
44
MLIR Team4eef7952018-12-21 19:06:2345// TODO(andydavis) These flags are global for the pass to be used for
46// experimentation. Find a way to provide more fine grained control (i.e.
47// depth per-loop nest, or depth per load/store op) for this pass utilizing a
48// cost model.
49static llvm::cl::opt<unsigned> clSrcLoopDepth(
50 "src-loop-depth", llvm::cl::Hidden,
51 llvm::cl::desc("Controls the depth of the source loop nest at which "
52 "to apply loop iteration slicing before fusion."));
53
54static llvm::cl::opt<unsigned> clDstLoopDepth(
55 "dst-loop-depth", llvm::cl::Hidden,
56 llvm::cl::desc("Controls the depth of the destination loop nest at which "
57 "to fuse the source loop nest slice."));
58
MLIR Teamf28e4df2018-11-01 14:26:0059namespace {
60
MLIR Team3b692302018-12-17 17:57:1461/// Loop fusion pass. This pass currently supports a greedy fusion policy,
62/// which fuses loop nests with single-writer/single-reader memref dependences
63/// with the goal of improving locality.
64
65// TODO(andydavis) Support fusion of source loop nests which write to multiple
66// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0067// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
68// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1469
MLIR Teamf28e4df2018-11-01 14:26:0070struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0371 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0072
73 PassResult runOnMLFunction(MLFunction *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1874 static char passID;
MLIR Teamf28e4df2018-11-01 14:26:0075};
76
MLIR Teamf28e4df2018-11-01 14:26:0077} // end anonymous namespace
78
Jacques Pienaar6f0fb222018-11-07 02:34:1879char LoopFusion::passID = 0;
80
MLIR Teamf28e4df2018-11-01 14:26:0081FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
82
MLIR Teamf28e4df2018-11-01 14:26:0083static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt,
84 MemRefAccess *access) {
85 if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
Chris Lattner3f190312018-12-27 22:35:1086 access->memref = loadOp->getMemRef();
MLIR Teamf28e4df2018-11-01 14:26:0087 access->opStmt = loadOrStoreOpStmt;
88 auto loadMemrefType = loadOp->getMemRefType();
89 access->indices.reserve(loadMemrefType.getRank());
90 for (auto *index : loadOp->getIndices()) {
Chris Lattner3f190312018-12-27 22:35:1091 access->indices.push_back(index);
MLIR Teamf28e4df2018-11-01 14:26:0092 }
93 } else {
94 assert(loadOrStoreOpStmt->isa<StoreOp>());
95 auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
96 access->opStmt = loadOrStoreOpStmt;
Chris Lattner3f190312018-12-27 22:35:1097 access->memref = storeOp->getMemRef();
MLIR Teamf28e4df2018-11-01 14:26:0098 auto storeMemrefType = storeOp->getMemRefType();
99 access->indices.reserve(storeMemrefType.getRank());
100 for (auto *index : storeOp->getIndices()) {
Chris Lattner3f190312018-12-27 22:35:10101 access->indices.push_back(index);
MLIR Teamf28e4df2018-11-01 14:26:00102 }
103 }
104}
105
MLIR Team3b692302018-12-17 17:57:14106// FusionCandidate encapsulates source and destination memref access within
107// loop nests which are candidates for loop fusion.
108struct FusionCandidate {
109 // Load or store access within src loop nest to be fused into dst loop nest.
110 MemRefAccess srcAccess;
111 // Load or store access within dst loop nest.
112 MemRefAccess dstAccess;
113};
MLIR Teamf28e4df2018-11-01 14:26:00114
MLIR Team3b692302018-12-17 17:57:14115static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt,
116 OperationStmt *dstLoadOpStmt) {
117 FusionCandidate candidate;
118 // Get store access for src loop nest.
119 getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess);
120 // Get load access for dst loop nest.
121 getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess);
122 return candidate;
MLIR Teamf28e4df2018-11-01 14:26:00123}
124
MLIR Team4eef7952018-12-21 19:06:23125// Returns the loop depth of the loop nest surrounding 'opStmt'.
126static unsigned getLoopDepth(OperationStmt *opStmt) {
127 unsigned loopDepth = 0;
128 auto *currStmt = opStmt->getParentStmt();
129 ForStmt *currForStmt;
130 while (currStmt && (currForStmt = dyn_cast<ForStmt>(currStmt))) {
131 ++loopDepth;
132 currStmt = currStmt->getParentStmt();
133 }
134 return loopDepth;
135}
136
MLIR Team3b692302018-12-17 17:57:14137namespace {
MLIR Teamf28e4df2018-11-01 14:26:00138
MLIR Team3b692302018-12-17 17:57:14139// LoopNestStateCollector walks loop nests and collects load and store
140// operations, and whether or not an IfStmt was encountered in the loop nest.
141class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> {
142public:
143 SmallVector<ForStmt *, 4> forStmts;
144 SmallVector<OperationStmt *, 4> loadOpStmts;
145 SmallVector<OperationStmt *, 4> storeOpStmts;
146 bool hasIfStmt = false;
147
148 void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
149
150 void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
151
152 void visitOperationStmt(OperationStmt *opStmt) {
153 if (opStmt->isa<LoadOp>())
154 loadOpStmts.push_back(opStmt);
155 if (opStmt->isa<StoreOp>())
156 storeOpStmts.push_back(opStmt);
157 }
158};
159
MLIR Team6892ffb2018-12-20 04:42:55160// MemRefDependenceGraph is a graph data structure where graph nodes are
161// top-level statements in an MLFunction which contain load/store ops, and edges
162// are memref dependences between the nodes.
163// TODO(andydavis) Add a depth parameter to dependence graph construction.
164struct MemRefDependenceGraph {
165public:
166 // Node represents a node in the graph. A Node is either an entire loop nest
167 // rooted at the top level which contains loads/stores, or a top level
168 // load/store.
169 struct Node {
170 // The unique identifier of this node in the graph.
171 unsigned id;
172 // The top-level statment which is (or contains) loads/stores.
173 Statement *stmt;
174 // List of load op stmts.
175 SmallVector<OperationStmt *, 4> loads;
176 // List of store op stmts.
177 SmallVector<OperationStmt *, 4> stores;
178 Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {}
179
180 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10181 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55182 unsigned loadOpCount = 0;
183 for (auto *loadOpStmt : loads) {
Chris Lattner3f190312018-12-27 22:35:10184 if (memref == loadOpStmt->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55185 ++loadOpCount;
186 }
187 return loadOpCount;
188 }
189
190 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10191 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55192 unsigned storeOpCount = 0;
193 for (auto *storeOpStmt : stores) {
Chris Lattner3f190312018-12-27 22:35:10194 if (memref == storeOpStmt->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55195 ++storeOpCount;
196 }
197 return storeOpCount;
198 }
199 };
200
201 // Edge represents a memref data dependece between nodes in the graph.
202 struct Edge {
203 // The id of the node at the other end of the edge.
204 unsigned id;
205 // The memref on which this edge represents a dependence.
Chris Lattner3f190312018-12-27 22:35:10206 Value *memref;
MLIR Team6892ffb2018-12-20 04:42:55207 };
208
209 // Map from node id to Node.
210 DenseMap<unsigned, Node> nodes;
211 // Map from node id to list of input edges.
212 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
213 // Map from node id to list of output edges.
214 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
215
216 MemRefDependenceGraph() {}
217
218 // Initializes the dependence graph based on operations in 'f'.
219 // Returns true on success, false otherwise.
220 bool init(MLFunction *f);
221
222 // Returns the graph node for 'id'.
223 Node *getNode(unsigned id) {
224 auto it = nodes.find(id);
225 assert(it != nodes.end());
226 return &it->second;
227 }
228
229 // Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10230 void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55231 outEdges[srcId].push_back({dstId, memref});
232 inEdges[dstId].push_back({srcId, memref});
233 }
234
235 // Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10236 void removeEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55237 assert(inEdges.count(dstId) > 0);
238 assert(outEdges.count(srcId) > 0);
239 // Remove 'srcId' from 'inEdges[dstId]'.
240 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
241 if ((*it).id == srcId && (*it).memref == memref) {
242 inEdges[dstId].erase(it);
243 break;
244 }
245 }
246 // Remove 'dstId' from 'outEdges[srcId]'.
247 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
248 if ((*it).id == dstId && (*it).memref == memref) {
249 outEdges[srcId].erase(it);
250 break;
251 }
252 }
253 }
254
255 // Returns the input edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10256 unsigned getInEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55257 unsigned inEdgeCount = 0;
258 if (inEdges.count(id) > 0)
259 for (auto &inEdge : inEdges[id])
260 if (inEdge.memref == memref)
261 ++inEdgeCount;
262 return inEdgeCount;
263 }
264
265 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10266 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55267 unsigned outEdgeCount = 0;
268 if (outEdges.count(id) > 0)
269 for (auto &outEdge : outEdges[id])
270 if (outEdge.memref == memref)
271 ++outEdgeCount;
272 return outEdgeCount;
273 }
274
275 // Returns the min node id of all output edges from node 'id'.
276 unsigned getMinOutEdgeNodeId(unsigned id) {
277 unsigned minId = std::numeric_limits<unsigned>::max();
278 if (outEdges.count(id) > 0)
279 for (auto &outEdge : outEdges[id])
280 minId = std::min(minId, outEdge.id);
281 return minId;
282 }
283
284 // Updates edge mappings from node 'srcId' to node 'dstId' and removes
285 // state associated with node 'srcId'.
286 void updateEdgesAndRemoveSrcNode(unsigned srcId, unsigned dstId) {
287 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
288 if (inEdges.count(srcId) > 0) {
289 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
290 for (auto &inEdge : oldInEdges) {
291 // Remove edge from 'inEdge.id' to 'srcId'.
292 removeEdge(inEdge.id, srcId, inEdge.memref);
293 // Add edge from 'inEdge.id' to 'dstId'.
294 addEdge(inEdge.id, dstId, inEdge.memref);
295 }
296 }
297 // For each edge in 'outEdges[srcId]': add new edge remaping to 'dstId'.
298 if (outEdges.count(srcId) > 0) {
299 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
300 for (auto &outEdge : oldOutEdges) {
301 // Remove edge from 'srcId' to 'outEdge.id'.
302 removeEdge(srcId, outEdge.id, outEdge.memref);
303 // Add edge from 'dstId' to 'outEdge.id' (if 'outEdge.id' != 'dstId').
304 if (outEdge.id != dstId)
305 addEdge(dstId, outEdge.id, outEdge.memref);
306 }
307 }
308 // Remove 'srcId' from graph state.
309 inEdges.erase(srcId);
310 outEdges.erase(srcId);
311 nodes.erase(srcId);
312 }
313
314 // Adds ops in 'loads' and 'stores' to node at 'id'.
315 void addToNode(unsigned id, const SmallVectorImpl<OperationStmt *> &loads,
316 const SmallVectorImpl<OperationStmt *> &stores) {
317 Node *node = getNode(id);
318 for (auto *loadOpStmt : loads)
319 node->loads.push_back(loadOpStmt);
320 for (auto *storeOpStmt : stores)
321 node->stores.push_back(storeOpStmt);
322 }
323
324 void print(raw_ostream &os) const {
325 os << "\nMemRefDependenceGraph\n";
326 os << "\nNodes:\n";
327 for (auto &idAndNode : nodes) {
328 os << "Node: " << idAndNode.first << "\n";
329 auto it = inEdges.find(idAndNode.first);
330 if (it != inEdges.end()) {
331 for (const auto &e : it->second)
332 os << " InEdge: " << e.id << " " << e.memref << "\n";
333 }
334 it = outEdges.find(idAndNode.first);
335 if (it != outEdges.end()) {
336 for (const auto &e : it->second)
337 os << " OutEdge: " << e.id << " " << e.memref << "\n";
338 }
339 }
340 }
341 void dump() const { print(llvm::errs()); }
342};
343
344// Intializes the data dependence graph by walking statements in 'f'.
345// Assigns each node in the graph a node id based on program order in 'f'.
346// TODO(andydavis) Add support for taking a StmtBlock arg to construct the
347// dependence graph at a different depth.
348bool MemRefDependenceGraph::init(MLFunction *f) {
349 unsigned id = 0;
Chris Lattner3f190312018-12-27 22:35:10350 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerd613f5a2018-12-26 19:21:53351 for (auto &stmt : *f->getBody()) {
MLIR Team6892ffb2018-12-20 04:42:55352 if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
353 // Create graph node 'id' to represent top-level 'forStmt' and record
354 // all loads and store accesses it contains.
355 LoopNestStateCollector collector;
356 collector.walkForStmt(forStmt);
357 // Return false if IfStmts are found (not currently supported).
358 if (collector.hasIfStmt)
359 return false;
360 Node node(id++, &stmt);
361 for (auto *opStmt : collector.loadOpStmts) {
362 node.loads.push_back(opStmt);
Chris Lattner3f190312018-12-27 22:35:10363 auto *memref = opStmt->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55364 memrefAccesses[memref].insert(node.id);
365 }
366 for (auto *opStmt : collector.storeOpStmts) {
367 node.stores.push_back(opStmt);
Chris Lattner3f190312018-12-27 22:35:10368 auto *memref = opStmt->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55369 memrefAccesses[memref].insert(node.id);
370 }
371 nodes.insert({node.id, node});
372 }
373 if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
374 if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
375 // Create graph node for top-level load op.
376 Node node(id++, &stmt);
377 node.loads.push_back(opStmt);
Chris Lattner3f190312018-12-27 22:35:10378 auto *memref = opStmt->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55379 memrefAccesses[memref].insert(node.id);
380 nodes.insert({node.id, node});
381 }
382 if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
383 // Create graph node for top-level store op.
384 Node node(id++, &stmt);
385 node.stores.push_back(opStmt);
Chris Lattner3f190312018-12-27 22:35:10386 auto *memref = opStmt->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55387 memrefAccesses[memref].insert(node.id);
388 nodes.insert({node.id, node});
389 }
390 }
391 // Return false if IfStmts are found (not currently supported).
392 if (isa<IfStmt>(&stmt))
393 return false;
394 }
395
396 // Walk memref access lists and add graph edges between dependent nodes.
397 for (auto &memrefAndList : memrefAccesses) {
398 unsigned n = memrefAndList.second.size();
399 for (unsigned i = 0; i < n; ++i) {
400 unsigned srcId = memrefAndList.second[i];
401 bool srcHasStore =
402 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
403 for (unsigned j = i + 1; j < n; ++j) {
404 unsigned dstId = memrefAndList.second[j];
405 bool dstHasStore =
406 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
407 if (srcHasStore || dstHasStore)
408 addEdge(srcId, dstId, memrefAndList.first);
409 }
410 }
411 }
412 return true;
413}
414
415// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:14416// relationship on a memref, with the goal of improving locality. Currently,
417// this the producer/consumer relationship is required to be unique in the
418// MLFunction (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:00419//
MLIR Team3b692302018-12-17 17:57:14420// The steps of the algorithm are as follows:
421//
MLIR Team6892ffb2018-12-20 04:42:55422// *) A worklist is initialized with node ids from the dependence graph.
423// *) For each node id in the worklist:
MLIR Team3b692302018-12-17 17:57:14424// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate
425// destination ForStmt into which fusion will be attempted.
426// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'.
427// *) For each LoadOp in 'dstLoadOps' do:
428// *) Lookup dependent loop nests at earlier positions in the MLFunction
429// which have a single store op to the same memref.
430// *) Check if dependences would be violated by the fusion. For example,
431// the src loop nest may load from memrefs which are different than
432// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:55433// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:14434// bounds to be functions of 'dstLoopNest' IVs and symbols.
435// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
436// just before the dst load op user.
437// *) Add the newly fused load/store operation statements to the state,
438// and also add newly fuse load ops to 'dstLoopOps' to be considered
439// as fusion dst load ops in another iteration.
440// *) Remove old src loop nest and its associated state.
441//
442// Given a graph where top-level statements are vertices in the set 'V' and
443// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:55444// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:14445//
MLIR Team6892ffb2018-12-20 04:42:55446// This greedy algorithm is not 'maximal' due to the current restriction of
447// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:14448//
449// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:55450// TODO(andydavis) Add support for fusing for input reuse (perhaps by
451// constructing a graph with edges which represent loads from the same memref
452// in two different loop nestst.
453struct GreedyFusion {
454public:
455 MemRefDependenceGraph *mdg;
MLIR Team3b692302018-12-17 17:57:14456 SmallVector<unsigned, 4> worklist;
MLIR Teamf28e4df2018-11-01 14:26:00457
MLIR Team6892ffb2018-12-20 04:42:55458 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
459 // Initialize worklist with nodes from 'mdg'.
460 worklist.resize(mdg->nodes.size());
461 std::iota(worklist.begin(), worklist.end(), 0);
462 }
MLIR Team3b692302018-12-17 17:57:14463
464 void run() {
MLIR Team3b692302018-12-17 17:57:14465 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:55466 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:14467 worklist.pop_back();
MLIR Team6892ffb2018-12-20 04:42:55468 // Skip if this node was removed (fused into another node).
469 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:14470 continue;
MLIR Team6892ffb2018-12-20 04:42:55471 // Get 'dstNode' into which to attempt fusion.
472 auto *dstNode = mdg->getNode(dstId);
473 // Skip if 'dstNode' is not a loop nest.
474 if (!isa<ForStmt>(dstNode->stmt))
MLIR Team3b692302018-12-17 17:57:14475 continue;
476
MLIR Team6892ffb2018-12-20 04:42:55477 SmallVector<OperationStmt *, 4> loads = dstNode->loads;
478 while (!loads.empty()) {
479 auto *dstLoadOpStmt = loads.pop_back_val();
Chris Lattner3f190312018-12-27 22:35:10480 auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55481 // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'.
482 if (dstNode->getLoadOpCount(memref) != 1)
MLIR Team3b692302018-12-17 17:57:14483 continue;
MLIR Team6892ffb2018-12-20 04:42:55484 // Skip if no input edges along which to fuse.
485 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:14486 continue;
MLIR Team6892ffb2018-12-20 04:42:55487 // Iterate through in edges for 'dstId'.
488 for (auto &srcEdge : mdg->inEdges[dstId]) {
489 // Skip 'srcEdge' if not for 'memref'.
490 if (srcEdge.memref != memref)
491 continue;
492 auto *srcNode = mdg->getNode(srcEdge.id);
493 // Skip if 'srcNode' is not a loop nest.
494 if (!isa<ForStmt>(srcNode->stmt))
495 continue;
496 // Skip if 'srcNode' has more than one store to 'memref'.
497 if (srcNode->getStoreOpCount(memref) != 1)
498 continue;
499 // Skip 'srcNode' if it has out edges on 'memref' other than 'dstId'.
500 if (mdg->getOutEdgeCount(srcNode->id, memref) != 1)
501 continue;
502 // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly
503 // TODO(andydavis) Track dependence type with edges, and just check
504 // for WAW dependence edge here.
505 if (mdg->getInEdgeCount(srcNode->id, memref) != 0)
506 continue;
507 // Skip if 'srcNode' has out edges to other memrefs after 'dstId'.
508 if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId)
509 continue;
510 // Get unique 'srcNode' store op.
511 auto *srcStoreOpStmt = srcNode->stores.front();
512 // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'.
513 FusionCandidate candidate =
514 buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt);
515 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
MLIR Team4eef7952018-12-21 19:06:23516 unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0
517 ? clSrcLoopDepth
518 : getLoopDepth(srcStoreOpStmt);
519 unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0
520 ? clDstLoopDepth
521 : getLoopDepth(dstLoadOpStmt);
MLIR Team6892ffb2018-12-20 04:42:55522 auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
MLIR Team4eef7952018-12-21 19:06:23523 &candidate.srcAccess, &candidate.dstAccess, srcLoopDepth,
524 dstLoopDepth);
MLIR Team6892ffb2018-12-20 04:42:55525 if (sliceLoopNest != nullptr) {
526 // Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode'
527 mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);
528 // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'.
529 LoopNestStateCollector collector;
530 collector.walkForStmt(sliceLoopNest);
531 mdg->addToNode(dstId, collector.loadOpStmts,
532 collector.storeOpStmts);
533 // Add new load ops to current Node load op list 'loads' to
534 // continue fusing based on new operands.
535 for (auto *loadOpStmt : collector.loadOpStmts)
536 loads.push_back(loadOpStmt);
537 // Promote single iteration loops to single IV value.
538 for (auto *forStmt : collector.forStmts) {
539 promoteIfSingleIteration(forStmt);
540 }
541 // Remove old src loop nest.
542 cast<ForStmt>(srcNode->stmt)->erase();
MLIR Team3b692302018-12-17 17:57:14543 }
MLIR Team3b692302018-12-17 17:57:14544 }
545 }
546 }
MLIR Teamf28e4df2018-11-01 14:26:00547 }
MLIR Team3b692302018-12-17 17:57:14548};
549
550} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:00551
552PassResult LoopFusion::runOnMLFunction(MLFunction *f) {
MLIR Team6892ffb2018-12-20 04:42:55553 MemRefDependenceGraph g;
554 if (g.init(f))
555 GreedyFusion(&g).run();
MLIR Teamf28e4df2018-11-01 14:26:00556 return success();
557}
Jacques Pienaar6f0fb222018-11-07 02:34:18558
559static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");