blob: 0add79724206d2cb4d4506a4cd6cab7bbd9d2f9a [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1423#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0024#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1425#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0026#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinOps.h"
Chris Lattner456ad6a2018-12-29 00:05:3530#include "mlir/IR/InstVisitor.h"
MLIR Teamf28e4df2018-11-01 14:26:0031#include "mlir/Pass.h"
32#include "mlir/StandardOps/StandardOps.h"
33#include "mlir/Transforms/LoopUtils.h"
34#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2735#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0036#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1437#include "llvm/ADT/DenseSet.h"
38#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2339#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2540#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1441#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2442#include <iomanip>
MLIR Team3b692302018-12-17 17:57:1443
MLIR Team38c2fe32019-01-14 19:26:2544#define DEBUG_TYPE "loop-fusion"
45
MLIR Team3b692302018-12-17 17:57:1446using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0047
48using namespace mlir;
49
Uday Bondhugula864d9e02019-01-23 17:16:2450/// Disables fusion profitability check and fuses if valid.
MLIR Teamc4237ae2019-01-18 16:56:2751static llvm::cl::opt<bool>
52 clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
Uday Bondhugula864d9e02019-01-23 17:16:2453 llvm::cl::desc("Enables maximal loop fusion"));
54
55/// A threshold in percent of additional computation allowed when fusing.
56static llvm::cl::opt<double> clFusionAddlComputeTolerance(
57 "fusion-compute-tolerance", llvm::cl::Hidden,
58 llvm::cl::desc("Fractional increase in additional"
59 "computation tolerated while fusing"));
MLIR Teamc4237ae2019-01-18 16:56:2760
MLIR Teamf28e4df2018-11-01 14:26:0061namespace {
62
MLIR Team3b692302018-12-17 17:57:1463/// Loop fusion pass. This pass currently supports a greedy fusion policy,
64/// which fuses loop nests with single-writer/single-reader memref dependences
65/// with the goal of improving locality.
66
67// TODO(andydavis) Support fusion of source loop nests which write to multiple
68// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0069// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
70// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1471
MLIR Teamf28e4df2018-11-01 14:26:0072struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0373 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0074
Chris Lattner79748892018-12-31 07:10:3575 PassResult runOnFunction(Function *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1876 static char passID;
Uday Bondhugula864d9e02019-01-23 17:16:2477
78 // The amount of additional computation that is tolerated while fusing
79 // pair-wise as a fraction of the total computation.
80 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:0081};
82
MLIR Teamf28e4df2018-11-01 14:26:0083} // end anonymous namespace
84
Jacques Pienaar6f0fb222018-11-07 02:34:1885char LoopFusion::passID = 0;
86
MLIR Teamf28e4df2018-11-01 14:26:0087FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
88
MLIR Team3b692302018-12-17 17:57:1489namespace {
MLIR Teamf28e4df2018-11-01 14:26:0090
MLIR Team3b692302018-12-17 17:57:1491// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:3592// operations, and whether or not an IfInst was encountered in the loop nest.
93class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
MLIR Team3b692302018-12-17 17:57:1494public:
Chris Lattner456ad6a2018-12-29 00:05:3595 SmallVector<ForInst *, 4> forInsts;
96 SmallVector<OperationInst *, 4> loadOpInsts;
97 SmallVector<OperationInst *, 4> storeOpInsts;
98 bool hasIfInst = false;
MLIR Team3b692302018-12-17 17:57:1499
Chris Lattner456ad6a2018-12-29 00:05:35100 void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
MLIR Team3b692302018-12-17 17:57:14101
Chris Lattner456ad6a2018-12-29 00:05:35102 void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
MLIR Team3b692302018-12-17 17:57:14103
Chris Lattner456ad6a2018-12-29 00:05:35104 void visitOperationInst(OperationInst *opInst) {
105 if (opInst->isa<LoadOp>())
106 loadOpInsts.push_back(opInst);
107 if (opInst->isa<StoreOp>())
108 storeOpInsts.push_back(opInst);
MLIR Team3b692302018-12-17 17:57:14109 }
110};
111
MLIR Team71495d52019-01-22 21:23:37112// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
113static bool isMemRefDereferencingOp(const OperationInst &op) {
114 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
115 op.isa<DmaWaitOp>())
116 return true;
117 return false;
118}
MLIR Team6892ffb2018-12-20 04:42:55119// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:35120// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55121// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27122// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55123// TODO(andydavis) Add a depth parameter to dependence graph construction.
124struct MemRefDependenceGraph {
125public:
126 // Node represents a node in the graph. A Node is either an entire loop nest
127 // rooted at the top level which contains loads/stores, or a top level
128 // load/store.
129 struct Node {
130 // The unique identifier of this node in the graph.
131 unsigned id;
132 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35133 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41134 // List of load operations.
135 SmallVector<OperationInst *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35136 // List of store op insts.
Chris Lattner5187cfc2018-12-28 05:21:41137 SmallVector<OperationInst *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35138 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55139
140 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10141 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55142 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35143 for (auto *loadOpInst : loads) {
144 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55145 ++loadOpCount;
146 }
147 return loadOpCount;
148 }
149
150 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10151 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55152 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35153 for (auto *storeOpInst : stores) {
154 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55155 ++storeOpCount;
156 }
157 return storeOpCount;
158 }
159 };
160
161 // Edge represents a memref data dependece between nodes in the graph.
162 struct Edge {
163 // The id of the node at the other end of the edge.
164 unsigned id;
165 // The memref on which this edge represents a dependence.
Chris Lattner3f190312018-12-27 22:35:10166 Value *memref;
MLIR Team6892ffb2018-12-20 04:42:55167 };
168
169 // Map from node id to Node.
170 DenseMap<unsigned, Node> nodes;
171 // Map from node id to list of input edges.
172 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
173 // Map from node id to list of output edges.
174 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27175 // Map from memref to a count on the dependence edges associated with that
176 // memref.
177 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Team6892ffb2018-12-20 04:42:55178
179 MemRefDependenceGraph() {}
180
181 // Initializes the dependence graph based on operations in 'f'.
182 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09183 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55184
185 // Returns the graph node for 'id'.
186 Node *getNode(unsigned id) {
187 auto it = nodes.find(id);
188 assert(it != nodes.end());
189 return &it->second;
190 }
191
MLIR Teamc4237ae2019-01-18 16:56:27192 // Remove node 'id' (and its associated edges) from graph.
193 void removeNode(unsigned id) {
194 // Remove each edge in 'inEdges[id]'.
195 if (inEdges.count(id) > 0) {
196 SmallVector<Edge, 2> oldInEdges = inEdges[id];
197 for (auto &inEdge : oldInEdges) {
198 removeEdge(inEdge.id, id, inEdge.memref);
199 }
200 }
201 // Remove each edge in 'outEdges[id]'.
202 if (outEdges.count(id) > 0) {
203 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
204 for (auto &outEdge : oldOutEdges) {
205 removeEdge(id, outEdge.id, outEdge.memref);
206 }
207 }
208 // Erase remaining node state.
209 inEdges.erase(id);
210 outEdges.erase(id);
211 nodes.erase(id);
212 }
213
214 bool hasOutEdges(unsigned id) {
215 return outEdges.count(id) > 0 && !outEdges[id].empty();
216 }
217
MLIR Team71495d52019-01-22 21:23:37218 // Returns true if node 'id' writes to any memref which escapes (or is an
219 // argument to) the function/block. Returns false otherwise.
220 bool writesToLiveInOrEscapingMemrefs(unsigned id) {
221 Node *node = getNode(id);
222 for (auto *storeOpInst : node->stores) {
223 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
224 auto *inst = memref->getDefiningInst();
225 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
226 // Return false if 'memref' is a function argument.
227 if (opInst == nullptr)
228 return true;
229 // Return false if any use of 'memref' escapes the function.
230 for (auto &use : memref->getUses()) {
231 auto *user = dyn_cast<OperationInst>(use.getOwner());
232 if (!user || !isMemRefDereferencingOp(*user))
233 return true;
234 }
235 }
236 return false;
237 }
238
MLIR Team27d067e2019-01-16 17:55:02239 // Returns true iff there is an edge from node 'srcId' to node 'dstId' for
240 // 'memref'. Returns false otherwise.
241 bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) {
242 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
243 return false;
244 }
245 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
246 return edge.id == dstId && edge.memref == memref;
247 });
248 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
249 return edge.id == srcId && edge.memref == memref;
250 });
251 return hasOutEdge && hasInEdge;
252 }
253
MLIR Team6892ffb2018-12-20 04:42:55254 // Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10255 void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team27d067e2019-01-16 17:55:02256 if (!hasEdge(srcId, dstId, memref)) {
257 outEdges[srcId].push_back({dstId, memref});
258 inEdges[dstId].push_back({srcId, memref});
MLIR Teamc4237ae2019-01-18 16:56:27259 memrefEdgeCount[memref]++;
MLIR Team27d067e2019-01-16 17:55:02260 }
MLIR Team6892ffb2018-12-20 04:42:55261 }
262
263 // Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10264 void removeEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55265 assert(inEdges.count(dstId) > 0);
266 assert(outEdges.count(srcId) > 0);
MLIR Teamc4237ae2019-01-18 16:56:27267 assert(memrefEdgeCount.count(memref) > 0);
268 memrefEdgeCount[memref]--;
MLIR Team6892ffb2018-12-20 04:42:55269 // Remove 'srcId' from 'inEdges[dstId]'.
270 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
271 if ((*it).id == srcId && (*it).memref == memref) {
272 inEdges[dstId].erase(it);
273 break;
274 }
275 }
276 // Remove 'dstId' from 'outEdges[srcId]'.
277 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
278 if ((*it).id == dstId && (*it).memref == memref) {
279 outEdges[srcId].erase(it);
280 break;
281 }
282 }
283 }
284
285 // Returns the input edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10286 unsigned getInEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55287 unsigned inEdgeCount = 0;
288 if (inEdges.count(id) > 0)
289 for (auto &inEdge : inEdges[id])
290 if (inEdge.memref == memref)
291 ++inEdgeCount;
292 return inEdgeCount;
293 }
294
295 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10296 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55297 unsigned outEdgeCount = 0;
298 if (outEdges.count(id) > 0)
299 for (auto &outEdge : outEdges[id])
300 if (outEdge.memref == memref)
301 ++outEdgeCount;
302 return outEdgeCount;
303 }
304
MLIR Team5c5739d2019-01-25 06:27:40305 // Check for a dependence in Block instruction list range (srcId, dstId) on
306 // memrefs other than 'memrefToSkip' (which will be privatized for the fused
307 // loop).
308 bool hasDependenceTargetInRange(unsigned srcId, unsigned dstId,
309 Value *memrefToSkip) {
310 if (outEdges.count(srcId) == 0)
311 return false;
312 // Check if any of the outgoing edge targets from srcId lie in
313 // (srcId, dstId).
314 SmallPtrSet<Instruction *, 2> depInsts;
315 for (auto &outEdge : outEdges[srcId]) {
316 if (outEdge.id != dstId && outEdge.memref != memrefToSkip) {
317 Node *node = getNode(outEdge.id);
318 depInsts.insert(node->inst);
319 }
320 }
321 // Do a linear walk from 'srcNode.inst' to 'dstNode.inst' and for each
322 // instruction 'inst' in range ('srcNode.inst', 'dstNode.inst') test
323 // if 'depInsts' contains 'inst', and return true if it does.
324 // TODO(andydavis) If this linear search becomes a compile time issue,
325 // create a data structure which allows a faster search through ForInsts
326 // in a Block.
327 Block::iterator it = std::next(Block::iterator(getNode(srcId)->inst));
328 Block::iterator itEnd = Block::iterator(getNode(dstId)->inst);
329 return std::any_of(it, itEnd, [&](Instruction &inst) {
330 return depInsts.count(&inst) > 0;
331 });
MLIR Team6892ffb2018-12-20 04:42:55332 }
333
MLIR Teamc4237ae2019-01-18 16:56:27334 // Updates edge mappings from node 'srcId' to node 'dstId'.
335 void updateEdges(unsigned srcId, unsigned dstId) {
MLIR Team6892ffb2018-12-20 04:42:55336 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
337 if (inEdges.count(srcId) > 0) {
338 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
339 for (auto &inEdge : oldInEdges) {
MLIR Team6892ffb2018-12-20 04:42:55340 // Add edge from 'inEdge.id' to 'dstId'.
341 addEdge(inEdge.id, dstId, inEdge.memref);
342 }
343 }
MLIR Teamc4237ae2019-01-18 16:56:27344 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55345 if (outEdges.count(srcId) > 0) {
346 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
347 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27348 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
349 if (outEdge.id == dstId)
350 removeEdge(srcId, outEdge.id, outEdge.memref);
MLIR Team6892ffb2018-12-20 04:42:55351 }
352 }
MLIR Team6892ffb2018-12-20 04:42:55353 }
354
355 // Adds ops in 'loads' and 'stores' to node at 'id'.
Chris Lattner5187cfc2018-12-28 05:21:41356 void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
357 const SmallVectorImpl<OperationInst *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55358 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35359 for (auto *loadOpInst : loads)
360 node->loads.push_back(loadOpInst);
361 for (auto *storeOpInst : stores)
362 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55363 }
364
MLIR Teamc4237ae2019-01-18 16:56:27365 void clearNodeLoadAndStores(unsigned id) {
366 Node *node = getNode(id);
367 node->loads.clear();
368 node->stores.clear();
369 }
370
MLIR Team6892ffb2018-12-20 04:42:55371 void print(raw_ostream &os) const {
372 os << "\nMemRefDependenceGraph\n";
373 os << "\nNodes:\n";
374 for (auto &idAndNode : nodes) {
375 os << "Node: " << idAndNode.first << "\n";
376 auto it = inEdges.find(idAndNode.first);
377 if (it != inEdges.end()) {
378 for (const auto &e : it->second)
379 os << " InEdge: " << e.id << " " << e.memref << "\n";
380 }
381 it = outEdges.find(idAndNode.first);
382 if (it != outEdges.end()) {
383 for (const auto &e : it->second)
384 os << " OutEdge: " << e.id << " " << e.memref << "\n";
385 }
386 }
387 }
388 void dump() const { print(llvm::errs()); }
389};
390
Chris Lattner456ad6a2018-12-29 00:05:35391// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55392// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39393// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55394// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09395bool MemRefDependenceGraph::init(Function *f) {
MLIR Team6892ffb2018-12-20 04:42:55396 unsigned id = 0;
Chris Lattner3f190312018-12-27 22:35:10397 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43398
399 // TODO: support multi-block functions.
400 if (f->getBlocks().size() != 1)
401 return false;
402
403 for (auto &inst : f->front()) {
Chris Lattner456ad6a2018-12-29 00:05:35404 if (auto *forInst = dyn_cast<ForInst>(&inst)) {
405 // Create graph node 'id' to represent top-level 'forInst' and record
MLIR Team6892ffb2018-12-20 04:42:55406 // all loads and store accesses it contains.
407 LoopNestStateCollector collector;
Chris Lattner456ad6a2018-12-29 00:05:35408 collector.walkForInst(forInst);
409 // Return false if IfInsts are found (not currently supported).
410 if (collector.hasIfInst)
MLIR Team6892ffb2018-12-20 04:42:55411 return false;
Chris Lattner456ad6a2018-12-29 00:05:35412 Node node(id++, &inst);
413 for (auto *opInst : collector.loadOpInsts) {
414 node.loads.push_back(opInst);
415 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55416 memrefAccesses[memref].insert(node.id);
417 }
Chris Lattner456ad6a2018-12-29 00:05:35418 for (auto *opInst : collector.storeOpInsts) {
419 node.stores.push_back(opInst);
420 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55421 memrefAccesses[memref].insert(node.id);
422 }
423 nodes.insert({node.id, node});
424 }
Chris Lattner456ad6a2018-12-29 00:05:35425 if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
426 if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55427 // Create graph node for top-level load op.
Chris Lattner456ad6a2018-12-29 00:05:35428 Node node(id++, &inst);
429 node.loads.push_back(opInst);
430 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55431 memrefAccesses[memref].insert(node.id);
432 nodes.insert({node.id, node});
433 }
Chris Lattner456ad6a2018-12-29 00:05:35434 if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55435 // Create graph node for top-level store op.
Chris Lattner456ad6a2018-12-29 00:05:35436 Node node(id++, &inst);
437 node.stores.push_back(opInst);
438 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55439 memrefAccesses[memref].insert(node.id);
440 nodes.insert({node.id, node});
441 }
442 }
Chris Lattner456ad6a2018-12-29 00:05:35443 // Return false if IfInsts are found (not currently supported).
444 if (isa<IfInst>(&inst))
MLIR Team6892ffb2018-12-20 04:42:55445 return false;
446 }
447
448 // Walk memref access lists and add graph edges between dependent nodes.
449 for (auto &memrefAndList : memrefAccesses) {
450 unsigned n = memrefAndList.second.size();
451 for (unsigned i = 0; i < n; ++i) {
452 unsigned srcId = memrefAndList.second[i];
453 bool srcHasStore =
454 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
455 for (unsigned j = i + 1; j < n; ++j) {
456 unsigned dstId = memrefAndList.second[j];
457 bool dstHasStore =
458 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
459 if (srcHasStore || dstHasStore)
460 addEdge(srcId, dstId, memrefAndList.first);
461 }
462 }
463 }
464 return true;
465}
466
MLIR Team38c2fe32019-01-14 19:26:25467namespace {
468
469// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
470// and operation count) for a loop nest up until the innermost loop body.
471struct LoopNestStats {
472 // Map from ForInst to immediate child ForInsts in its loop body.
473 DenseMap<ForInst *, SmallVector<ForInst *, 2>> loopMap;
474 // Map from ForInst to count of operations in its loop body.
475 DenseMap<ForInst *, uint64_t> opCountMap;
476 // Map from ForInst to its constant trip count.
477 DenseMap<ForInst *, uint64_t> tripCountMap;
478};
479
480// LoopNestStatsCollector walks a single loop nest and gathers per-loop
481// trip count and operation count statistics and records them in 'stats'.
482class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
483public:
484 LoopNestStats *stats;
485 bool hasLoopWithNonConstTripCount = false;
486
487 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
488
489 void visitForInst(ForInst *forInst) {
490 auto *parentInst = forInst->getParentInst();
491 if (parentInst != nullptr) {
492 assert(isa<ForInst>(parentInst) && "Expected parent ForInst");
493 // Add mapping to 'forInst' from its parent ForInst.
494 stats->loopMap[cast<ForInst>(parentInst)].push_back(forInst);
495 }
496 // Record the number of op instructions in the body of 'forInst'.
497 unsigned count = 0;
498 stats->opCountMap[forInst] = 0;
499 for (auto &inst : *forInst->getBody()) {
500 if (isa<OperationInst>(&inst))
501 ++count;
502 }
503 stats->opCountMap[forInst] = count;
504 // Record trip count for 'forInst'. Set flag if trip count is not constant.
505 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(*forInst);
506 if (!maybeConstTripCount.hasValue()) {
507 hasLoopWithNonConstTripCount = true;
508 return;
509 }
510 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
511 }
512};
513
514// Computes the total cost of the loop nest rooted at 'forInst'.
515// Currently, the total cost is computed by counting the total operation
516// instance count (i.e. total number of operations in the loop bodyloop
517// operation count * loop trip count) for the entire loop nest.
518// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
519// specified in the map when computing the total op instance count.
520// NOTE: this is used to compute the cost of computation slices, which are
521// sliced along the iteration dimension, and thus reduce the trip count.
522// If 'computeCostMap' is non-null, the total op count for forInsts specified
523// in the map is increased (not overridden) by adding the op count from the
524// map to the existing op count for the for loop. This is done before
525// multiplying by the loop's trip count, and is used to model the cost of
526// inserting a sliced loop nest of known cost into the loop's body.
527// NOTE: this is used to compute the cost of fusing a slice of some loop nest
528// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24529static int64_t getComputeCost(
MLIR Team27d067e2019-01-16 17:55:02530 ForInst *forInst, LoopNestStats *stats,
531 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
Uday Bondhugula864d9e02019-01-23 17:16:24532 DenseMap<ForInst *, int64_t> *computeCostMap) {
MLIR Team38c2fe32019-01-14 19:26:25533 // 'opCount' is the total number operations in one iteration of 'forInst' body
Uday Bondhugula864d9e02019-01-23 17:16:24534 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25535 if (stats->loopMap.count(forInst) > 0) {
536 for (auto *childForInst : stats->loopMap[forInst]) {
537 opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
538 computeCostMap);
539 }
540 }
541 // Add in additional op instances from slice (if specified in map).
542 if (computeCostMap != nullptr) {
543 auto it = computeCostMap->find(forInst);
544 if (it != computeCostMap->end()) {
545 opCount += it->second;
546 }
547 }
548 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24549 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25550 if (tripCountOverrideMap != nullptr) {
551 auto it = tripCountOverrideMap->find(forInst);
552 if (it != tripCountOverrideMap->end()) {
553 tripCount = it->second;
554 }
555 }
556 // Returns the total number of dynamic instances of operations in loop body.
557 return tripCount * opCount;
558}
559
560} // end anonymous namespace
561
MLIR Team27d067e2019-01-16 17:55:02562static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00563 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
564 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02565 assert(lbMap.getNumDims() == ubMap.getNumDims());
566 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
567 // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
568 // ub_expr - lb_expr
569 AffineExpr lbExpr(lbMap.getResult(0));
570 AffineExpr ubExpr(ubMap.getResult(0));
571 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
572 lbMap.getNumSymbols());
573 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
574 if (!cExpr)
575 return None;
576 return cExpr.getValue();
577}
578
MLIR Team38c2fe32019-01-14 19:26:25579// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
580// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
581// Returns true on success, false otherwise (if a non-constant trip count
582// was encountered).
583// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02584static bool buildSliceTripCountMap(
585 OperationInst *srcOpInst, ComputationSliceState *sliceState,
586 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
MLIR Team38c2fe32019-01-14 19:26:25587 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02588 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25589 unsigned numSrcLoopIVs = srcLoopIVs.size();
590 // Populate map from ForInst -> trip count
591 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
592 AffineMap lbMap = sliceState->lbs[i];
593 AffineMap ubMap = sliceState->ubs[i];
594 if (lbMap == AffineMap::Null() || ubMap == AffineMap::Null()) {
595 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
596 if (srcLoopIVs[i]->hasConstantLowerBound() &&
597 srcLoopIVs[i]->hasConstantUpperBound()) {
598 (*tripCountMap)[srcLoopIVs[i]] =
599 srcLoopIVs[i]->getConstantUpperBound() -
600 srcLoopIVs[i]->getConstantLowerBound();
601 continue;
602 }
603 return false;
604 }
MLIR Team27d067e2019-01-16 17:55:02605 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
606 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25607 return false;
MLIR Team27d067e2019-01-16 17:55:02608 (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25609 }
610 return true;
611}
612
MLIR Team27d067e2019-01-16 17:55:02613// Removes load operations from 'srcLoads' which operate on 'memref', and
614// adds them to 'dstLoads'.
615static void
616moveLoadsAccessingMemrefTo(Value *memref,
617 SmallVectorImpl<OperationInst *> *srcLoads,
618 SmallVectorImpl<OperationInst *> *dstLoads) {
619 dstLoads->clear();
620 SmallVector<OperationInst *, 4> srcLoadsToKeep;
621 for (auto *load : *srcLoads) {
622 if (load->cast<LoadOp>()->getMemRef() == memref)
623 dstLoads->push_back(load);
624 else
625 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25626 }
MLIR Team27d067e2019-01-16 17:55:02627 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25628}
629
MLIR Team27d067e2019-01-16 17:55:02630// Returns the innermost common loop depth for the set of operations in 'ops'.
631static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
632 unsigned numOps = ops.size();
633 assert(numOps > 0);
634
635 std::vector<SmallVector<ForInst *, 4>> loops(numOps);
636 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
637 for (unsigned i = 0; i < numOps; ++i) {
638 getLoopIVs(*ops[i], &loops[i]);
639 loopDepthLimit =
640 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25641 }
MLIR Team27d067e2019-01-16 17:55:02642
643 unsigned loopDepth = 0;
644 for (unsigned d = 0; d < loopDepthLimit; ++d) {
645 unsigned i;
646 for (i = 1; i < numOps; ++i) {
647 if (loops[i - 1][d] != loops[i][d]) {
648 break;
649 }
650 }
651 if (i != numOps)
652 break;
653 ++loopDepth;
654 }
655 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25656}
657
Uday Bondhugulac1ca23e2019-01-16 21:13:00658// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
659// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:02660// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
661// and 'sliceStateB' are aligned.
662// Specifically, when taking the union of overlapping intervals, it assumes
663// that both intervals start at zero. Support needs to be added to take into
664// account interval start offset when computing the union.
665// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:00666static bool getSliceUnion(const ComputationSliceState &sliceStateA,
667 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:02668 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
669 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
670
671 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
672 AffineMap lbMapA = sliceStateA.lbs[i];
673 AffineMap ubMapA = sliceStateA.ubs[i];
674 if (lbMapA == AffineMap::Null()) {
675 assert(ubMapA == AffineMap::Null());
676 continue;
677 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00678 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:02679
680 AffineMap lbMapB = sliceStateB->lbs[i];
681 AffineMap ubMapB = sliceStateB->ubs[i];
682 if (lbMapB == AffineMap::Null()) {
683 assert(ubMapB == AffineMap::Null());
684 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
685 sliceStateB->lbs[i] = lbMapA;
686 sliceStateB->ubs[i] = ubMapA;
687 continue;
688 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00689
690 // TODO(andydavis) Change this code to take the min across all lower bounds
691 // and max across all upper bounds for each dimension. This code can for
692 // cases where a unique min or max could not be statically determined.
693
694 // Assumption: both lower bounds are the same.
695 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:02696 return false;
697
698 // Add bound with the largest trip count to union.
699 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
700 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
701 if (!tripCountA.hasValue() || !tripCountB.hasValue())
702 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:00703
MLIR Team27d067e2019-01-16 17:55:02704 if (tripCountA.getValue() > tripCountB.getValue()) {
705 sliceStateB->lbs[i] = lbMapA;
706 sliceStateB->ubs[i] = ubMapA;
707 }
708 }
709 return true;
710}
711
MLIR Teamc4237ae2019-01-18 16:56:27712// Creates and returns a private (single-user) memref for fused loop rooted
713// at 'forInst', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:52714// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
715// TODO(bondhugula): consider refactoring the common code from generateDma and
716// this one.
MLIR Teamc4237ae2019-01-18 16:56:27717static Value *createPrivateMemRef(ForInst *forInst,
Uday Bondhugula94a03f82019-01-22 21:58:52718 OperationInst *srcStoreOpInst,
719 unsigned dstLoopDepth) {
MLIR Teamc4237ae2019-01-18 16:56:27720 // Create builder to insert alloc op just before 'forInst'.
721 FuncBuilder b(forInst);
722 // Builder to create constants at the top level.
723 FuncBuilder top(forInst->getFunction());
724 // Create new memref type based on slice bounds.
725 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>()->getMemRef();
726 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
727 unsigned rank = oldMemRefType.getRank();
728
Uday Bondhugula94a03f82019-01-22 21:58:52729 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27730 MemRefRegion region;
Uday Bondhugula94a03f82019-01-22 21:58:52731 getMemRefRegion(srcStoreOpInst, dstLoopDepth, &region);
River Riddle6859f332019-01-23 22:39:45732 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:27733 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:52734 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:27735 lbs.reserve(rank);
736 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:52737 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27738 Optional<int64_t> numElements =
Uday Bondhugula94a03f82019-01-22 21:58:52739 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
MLIR Teamc4237ae2019-01-18 16:56:27740 assert(numElements.hasValue());
741
MLIR Teamc4237ae2019-01-18 16:56:27742 const FlatAffineConstraints *cst = region.getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:52743 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
744 // on; this would correspond to loop IVs surrounding the level at which the
745 // slice is being materialized.
746 SmallVector<Value *, 8> outerIVs;
747 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
748
749 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:27750 SmallVector<AffineExpr, 4> offsets;
751 offsets.reserve(rank);
752 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:52753 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
754
MLIR Teamc4237ae2019-01-18 16:56:27755 AffineExpr offset = top.getAffineConstantExpr(0);
756 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
757 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
758 }
Uday Bondhugula94a03f82019-01-22 21:58:52759 assert(lbDivisors[d] > 0);
760 offset =
761 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:27762 offsets.push_back(offset);
763 }
764
765 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
766 // by 'srcStoreOpInst'.
767 auto newMemRefType = b.getMemRefType(newShape, oldMemRefType.getElementType(),
768 {}, oldMemRefType.getMemorySpace());
769 // Gather alloc operands for the dynamic dimensions of the memref.
770 SmallVector<Value *, 4> allocOperands;
771 unsigned dynamicDimCount = 0;
772 for (auto dimSize : oldMemRefType.getShape()) {
773 if (dimSize == -1)
774 allocOperands.push_back(
775 b.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++));
776 }
777
778 // Create new private memref for fused loop 'forInst'.
779 Value *newMemRef =
780 b.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
781
782 // Build an AffineMap to remap access functions based on lower bound offsets.
783 SmallVector<AffineExpr, 4> remapExprs;
784 remapExprs.reserve(rank);
785 unsigned zeroOffsetCount = 0;
786 for (unsigned i = 0; i < rank; i++) {
787 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
788 if (constExpr.getValue() == 0)
789 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:52790 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
791
792 auto remapExpr =
793 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
794 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:27795 }
Uday Bondhugula94a03f82019-01-22 21:58:52796 auto indexRemap =
797 zeroOffsetCount == rank
798 ? AffineMap::Null()
799 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:27800 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:52801 bool ret =
802 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
803 /*extraOperands=*/outerIVs,
804 /*domInstFilter=*/&*forInst->getBody()->begin());
805 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:37806 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:27807 return newMemRef;
808}
809
Uday Bondhugula864d9e02019-01-23 17:16:24810// Does the slice have a single iteration?
811static uint64_t getSliceIterationCount(
812 const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) {
813 uint64_t iterCount = 1;
814 for (const auto &count : sliceTripCountMap) {
815 iterCount *= count.second;
816 }
817 return iterCount;
818}
819
MLIR Team27d067e2019-01-16 17:55:02820// Checks the profitability of fusing a backwards slice of the loop nest
821// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
Uday Bondhugulab4a14432019-01-26 00:00:50822// Returns true if it is profitable to fuse the candidate loop nests. Returns
823// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
824// to materialize the source loop nest slice.
MLIR Team38c2fe32019-01-14 19:26:25825// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:02826// *) Computes the backward computation slice at 'srcOpInst'. This
827// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:25828// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:02829// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:25830// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
831// loop nest is the total number of dynamic operation instances in the loop
832// nest).
833// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:02834// loop nest at various values of dst loop depth, attempting to fuse
835// the largest compution slice at the maximal dst loop depth (closest to the
836// load) to minimize reuse distance and potentially enable subsequent
837// load/store forwarding.
838// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for
839// the same memref as is written by 'srcOpInst', then the union of slice
840// loop bounds is used to compute the slice and associated slice cost.
Uday Bondhugulab4a14432019-01-26 00:00:50841// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
MLIR Team38c2fe32019-01-14 19:26:25842// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:02843// NOTE: We attempt to maximize the dst loop depth, but there are cases
844// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:25845// loop (within the src computation slice) at a depth which results in
846// execessive recomputation (see unit tests for examples).
847// *) Compares the total cost of the unfused loop nests to the min cost fused
848// loop nest computed in the previous step, and returns true if the latter
849// is lower.
MLIR Team27d067e2019-01-16 17:55:02850static bool isFusionProfitable(OperationInst *srcOpInst,
851 ArrayRef<OperationInst *> dstOpInsts,
MLIR Team38c2fe32019-01-14 19:26:25852 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:02853 unsigned *dstLoopDepth) {
Uday Bondhugula06d21d92019-01-25 01:01:49854 LLVM_DEBUG({
855 llvm::dbgs() << "Checking whether fusion is profitable between:\n";
856 llvm::dbgs() << " ";
857 srcOpInst->dump();
858 llvm::dbgs() << " and \n";
859 for (auto dstOpInst : dstOpInsts) {
860 llvm::dbgs() << " ";
861 dstOpInst->dump();
862 };
863 });
Uday Bondhugula864d9e02019-01-23 17:16:24864
MLIR Team38c2fe32019-01-14 19:26:25865 // Compute cost of sliced and unsliced src loop nest.
866 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02867 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25868 unsigned numSrcLoopIVs = srcLoopIVs.size();
869
870 // Walk src loop nest and collect stats.
871 LoopNestStats srcLoopNestStats;
872 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
873 srcStatsCollector.walk(srcLoopIVs[0]);
874 // Currently only constant trip count loop nests are supported.
875 if (srcStatsCollector.hasLoopWithNonConstTripCount)
876 return false;
877
878 // Compute cost of dst loop nest.
879 SmallVector<ForInst *, 4> dstLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02880 getLoopIVs(*dstOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25881
882 LoopNestStats dstLoopNestStats;
883 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
884 dstStatsCollector.walk(dstLoopIVs[0]);
885 // Currently only constant trip count loop nests are supported.
886 if (dstStatsCollector.hasLoopWithNonConstTripCount)
887 return false;
888
MLIR Team27d067e2019-01-16 17:55:02889 // Compute the innermost common loop for ops in 'dstOpInst'.
890 unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts);
891 if (maxDstLoopDepth == 0)
892 return false;
893
894 // Search for min cost value for 'dstLoopDepth'. At each value of
895 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
896 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
897 // of these bounds). Next the union slice bounds are used to calculate
898 // the cost of the slice and the cost of the slice inserted into the dst
899 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:24900 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
901 uint64_t maxStorageReduction = 0;
902 Optional<uint64_t> sliceMemEstimate = None;
903
MLIR Team27d067e2019-01-16 17:55:02904 SmallVector<ComputationSliceState, 4> sliceStates;
905 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:24906 // The best loop depth at which to materialize the slice.
907 Optional<unsigned> bestDstLoopDepth = None;
908
909 // Compute op instance count for the src loop nest without iteration slicing.
910 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
911 /*tripCountOverrideMap=*/nullptr,
912 /*computeCostMap=*/nullptr);
913
914 // Compute op instance count for the src loop nest.
915 uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
916 /*tripCountOverrideMap=*/nullptr,
917 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:02918
919 llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
Uday Bondhugula864d9e02019-01-23 17:16:24920 DenseMap<ForInst *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:02921 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
922 MemRefAccess srcAccess(srcOpInst);
923 // Handle the common case of one dst load without a copy.
924 if (!mlir::getBackwardComputationSliceState(
925 srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1]))
926 return false;
927 // Compute the union of slice bound of all ops in 'dstOpInsts'.
928 for (int j = 1, e = dstOpInsts.size(); j < e; ++j) {
929 MemRefAccess dstAccess(dstOpInsts[j]);
930 ComputationSliceState tmpSliceState;
931 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
932 &tmpSliceState))
933 return false;
934 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:00935 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:25936 }
Uday Bondhugulab4a14432019-01-26 00:00:50937 // Build trip count map for computation slice. We'll skip cases where the
938 // trip count was non-constant.
MLIR Team27d067e2019-01-16 17:55:02939 sliceTripCountMap.clear();
940 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
941 &sliceTripCountMap))
Uday Bondhugula864d9e02019-01-23 17:16:24942 continue;
943
944 // Checks whether a store to load forwarding will happen.
945 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
Uday Bondhugula864d9e02019-01-23 17:16:24946 assert(sliceIterationCount > 0);
Uday Bondhugulab4a14432019-01-26 00:00:50947 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
Uday Bondhugula864d9e02019-01-23 17:16:24948
949 // Compute cost of fusion for this dest loop depth.
950
951 computeCostMap.clear();
952
953 // The store and loads to this memref will disappear.
954 if (storeLoadFwdGuaranteed) {
955 // A single store disappears: -1 for that.
956 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1;
957 for (auto *loadOp : dstOpInsts) {
958 if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst()))
959 computeCostMap[loadLoop] = -1;
960 }
961 }
MLIR Team27d067e2019-01-16 17:55:02962
MLIR Team38c2fe32019-01-14 19:26:25963 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:24964 int64_t sliceComputeCost =
965 getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
966 /*tripCountOverrideMap=*/&sliceTripCountMap,
967 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:25968
Uday Bondhugula864d9e02019-01-23 17:16:24969 // Compute cost of fusion for this depth.
MLIR Team27d067e2019-01-16 17:55:02970 computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:24971
972 int64_t fusedLoopNestComputeCost =
MLIR Team27d067e2019-01-16 17:55:02973 getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
974 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:24975
976 double additionalComputeFraction =
977 fusedLoopNestComputeCost /
978 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
979 1;
980
981 // TODO(bondhugula): This is an ugly approximation. Fix this by finding a
982 // good way to calculate the footprint of the memref in the slice and
983 // divide it by the total memory footprint of the fused computation.
984 double storageReduction =
985 static_cast<double>(srcLoopNestCost) / sliceIterationCount;
986
Uday Bondhugula06d21d92019-01-25 01:01:49987 LLVM_DEBUG({
988 std::stringstream msg;
989 msg << " evaluating fusion profitability at depth : " << i << "\n"
990 << std::setprecision(2) << " additional compute fraction: "
991 << 100.0 * additionalComputeFraction << "%\n"
992 << " storage reduction factor: " << storageReduction << "x\n"
993 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
994 << " slice iteration count: " << sliceIterationCount << "\n";
995 llvm::dbgs() << msg.str();
996 });
Uday Bondhugula864d9e02019-01-23 17:16:24997
998 double computeToleranceThreshold =
999 clFusionAddlComputeTolerance.getNumOccurrences() > 0
1000 ? clFusionAddlComputeTolerance
1001 : LoopFusion::kComputeToleranceThreshold;
1002
1003 // TODO(b/123247369): This is a placeholder cost model.
1004 // Among all choices that add an acceptable amount of redundant computation
1005 // (as per computeToleranceThreshold), we will simply pick the one that
1006 // reduces the intermediary size the most.
1007 if ((storageReduction > maxStorageReduction) &&
1008 (clMaximalLoopFusion ||
1009 (additionalComputeFraction < computeToleranceThreshold))) {
1010 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:021011 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:241012 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
1013 // TODO(bondhugula,andydavis): find a good way to compute the memory
1014 // footprint of the materialized slice.
1015 // Approximating this to the compute cost of the slice. This could be an
1016 // under-approximation or an overapproximation, but in many cases
1017 // accurate.
1018 sliceMemEstimate = sliceIterationCount;
MLIR Team38c2fe32019-01-14 19:26:251019 }
1020 }
1021
Uday Bondhugula864d9e02019-01-23 17:16:241022 // A simple cost model: fuse if it reduces the memory footprint. If
1023 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251024
Uday Bondhugula864d9e02019-01-23 17:16:241025 if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
1026 LLVM_DEBUG(llvm::dbgs()
1027 << "All fusion choices involve more than the threshold amount of"
1028 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251029 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241030 }
1031
1032 assert(bestDstLoopDepth.hasValue() &&
1033 "expected to have a value per logic above");
1034
1035 // Set dstLoopDepth based on best values from search.
1036 *dstLoopDepth = bestDstLoopDepth.getValue();
1037
1038 LLVM_DEBUG(
Uday Bondhugula06d21d92019-01-25 01:01:491039 llvm::dbgs() << " LoopFusion fusion stats:"
1040 << "\n best loop depth: " << bestDstLoopDepth
Uday Bondhugula864d9e02019-01-23 17:16:241041 << "\n src loop nest compute cost: " << srcLoopNestCost
1042 << "\n dst loop nest compute cost: " << dstLoopNestCost
1043 << "\n fused loop nest compute cost: "
1044 << minFusedLoopNestComputeCost << "\n");
1045
1046 auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]);
1047 auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]);
1048
1049 Optional<double> storageReduction = None;
1050
1051 if (!clMaximalLoopFusion) {
1052 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1053 LLVM_DEBUG(
1054 llvm::dbgs()
1055 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1056 return false;
1057 }
1058
1059 auto srcMemSizeVal = srcMemSize.getValue();
1060 auto dstMemSizeVal = dstMemSize.getValue();
1061
1062 assert(sliceMemEstimate.hasValue() && "expected value");
1063 // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
1064 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1065
1066 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1067 << " dst mem: " << dstMemSizeVal << "\n"
1068 << " fused mem: " << fusedMem << "\n"
1069 << " slice mem: " << sliceMemEstimate << "\n");
1070
1071 if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
1072 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1073 return false;
1074 }
1075 storageReduction =
1076 100.0 *
1077 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1078 }
1079
1080 double additionalComputeFraction =
1081 100.0 * (minFusedLoopNestComputeCost /
1082 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1083 1);
MLIR Team5c5739d2019-01-25 06:27:401084 (void)additionalComputeFraction;
Uday Bondhugula06d21d92019-01-25 01:01:491085 LLVM_DEBUG({
1086 std::stringstream msg;
1087 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
1088 << setprecision(2) << additionalComputeFraction
1089 << "% redundant computation and a ";
1090 msg << (storageReduction.hasValue()
1091 ? std::to_string(storageReduction.getValue())
1092 : "<unknown>");
1093 msg << "% storage reduction.\n";
1094 llvm::dbgs() << msg.str();
1095 });
Uday Bondhugula864d9e02019-01-23 17:16:241096
MLIR Team27d067e2019-01-16 17:55:021097 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241098 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021099 sliceState->lbs = bestSliceState->lbs;
1100 sliceState->ubs = bestSliceState->ubs;
1101 sliceState->lbOperands = bestSliceState->lbOperands;
1102 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241103
MLIR Team27d067e2019-01-16 17:55:021104 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251105 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
MLIR Team27d067e2019-01-16 17:55:021106 if (sliceState->lbs[i] != AffineMap::Null()) {
1107 canonicalizeMapAndOperands(&sliceState->lbs[i],
1108 &sliceState->lbOperands[i]);
1109 }
1110 if (sliceState->ubs[i] != AffineMap::Null()) {
1111 canonicalizeMapAndOperands(&sliceState->ubs[i],
1112 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251113 }
1114 }
1115 return true;
1116}
1117
MLIR Team6892ffb2018-12-20 04:42:551118// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:141119// relationship on a memref, with the goal of improving locality. Currently,
1120// this the producer/consumer relationship is required to be unique in the
Chris Lattner69d9e992018-12-28 16:48:091121// Function (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:001122//
MLIR Team3b692302018-12-17 17:57:141123// The steps of the algorithm are as follows:
1124//
MLIR Team6892ffb2018-12-20 04:42:551125// *) A worklist is initialized with node ids from the dependence graph.
1126// *) For each node id in the worklist:
Chris Lattner456ad6a2018-12-29 00:05:351127// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
1128// destination ForInst into which fusion will be attempted.
1129// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141130// *) For each LoadOp in 'dstLoadOps' do:
Chris Lattner69d9e992018-12-28 16:48:091131// *) Lookup dependent loop nests at earlier positions in the Function
MLIR Team3b692302018-12-17 17:57:141132// which have a single store op to the same memref.
1133// *) Check if dependences would be violated by the fusion. For example,
1134// the src loop nest may load from memrefs which are different than
1135// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:551136// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141137// bounds to be functions of 'dstLoopNest' IVs and symbols.
1138// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
1139// just before the dst load op user.
Chris Lattner456ad6a2018-12-29 00:05:351140// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:141141// and also add newly fuse load ops to 'dstLoopOps' to be considered
1142// as fusion dst load ops in another iteration.
1143// *) Remove old src loop nest and its associated state.
1144//
Chris Lattner456ad6a2018-12-29 00:05:351145// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141146// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551147// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141148//
MLIR Team6892ffb2018-12-20 04:42:551149// This greedy algorithm is not 'maximal' due to the current restriction of
1150// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141151//
1152// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551153// TODO(andydavis) Add support for fusing for input reuse (perhaps by
1154// constructing a graph with edges which represent loads from the same memref
MLIR Team5c5739d2019-01-25 06:27:401155// in two different loop nests.
MLIR Team6892ffb2018-12-20 04:42:551156struct GreedyFusion {
1157public:
1158 MemRefDependenceGraph *mdg;
MLIR Team3b692302018-12-17 17:57:141159 SmallVector<unsigned, 4> worklist;
MLIR Teamf28e4df2018-11-01 14:26:001160
MLIR Team6892ffb2018-12-20 04:42:551161 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
1162 // Initialize worklist with nodes from 'mdg'.
1163 worklist.resize(mdg->nodes.size());
1164 std::iota(worklist.begin(), worklist.end(), 0);
1165 }
MLIR Team3b692302018-12-17 17:57:141166
1167 void run() {
MLIR Team3b692302018-12-17 17:57:141168 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551169 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141170 worklist.pop_back();
MLIR Team6892ffb2018-12-20 04:42:551171 // Skip if this node was removed (fused into another node).
1172 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141173 continue;
MLIR Team6892ffb2018-12-20 04:42:551174 // Get 'dstNode' into which to attempt fusion.
1175 auto *dstNode = mdg->getNode(dstId);
1176 // Skip if 'dstNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:351177 if (!isa<ForInst>(dstNode->inst))
MLIR Team3b692302018-12-17 17:57:141178 continue;
1179
Chris Lattner5187cfc2018-12-28 05:21:411180 SmallVector<OperationInst *, 4> loads = dstNode->loads;
MLIR Team27d067e2019-01-16 17:55:021181 SmallVector<OperationInst *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271182 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551183 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021184 // Get memref of load on top of the stack.
1185 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271186 if (visitedMemrefs.count(memref) > 0)
1187 continue;
1188 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021189 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1190 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551191 // Skip if no input edges along which to fuse.
1192 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141193 continue;
MLIR Team6892ffb2018-12-20 04:42:551194 // Iterate through in edges for 'dstId'.
1195 for (auto &srcEdge : mdg->inEdges[dstId]) {
1196 // Skip 'srcEdge' if not for 'memref'.
1197 if (srcEdge.memref != memref)
1198 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241199
MLIR Team6892ffb2018-12-20 04:42:551200 auto *srcNode = mdg->getNode(srcEdge.id);
1201 // Skip if 'srcNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:351202 if (!isa<ForInst>(srcNode->inst))
MLIR Team6892ffb2018-12-20 04:42:551203 continue;
MLIR Teamb28009b2019-01-23 19:11:431204 // Skip if 'srcNode' has more than one store to any memref.
1205 // TODO(andydavis) Support fusing multi-output src loop nests.
1206 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551207 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241208
MLIR Team6892ffb2018-12-20 04:42:551209 // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly
1210 // TODO(andydavis) Track dependence type with edges, and just check
1211 // for WAW dependence edge here.
1212 if (mdg->getInEdgeCount(srcNode->id, memref) != 0)
1213 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241214
MLIR Team5c5739d2019-01-25 06:27:401215 // Skip if 'srcNode' has out edges on memrefs other than 'memref'
1216 // for nodes in instruction list range (srcNode.inst, dstNode.inst).
1217 if (mdg->hasDependenceTargetInRange(srcNode->id, dstNode->id, memref))
MLIR Team6892ffb2018-12-20 04:42:551218 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241219
Uday Bondhugulab4a14432019-01-26 00:00:501220 // Check if fusion would be profitable and at what depth.
MLIR Team6892ffb2018-12-20 04:42:551221 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351222 auto *srcStoreOpInst = srcNode->stores.front();
Uday Bondhugulab4a14432019-01-26 00:00:501223 unsigned bestDstLoopDepth;
MLIR Team38c2fe32019-01-14 19:26:251224 mlir::ComputationSliceState sliceState;
MLIR Team27d067e2019-01-16 17:55:021225 if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
Uday Bondhugulab4a14432019-01-26 00:00:501226 &bestDstLoopDepth))
MLIR Team38c2fe32019-01-14 19:26:251227 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241228
MLIR Team6892ffb2018-12-20 04:42:551229 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1230 auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
Uday Bondhugulab4a14432019-01-26 00:00:501231 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:551232 if (sliceLoopNest != nullptr) {
MLIR Teamc4237ae2019-01-18 16:56:271233 // Update edges between 'srcNode' and 'dstNode'.
1234 mdg->updateEdges(srcNode->id, dstNode->id);
1235
1236 // Collect slice loop stats.
1237 LoopNestStateCollector sliceCollector;
1238 sliceCollector.walkForInst(sliceLoopNest);
1239 // Promote single iteration slice loops to single IV value.
1240 for (auto *forInst : sliceCollector.forInsts) {
Chris Lattner456ad6a2018-12-29 00:05:351241 promoteIfSingleIteration(forInst);
MLIR Team6892ffb2018-12-20 04:42:551242 }
MLIR Teamc4237ae2019-01-18 16:56:271243
1244 // Create private memref for 'memref' in 'dstForInst'.
1245 auto *dstForInst = cast<ForInst>(dstNode->inst);
1246 SmallVector<OperationInst *, 4> storesForMemref;
1247 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
1248 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1249 storesForMemref.push_back(storeOpInst);
1250 }
1251 assert(storesForMemref.size() == 1);
Uday Bondhugula94a03f82019-01-22 21:58:521252 auto *newMemRef = createPrivateMemRef(
Uday Bondhugulab4a14432019-01-26 00:00:501253 dstForInst, storesForMemref[0], bestDstLoopDepth);
MLIR Teamc4237ae2019-01-18 16:56:271254 visitedMemrefs.insert(newMemRef);
1255
1256 // Collect dst loop stats after memref privatizaton transformation.
1257 LoopNestStateCollector dstLoopCollector;
1258 dstLoopCollector.walkForInst(dstForInst);
1259
1260 // Add new load ops to current Node load op list 'loads' to
1261 // continue fusing based on new operands.
1262 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
1263 auto *loadMemRef = loadOpInst->cast<LoadOp>()->getMemRef();
1264 if (visitedMemrefs.count(loadMemRef) == 0)
1265 loads.push_back(loadOpInst);
1266 }
1267
1268 // Clear and add back loads and stores
1269 mdg->clearNodeLoadAndStores(dstNode->id);
1270 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1271 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371272 // Remove old src loop nest if it no longer has outgoing dependence
1273 // edges, and it does not write to a memref which escapes the
1274 // function.
1275 if (!mdg->hasOutEdges(srcNode->id) &&
1276 !mdg->writesToLiveInOrEscapingMemrefs(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271277 mdg->removeNode(srcNode->id);
1278 cast<ForInst>(srcNode->inst)->erase();
1279 }
MLIR Team3b692302018-12-17 17:57:141280 }
MLIR Team3b692302018-12-17 17:57:141281 }
1282 }
1283 }
MLIR Teamc4237ae2019-01-18 16:56:271284 // Clean up any allocs with no users.
1285 for (auto &pair : mdg->memrefEdgeCount) {
1286 if (pair.second > 0)
1287 continue;
1288 auto *memref = pair.first;
MLIR Team71495d52019-01-22 21:23:371289 // Skip if there exist other uses (return instruction or function calls).
1290 if (!memref->use_empty())
1291 continue;
MLIR Teamc4237ae2019-01-18 16:56:271292 // Use list expected to match the dep graph info.
MLIR Teamc4237ae2019-01-18 16:56:271293 auto *inst = memref->getDefiningInst();
1294 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
1295 if (opInst && opInst->isa<AllocOp>())
1296 opInst->erase();
1297 }
MLIR Teamf28e4df2018-11-01 14:26:001298 }
MLIR Team3b692302018-12-17 17:57:141299};
1300
1301} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:001302
Chris Lattner79748892018-12-31 07:10:351303PassResult LoopFusion::runOnFunction(Function *f) {
MLIR Team6892ffb2018-12-20 04:42:551304 MemRefDependenceGraph g;
1305 if (g.init(f))
1306 GreedyFusion(&g).run();
MLIR Teamf28e4df2018-11-01 14:26:001307 return success();
1308}
Jacques Pienaar6f0fb222018-11-07 02:34:181309
1310static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");