blob: 24914878656a86892cb90caf67b0f8b5bd0b9f18 [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1423#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0024#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1425#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0026#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinOps.h"
Chris Lattner456ad6a2018-12-29 00:05:3530#include "mlir/IR/InstVisitor.h"
MLIR Teamf28e4df2018-11-01 14:26:0031#include "mlir/Pass.h"
32#include "mlir/StandardOps/StandardOps.h"
33#include "mlir/Transforms/LoopUtils.h"
34#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2735#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0036#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1437#include "llvm/ADT/DenseSet.h"
38#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2339#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2540#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1441#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2442#include <iomanip>
MLIR Team3b692302018-12-17 17:57:1443
MLIR Team38c2fe32019-01-14 19:26:2544#define DEBUG_TYPE "loop-fusion"
45
MLIR Team3b692302018-12-17 17:57:1446using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0047
48using namespace mlir;
49
Uday Bondhugula864d9e02019-01-23 17:16:2450/// Disables fusion profitability check and fuses if valid.
MLIR Teamc4237ae2019-01-18 16:56:2751static llvm::cl::opt<bool>
52 clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
Uday Bondhugula864d9e02019-01-23 17:16:2453 llvm::cl::desc("Enables maximal loop fusion"));
54
55/// A threshold in percent of additional computation allowed when fusing.
56static llvm::cl::opt<double> clFusionAddlComputeTolerance(
57 "fusion-compute-tolerance", llvm::cl::Hidden,
58 llvm::cl::desc("Fractional increase in additional"
59 "computation tolerated while fusing"));
MLIR Teamc4237ae2019-01-18 16:56:2760
MLIR Teamf28e4df2018-11-01 14:26:0061namespace {
62
MLIR Team3b692302018-12-17 17:57:1463/// Loop fusion pass. This pass currently supports a greedy fusion policy,
64/// which fuses loop nests with single-writer/single-reader memref dependences
65/// with the goal of improving locality.
66
67// TODO(andydavis) Support fusion of source loop nests which write to multiple
68// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0069// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
70// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1471
MLIR Teamf28e4df2018-11-01 14:26:0072struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0373 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0074
Chris Lattner79748892018-12-31 07:10:3575 PassResult runOnFunction(Function *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1876 static char passID;
Uday Bondhugula864d9e02019-01-23 17:16:2477
78 // The amount of additional computation that is tolerated while fusing
79 // pair-wise as a fraction of the total computation.
80 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:0081};
82
MLIR Teamf28e4df2018-11-01 14:26:0083} // end anonymous namespace
84
Jacques Pienaar6f0fb222018-11-07 02:34:1885char LoopFusion::passID = 0;
86
MLIR Teamf28e4df2018-11-01 14:26:0087FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
88
MLIR Team3b692302018-12-17 17:57:1489namespace {
MLIR Teamf28e4df2018-11-01 14:26:0090
MLIR Team3b692302018-12-17 17:57:1491// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:3592// operations, and whether or not an IfInst was encountered in the loop nest.
93class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
MLIR Team3b692302018-12-17 17:57:1494public:
Chris Lattner456ad6a2018-12-29 00:05:3595 SmallVector<ForInst *, 4> forInsts;
96 SmallVector<OperationInst *, 4> loadOpInsts;
97 SmallVector<OperationInst *, 4> storeOpInsts;
98 bool hasIfInst = false;
MLIR Team3b692302018-12-17 17:57:1499
Chris Lattner456ad6a2018-12-29 00:05:35100 void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
MLIR Team3b692302018-12-17 17:57:14101
Chris Lattner456ad6a2018-12-29 00:05:35102 void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
MLIR Team3b692302018-12-17 17:57:14103
Chris Lattner456ad6a2018-12-29 00:05:35104 void visitOperationInst(OperationInst *opInst) {
105 if (opInst->isa<LoadOp>())
106 loadOpInsts.push_back(opInst);
107 if (opInst->isa<StoreOp>())
108 storeOpInsts.push_back(opInst);
MLIR Team3b692302018-12-17 17:57:14109 }
110};
111
MLIR Team71495d52019-01-22 21:23:37112// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
113static bool isMemRefDereferencingOp(const OperationInst &op) {
114 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
115 op.isa<DmaWaitOp>())
116 return true;
117 return false;
118}
MLIR Team6892ffb2018-12-20 04:42:55119// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:35120// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55121// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27122// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55123// TODO(andydavis) Add a depth parameter to dependence graph construction.
124struct MemRefDependenceGraph {
125public:
126 // Node represents a node in the graph. A Node is either an entire loop nest
127 // rooted at the top level which contains loads/stores, or a top level
128 // load/store.
129 struct Node {
130 // The unique identifier of this node in the graph.
131 unsigned id;
132 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35133 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41134 // List of load operations.
135 SmallVector<OperationInst *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35136 // List of store op insts.
Chris Lattner5187cfc2018-12-28 05:21:41137 SmallVector<OperationInst *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35138 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55139
140 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10141 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55142 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35143 for (auto *loadOpInst : loads) {
144 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55145 ++loadOpCount;
146 }
147 return loadOpCount;
148 }
149
150 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10151 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55152 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35153 for (auto *storeOpInst : stores) {
154 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55155 ++storeOpCount;
156 }
157 return storeOpCount;
158 }
159 };
160
161 // Edge represents a memref data dependece between nodes in the graph.
162 struct Edge {
163 // The id of the node at the other end of the edge.
164 unsigned id;
165 // The memref on which this edge represents a dependence.
Chris Lattner3f190312018-12-27 22:35:10166 Value *memref;
MLIR Team6892ffb2018-12-20 04:42:55167 };
168
169 // Map from node id to Node.
170 DenseMap<unsigned, Node> nodes;
171 // Map from node id to list of input edges.
172 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
173 // Map from node id to list of output edges.
174 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27175 // Map from memref to a count on the dependence edges associated with that
176 // memref.
177 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Team6892ffb2018-12-20 04:42:55178
179 MemRefDependenceGraph() {}
180
181 // Initializes the dependence graph based on operations in 'f'.
182 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09183 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55184
185 // Returns the graph node for 'id'.
186 Node *getNode(unsigned id) {
187 auto it = nodes.find(id);
188 assert(it != nodes.end());
189 return &it->second;
190 }
191
MLIR Teamc4237ae2019-01-18 16:56:27192 // Remove node 'id' (and its associated edges) from graph.
193 void removeNode(unsigned id) {
194 // Remove each edge in 'inEdges[id]'.
195 if (inEdges.count(id) > 0) {
196 SmallVector<Edge, 2> oldInEdges = inEdges[id];
197 for (auto &inEdge : oldInEdges) {
198 removeEdge(inEdge.id, id, inEdge.memref);
199 }
200 }
201 // Remove each edge in 'outEdges[id]'.
202 if (outEdges.count(id) > 0) {
203 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
204 for (auto &outEdge : oldOutEdges) {
205 removeEdge(id, outEdge.id, outEdge.memref);
206 }
207 }
208 // Erase remaining node state.
209 inEdges.erase(id);
210 outEdges.erase(id);
211 nodes.erase(id);
212 }
213
214 bool hasOutEdges(unsigned id) {
215 return outEdges.count(id) > 0 && !outEdges[id].empty();
216 }
217
MLIR Team71495d52019-01-22 21:23:37218 // Returns true if node 'id' writes to any memref which escapes (or is an
219 // argument to) the function/block. Returns false otherwise.
220 bool writesToLiveInOrEscapingMemrefs(unsigned id) {
221 Node *node = getNode(id);
222 for (auto *storeOpInst : node->stores) {
223 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
224 auto *inst = memref->getDefiningInst();
225 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
226 // Return false if 'memref' is a function argument.
227 if (opInst == nullptr)
228 return true;
229 // Return false if any use of 'memref' escapes the function.
230 for (auto &use : memref->getUses()) {
231 auto *user = dyn_cast<OperationInst>(use.getOwner());
232 if (!user || !isMemRefDereferencingOp(*user))
233 return true;
234 }
235 }
236 return false;
237 }
238
MLIR Team27d067e2019-01-16 17:55:02239 // Returns true iff there is an edge from node 'srcId' to node 'dstId' for
240 // 'memref'. Returns false otherwise.
241 bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) {
242 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
243 return false;
244 }
245 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
246 return edge.id == dstId && edge.memref == memref;
247 });
248 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
249 return edge.id == srcId && edge.memref == memref;
250 });
251 return hasOutEdge && hasInEdge;
252 }
253
MLIR Team6892ffb2018-12-20 04:42:55254 // Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10255 void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team27d067e2019-01-16 17:55:02256 if (!hasEdge(srcId, dstId, memref)) {
257 outEdges[srcId].push_back({dstId, memref});
258 inEdges[dstId].push_back({srcId, memref});
MLIR Teamc4237ae2019-01-18 16:56:27259 memrefEdgeCount[memref]++;
MLIR Team27d067e2019-01-16 17:55:02260 }
MLIR Team6892ffb2018-12-20 04:42:55261 }
262
263 // Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10264 void removeEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55265 assert(inEdges.count(dstId) > 0);
266 assert(outEdges.count(srcId) > 0);
MLIR Teamc4237ae2019-01-18 16:56:27267 assert(memrefEdgeCount.count(memref) > 0);
268 memrefEdgeCount[memref]--;
MLIR Team6892ffb2018-12-20 04:42:55269 // Remove 'srcId' from 'inEdges[dstId]'.
270 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
271 if ((*it).id == srcId && (*it).memref == memref) {
272 inEdges[dstId].erase(it);
273 break;
274 }
275 }
276 // Remove 'dstId' from 'outEdges[srcId]'.
277 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
278 if ((*it).id == dstId && (*it).memref == memref) {
279 outEdges[srcId].erase(it);
280 break;
281 }
282 }
283 }
284
285 // Returns the input edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10286 unsigned getInEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55287 unsigned inEdgeCount = 0;
288 if (inEdges.count(id) > 0)
289 for (auto &inEdge : inEdges[id])
290 if (inEdge.memref == memref)
291 ++inEdgeCount;
292 return inEdgeCount;
293 }
294
295 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10296 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55297 unsigned outEdgeCount = 0;
298 if (outEdges.count(id) > 0)
299 for (auto &outEdge : outEdges[id])
300 if (outEdge.memref == memref)
301 ++outEdgeCount;
302 return outEdgeCount;
303 }
304
MLIR Teamc4237ae2019-01-18 16:56:27305 // Returns the min node id across all outgoing edges from node 'id', skipping
306 // edges with 'memrefToSkip'.
307 unsigned getMinOutEdgeNodeId(unsigned id, Value *memrefToSkip) {
MLIR Team6892ffb2018-12-20 04:42:55308 unsigned minId = std::numeric_limits<unsigned>::max();
309 if (outEdges.count(id) > 0)
310 for (auto &outEdge : outEdges[id])
MLIR Teamc4237ae2019-01-18 16:56:27311 if (outEdge.memref != memrefToSkip)
312 minId = std::min(minId, outEdge.id);
MLIR Team6892ffb2018-12-20 04:42:55313 return minId;
314 }
315
MLIR Teamc4237ae2019-01-18 16:56:27316 // Updates edge mappings from node 'srcId' to node 'dstId'.
317 void updateEdges(unsigned srcId, unsigned dstId) {
MLIR Team6892ffb2018-12-20 04:42:55318 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
319 if (inEdges.count(srcId) > 0) {
320 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
321 for (auto &inEdge : oldInEdges) {
MLIR Team6892ffb2018-12-20 04:42:55322 // Add edge from 'inEdge.id' to 'dstId'.
323 addEdge(inEdge.id, dstId, inEdge.memref);
324 }
325 }
MLIR Teamc4237ae2019-01-18 16:56:27326 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55327 if (outEdges.count(srcId) > 0) {
328 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
329 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27330 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
331 if (outEdge.id == dstId)
332 removeEdge(srcId, outEdge.id, outEdge.memref);
MLIR Team6892ffb2018-12-20 04:42:55333 }
334 }
MLIR Team6892ffb2018-12-20 04:42:55335 }
336
337 // Adds ops in 'loads' and 'stores' to node at 'id'.
Chris Lattner5187cfc2018-12-28 05:21:41338 void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
339 const SmallVectorImpl<OperationInst *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55340 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35341 for (auto *loadOpInst : loads)
342 node->loads.push_back(loadOpInst);
343 for (auto *storeOpInst : stores)
344 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55345 }
346
MLIR Teamc4237ae2019-01-18 16:56:27347 void clearNodeLoadAndStores(unsigned id) {
348 Node *node = getNode(id);
349 node->loads.clear();
350 node->stores.clear();
351 }
352
MLIR Team6892ffb2018-12-20 04:42:55353 void print(raw_ostream &os) const {
354 os << "\nMemRefDependenceGraph\n";
355 os << "\nNodes:\n";
356 for (auto &idAndNode : nodes) {
357 os << "Node: " << idAndNode.first << "\n";
358 auto it = inEdges.find(idAndNode.first);
359 if (it != inEdges.end()) {
360 for (const auto &e : it->second)
361 os << " InEdge: " << e.id << " " << e.memref << "\n";
362 }
363 it = outEdges.find(idAndNode.first);
364 if (it != outEdges.end()) {
365 for (const auto &e : it->second)
366 os << " OutEdge: " << e.id << " " << e.memref << "\n";
367 }
368 }
369 }
370 void dump() const { print(llvm::errs()); }
371};
372
Chris Lattner456ad6a2018-12-29 00:05:35373// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55374// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39375// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55376// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09377bool MemRefDependenceGraph::init(Function *f) {
MLIR Team6892ffb2018-12-20 04:42:55378 unsigned id = 0;
Chris Lattner3f190312018-12-27 22:35:10379 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43380
381 // TODO: support multi-block functions.
382 if (f->getBlocks().size() != 1)
383 return false;
384
385 for (auto &inst : f->front()) {
Chris Lattner456ad6a2018-12-29 00:05:35386 if (auto *forInst = dyn_cast<ForInst>(&inst)) {
387 // Create graph node 'id' to represent top-level 'forInst' and record
MLIR Team6892ffb2018-12-20 04:42:55388 // all loads and store accesses it contains.
389 LoopNestStateCollector collector;
Chris Lattner456ad6a2018-12-29 00:05:35390 collector.walkForInst(forInst);
391 // Return false if IfInsts are found (not currently supported).
392 if (collector.hasIfInst)
MLIR Team6892ffb2018-12-20 04:42:55393 return false;
Chris Lattner456ad6a2018-12-29 00:05:35394 Node node(id++, &inst);
395 for (auto *opInst : collector.loadOpInsts) {
396 node.loads.push_back(opInst);
397 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55398 memrefAccesses[memref].insert(node.id);
399 }
Chris Lattner456ad6a2018-12-29 00:05:35400 for (auto *opInst : collector.storeOpInsts) {
401 node.stores.push_back(opInst);
402 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55403 memrefAccesses[memref].insert(node.id);
404 }
405 nodes.insert({node.id, node});
406 }
Chris Lattner456ad6a2018-12-29 00:05:35407 if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
408 if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55409 // Create graph node for top-level load op.
Chris Lattner456ad6a2018-12-29 00:05:35410 Node node(id++, &inst);
411 node.loads.push_back(opInst);
412 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55413 memrefAccesses[memref].insert(node.id);
414 nodes.insert({node.id, node});
415 }
Chris Lattner456ad6a2018-12-29 00:05:35416 if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55417 // Create graph node for top-level store op.
Chris Lattner456ad6a2018-12-29 00:05:35418 Node node(id++, &inst);
419 node.stores.push_back(opInst);
420 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55421 memrefAccesses[memref].insert(node.id);
422 nodes.insert({node.id, node});
423 }
424 }
Chris Lattner456ad6a2018-12-29 00:05:35425 // Return false if IfInsts are found (not currently supported).
426 if (isa<IfInst>(&inst))
MLIR Team6892ffb2018-12-20 04:42:55427 return false;
428 }
429
430 // Walk memref access lists and add graph edges between dependent nodes.
431 for (auto &memrefAndList : memrefAccesses) {
432 unsigned n = memrefAndList.second.size();
433 for (unsigned i = 0; i < n; ++i) {
434 unsigned srcId = memrefAndList.second[i];
435 bool srcHasStore =
436 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
437 for (unsigned j = i + 1; j < n; ++j) {
438 unsigned dstId = memrefAndList.second[j];
439 bool dstHasStore =
440 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
441 if (srcHasStore || dstHasStore)
442 addEdge(srcId, dstId, memrefAndList.first);
443 }
444 }
445 }
446 return true;
447}
448
MLIR Team38c2fe32019-01-14 19:26:25449namespace {
450
451// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
452// and operation count) for a loop nest up until the innermost loop body.
453struct LoopNestStats {
454 // Map from ForInst to immediate child ForInsts in its loop body.
455 DenseMap<ForInst *, SmallVector<ForInst *, 2>> loopMap;
456 // Map from ForInst to count of operations in its loop body.
457 DenseMap<ForInst *, uint64_t> opCountMap;
458 // Map from ForInst to its constant trip count.
459 DenseMap<ForInst *, uint64_t> tripCountMap;
460};
461
462// LoopNestStatsCollector walks a single loop nest and gathers per-loop
463// trip count and operation count statistics and records them in 'stats'.
464class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
465public:
466 LoopNestStats *stats;
467 bool hasLoopWithNonConstTripCount = false;
468
469 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
470
471 void visitForInst(ForInst *forInst) {
472 auto *parentInst = forInst->getParentInst();
473 if (parentInst != nullptr) {
474 assert(isa<ForInst>(parentInst) && "Expected parent ForInst");
475 // Add mapping to 'forInst' from its parent ForInst.
476 stats->loopMap[cast<ForInst>(parentInst)].push_back(forInst);
477 }
478 // Record the number of op instructions in the body of 'forInst'.
479 unsigned count = 0;
480 stats->opCountMap[forInst] = 0;
481 for (auto &inst : *forInst->getBody()) {
482 if (isa<OperationInst>(&inst))
483 ++count;
484 }
485 stats->opCountMap[forInst] = count;
486 // Record trip count for 'forInst'. Set flag if trip count is not constant.
487 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(*forInst);
488 if (!maybeConstTripCount.hasValue()) {
489 hasLoopWithNonConstTripCount = true;
490 return;
491 }
492 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
493 }
494};
495
496// Computes the total cost of the loop nest rooted at 'forInst'.
497// Currently, the total cost is computed by counting the total operation
498// instance count (i.e. total number of operations in the loop bodyloop
499// operation count * loop trip count) for the entire loop nest.
500// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
501// specified in the map when computing the total op instance count.
502// NOTE: this is used to compute the cost of computation slices, which are
503// sliced along the iteration dimension, and thus reduce the trip count.
504// If 'computeCostMap' is non-null, the total op count for forInsts specified
505// in the map is increased (not overridden) by adding the op count from the
506// map to the existing op count for the for loop. This is done before
507// multiplying by the loop's trip count, and is used to model the cost of
508// inserting a sliced loop nest of known cost into the loop's body.
509// NOTE: this is used to compute the cost of fusing a slice of some loop nest
510// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24511static int64_t getComputeCost(
MLIR Team27d067e2019-01-16 17:55:02512 ForInst *forInst, LoopNestStats *stats,
513 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
Uday Bondhugula864d9e02019-01-23 17:16:24514 DenseMap<ForInst *, int64_t> *computeCostMap) {
MLIR Team38c2fe32019-01-14 19:26:25515 // 'opCount' is the total number operations in one iteration of 'forInst' body
Uday Bondhugula864d9e02019-01-23 17:16:24516 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25517 if (stats->loopMap.count(forInst) > 0) {
518 for (auto *childForInst : stats->loopMap[forInst]) {
519 opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
520 computeCostMap);
521 }
522 }
523 // Add in additional op instances from slice (if specified in map).
524 if (computeCostMap != nullptr) {
525 auto it = computeCostMap->find(forInst);
526 if (it != computeCostMap->end()) {
527 opCount += it->second;
528 }
529 }
530 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24531 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25532 if (tripCountOverrideMap != nullptr) {
533 auto it = tripCountOverrideMap->find(forInst);
534 if (it != tripCountOverrideMap->end()) {
535 tripCount = it->second;
536 }
537 }
538 // Returns the total number of dynamic instances of operations in loop body.
539 return tripCount * opCount;
540}
541
542} // end anonymous namespace
543
MLIR Team27d067e2019-01-16 17:55:02544static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00545 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
546 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02547 assert(lbMap.getNumDims() == ubMap.getNumDims());
548 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
549 // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
550 // ub_expr - lb_expr
551 AffineExpr lbExpr(lbMap.getResult(0));
552 AffineExpr ubExpr(ubMap.getResult(0));
553 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
554 lbMap.getNumSymbols());
555 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
556 if (!cExpr)
557 return None;
558 return cExpr.getValue();
559}
560
MLIR Team38c2fe32019-01-14 19:26:25561// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
562// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
563// Returns true on success, false otherwise (if a non-constant trip count
564// was encountered).
565// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02566static bool buildSliceTripCountMap(
567 OperationInst *srcOpInst, ComputationSliceState *sliceState,
568 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
MLIR Team38c2fe32019-01-14 19:26:25569 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02570 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25571 unsigned numSrcLoopIVs = srcLoopIVs.size();
572 // Populate map from ForInst -> trip count
573 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
574 AffineMap lbMap = sliceState->lbs[i];
575 AffineMap ubMap = sliceState->ubs[i];
576 if (lbMap == AffineMap::Null() || ubMap == AffineMap::Null()) {
577 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
578 if (srcLoopIVs[i]->hasConstantLowerBound() &&
579 srcLoopIVs[i]->hasConstantUpperBound()) {
580 (*tripCountMap)[srcLoopIVs[i]] =
581 srcLoopIVs[i]->getConstantUpperBound() -
582 srcLoopIVs[i]->getConstantLowerBound();
583 continue;
584 }
585 return false;
586 }
MLIR Team27d067e2019-01-16 17:55:02587 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
588 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25589 return false;
MLIR Team27d067e2019-01-16 17:55:02590 (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25591 }
592 return true;
593}
594
MLIR Team27d067e2019-01-16 17:55:02595// Removes load operations from 'srcLoads' which operate on 'memref', and
596// adds them to 'dstLoads'.
597static void
598moveLoadsAccessingMemrefTo(Value *memref,
599 SmallVectorImpl<OperationInst *> *srcLoads,
600 SmallVectorImpl<OperationInst *> *dstLoads) {
601 dstLoads->clear();
602 SmallVector<OperationInst *, 4> srcLoadsToKeep;
603 for (auto *load : *srcLoads) {
604 if (load->cast<LoadOp>()->getMemRef() == memref)
605 dstLoads->push_back(load);
606 else
607 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25608 }
MLIR Team27d067e2019-01-16 17:55:02609 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25610}
611
MLIR Team27d067e2019-01-16 17:55:02612// Returns the innermost common loop depth for the set of operations in 'ops'.
613static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
614 unsigned numOps = ops.size();
615 assert(numOps > 0);
616
617 std::vector<SmallVector<ForInst *, 4>> loops(numOps);
618 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
619 for (unsigned i = 0; i < numOps; ++i) {
620 getLoopIVs(*ops[i], &loops[i]);
621 loopDepthLimit =
622 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25623 }
MLIR Team27d067e2019-01-16 17:55:02624
625 unsigned loopDepth = 0;
626 for (unsigned d = 0; d < loopDepthLimit; ++d) {
627 unsigned i;
628 for (i = 1; i < numOps; ++i) {
629 if (loops[i - 1][d] != loops[i][d]) {
630 break;
631 }
632 }
633 if (i != numOps)
634 break;
635 ++loopDepth;
636 }
637 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25638}
639
Uday Bondhugulac1ca23e2019-01-16 21:13:00640// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
641// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:02642// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
643// and 'sliceStateB' are aligned.
644// Specifically, when taking the union of overlapping intervals, it assumes
645// that both intervals start at zero. Support needs to be added to take into
646// account interval start offset when computing the union.
647// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:00648static bool getSliceUnion(const ComputationSliceState &sliceStateA,
649 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:02650 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
651 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
652
653 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
654 AffineMap lbMapA = sliceStateA.lbs[i];
655 AffineMap ubMapA = sliceStateA.ubs[i];
656 if (lbMapA == AffineMap::Null()) {
657 assert(ubMapA == AffineMap::Null());
658 continue;
659 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00660 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:02661
662 AffineMap lbMapB = sliceStateB->lbs[i];
663 AffineMap ubMapB = sliceStateB->ubs[i];
664 if (lbMapB == AffineMap::Null()) {
665 assert(ubMapB == AffineMap::Null());
666 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
667 sliceStateB->lbs[i] = lbMapA;
668 sliceStateB->ubs[i] = ubMapA;
669 continue;
670 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00671
672 // TODO(andydavis) Change this code to take the min across all lower bounds
673 // and max across all upper bounds for each dimension. This code can for
674 // cases where a unique min or max could not be statically determined.
675
676 // Assumption: both lower bounds are the same.
677 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:02678 return false;
679
680 // Add bound with the largest trip count to union.
681 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
682 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
683 if (!tripCountA.hasValue() || !tripCountB.hasValue())
684 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:00685
MLIR Team27d067e2019-01-16 17:55:02686 if (tripCountA.getValue() > tripCountB.getValue()) {
687 sliceStateB->lbs[i] = lbMapA;
688 sliceStateB->ubs[i] = ubMapA;
689 }
690 }
691 return true;
692}
693
MLIR Teamc4237ae2019-01-18 16:56:27694// Creates and returns a private (single-user) memref for fused loop rooted
695// at 'forInst', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:52696// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
697// TODO(bondhugula): consider refactoring the common code from generateDma and
698// this one.
MLIR Teamc4237ae2019-01-18 16:56:27699static Value *createPrivateMemRef(ForInst *forInst,
Uday Bondhugula94a03f82019-01-22 21:58:52700 OperationInst *srcStoreOpInst,
701 unsigned dstLoopDepth) {
MLIR Teamc4237ae2019-01-18 16:56:27702 // Create builder to insert alloc op just before 'forInst'.
703 FuncBuilder b(forInst);
704 // Builder to create constants at the top level.
705 FuncBuilder top(forInst->getFunction());
706 // Create new memref type based on slice bounds.
707 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>()->getMemRef();
708 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
709 unsigned rank = oldMemRefType.getRank();
710
Uday Bondhugula94a03f82019-01-22 21:58:52711 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27712 MemRefRegion region;
Uday Bondhugula94a03f82019-01-22 21:58:52713 getMemRefRegion(srcStoreOpInst, dstLoopDepth, &region);
River Riddle6859f332019-01-23 22:39:45714 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:27715 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:52716 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:27717 lbs.reserve(rank);
718 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:52719 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27720 Optional<int64_t> numElements =
Uday Bondhugula94a03f82019-01-22 21:58:52721 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
MLIR Teamc4237ae2019-01-18 16:56:27722 assert(numElements.hasValue());
723
MLIR Teamc4237ae2019-01-18 16:56:27724 const FlatAffineConstraints *cst = region.getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:52725 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
726 // on; this would correspond to loop IVs surrounding the level at which the
727 // slice is being materialized.
728 SmallVector<Value *, 8> outerIVs;
729 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
730
731 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:27732 SmallVector<AffineExpr, 4> offsets;
733 offsets.reserve(rank);
734 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:52735 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
736
MLIR Teamc4237ae2019-01-18 16:56:27737 AffineExpr offset = top.getAffineConstantExpr(0);
738 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
739 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
740 }
Uday Bondhugula94a03f82019-01-22 21:58:52741 assert(lbDivisors[d] > 0);
742 offset =
743 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:27744 offsets.push_back(offset);
745 }
746
747 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
748 // by 'srcStoreOpInst'.
749 auto newMemRefType = b.getMemRefType(newShape, oldMemRefType.getElementType(),
750 {}, oldMemRefType.getMemorySpace());
751 // Gather alloc operands for the dynamic dimensions of the memref.
752 SmallVector<Value *, 4> allocOperands;
753 unsigned dynamicDimCount = 0;
754 for (auto dimSize : oldMemRefType.getShape()) {
755 if (dimSize == -1)
756 allocOperands.push_back(
757 b.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++));
758 }
759
760 // Create new private memref for fused loop 'forInst'.
761 Value *newMemRef =
762 b.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
763
764 // Build an AffineMap to remap access functions based on lower bound offsets.
765 SmallVector<AffineExpr, 4> remapExprs;
766 remapExprs.reserve(rank);
767 unsigned zeroOffsetCount = 0;
768 for (unsigned i = 0; i < rank; i++) {
769 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
770 if (constExpr.getValue() == 0)
771 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:52772 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
773
774 auto remapExpr =
775 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
776 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:27777 }
Uday Bondhugula94a03f82019-01-22 21:58:52778 auto indexRemap =
779 zeroOffsetCount == rank
780 ? AffineMap::Null()
781 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:27782 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:52783 bool ret =
784 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
785 /*extraOperands=*/outerIVs,
786 /*domInstFilter=*/&*forInst->getBody()->begin());
787 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:37788 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:27789 return newMemRef;
790}
791
Uday Bondhugula864d9e02019-01-23 17:16:24792// Does the slice have a single iteration?
793static uint64_t getSliceIterationCount(
794 const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) {
795 uint64_t iterCount = 1;
796 for (const auto &count : sliceTripCountMap) {
797 iterCount *= count.second;
798 }
799 return iterCount;
800}
801
MLIR Team27d067e2019-01-16 17:55:02802// Checks the profitability of fusing a backwards slice of the loop nest
803// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
804// Returns true if it profitable to fuse the candidate loop nests. Returns
805// false otherwise.
MLIR Team38c2fe32019-01-14 19:26:25806// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:02807// *) Computes the backward computation slice at 'srcOpInst'. This
808// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:25809// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:02810// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:25811// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
812// loop nest is the total number of dynamic operation instances in the loop
813// nest).
814// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:02815// loop nest at various values of dst loop depth, attempting to fuse
816// the largest compution slice at the maximal dst loop depth (closest to the
817// load) to minimize reuse distance and potentially enable subsequent
818// load/store forwarding.
819// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for
820// the same memref as is written by 'srcOpInst', then the union of slice
821// loop bounds is used to compute the slice and associated slice cost.
MLIR Team38c2fe32019-01-14 19:26:25822// NOTE: 'dstLoopDepth' refers the loop depth within the destination loop
823// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:02824// NOTE: We attempt to maximize the dst loop depth, but there are cases
825// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:25826// loop (within the src computation slice) at a depth which results in
827// execessive recomputation (see unit tests for examples).
828// *) Compares the total cost of the unfused loop nests to the min cost fused
829// loop nest computed in the previous step, and returns true if the latter
830// is lower.
MLIR Team27d067e2019-01-16 17:55:02831static bool isFusionProfitable(OperationInst *srcOpInst,
832 ArrayRef<OperationInst *> dstOpInsts,
MLIR Team38c2fe32019-01-14 19:26:25833 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:02834 unsigned *dstLoopDepth) {
Uday Bondhugula864d9e02019-01-23 17:16:24835 LLVM_DEBUG(llvm::dbgs() << "Checking whether fusion is profitable between:\n";
836 llvm::dbgs() << " "; srcOpInst->dump(); llvm::dbgs() << " and \n";
837 for (auto dstOpInst
838 : dstOpInsts) {
839 llvm::dbgs() << " ";
840 dstOpInst->dump();
841 });
842
MLIR Team38c2fe32019-01-14 19:26:25843 // Compute cost of sliced and unsliced src loop nest.
844 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02845 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25846 unsigned numSrcLoopIVs = srcLoopIVs.size();
847
848 // Walk src loop nest and collect stats.
849 LoopNestStats srcLoopNestStats;
850 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
851 srcStatsCollector.walk(srcLoopIVs[0]);
852 // Currently only constant trip count loop nests are supported.
853 if (srcStatsCollector.hasLoopWithNonConstTripCount)
854 return false;
855
856 // Compute cost of dst loop nest.
857 SmallVector<ForInst *, 4> dstLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02858 getLoopIVs(*dstOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25859
860 LoopNestStats dstLoopNestStats;
861 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
862 dstStatsCollector.walk(dstLoopIVs[0]);
863 // Currently only constant trip count loop nests are supported.
864 if (dstStatsCollector.hasLoopWithNonConstTripCount)
865 return false;
866
MLIR Team27d067e2019-01-16 17:55:02867 // Compute the innermost common loop for ops in 'dstOpInst'.
868 unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts);
869 if (maxDstLoopDepth == 0)
870 return false;
871
872 // Search for min cost value for 'dstLoopDepth'. At each value of
873 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
874 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
875 // of these bounds). Next the union slice bounds are used to calculate
876 // the cost of the slice and the cost of the slice inserted into the dst
877 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:24878 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
879 uint64_t maxStorageReduction = 0;
880 Optional<uint64_t> sliceMemEstimate = None;
881
MLIR Team27d067e2019-01-16 17:55:02882 SmallVector<ComputationSliceState, 4> sliceStates;
883 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:24884 // The best loop depth at which to materialize the slice.
885 Optional<unsigned> bestDstLoopDepth = None;
886
887 // Compute op instance count for the src loop nest without iteration slicing.
888 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
889 /*tripCountOverrideMap=*/nullptr,
890 /*computeCostMap=*/nullptr);
891
892 // Compute op instance count for the src loop nest.
893 uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
894 /*tripCountOverrideMap=*/nullptr,
895 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:02896
897 llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
Uday Bondhugula864d9e02019-01-23 17:16:24898 DenseMap<ForInst *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:02899 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
900 MemRefAccess srcAccess(srcOpInst);
901 // Handle the common case of one dst load without a copy.
902 if (!mlir::getBackwardComputationSliceState(
903 srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1]))
904 return false;
905 // Compute the union of slice bound of all ops in 'dstOpInsts'.
906 for (int j = 1, e = dstOpInsts.size(); j < e; ++j) {
907 MemRefAccess dstAccess(dstOpInsts[j]);
908 ComputationSliceState tmpSliceState;
909 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
910 &tmpSliceState))
911 return false;
912 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:00913 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:25914 }
MLIR Team27d067e2019-01-16 17:55:02915 // Build trip count map for computation slice.
916 sliceTripCountMap.clear();
917 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
918 &sliceTripCountMap))
Uday Bondhugula864d9e02019-01-23 17:16:24919 // We'll skip cases where we the trip count was non-constant.
920 continue;
921
922 // Checks whether a store to load forwarding will happen.
923 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
924 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
925
926 assert(sliceIterationCount > 0);
927
928 // Compute cost of fusion for this dest loop depth.
929
930 computeCostMap.clear();
931
932 // The store and loads to this memref will disappear.
933 if (storeLoadFwdGuaranteed) {
934 // A single store disappears: -1 for that.
935 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1;
936 for (auto *loadOp : dstOpInsts) {
937 if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst()))
938 computeCostMap[loadLoop] = -1;
939 }
940 }
MLIR Team27d067e2019-01-16 17:55:02941
MLIR Team38c2fe32019-01-14 19:26:25942 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:24943 int64_t sliceComputeCost =
944 getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
945 /*tripCountOverrideMap=*/&sliceTripCountMap,
946 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:25947
Uday Bondhugula864d9e02019-01-23 17:16:24948 // Compute cost of fusion for this depth.
MLIR Team27d067e2019-01-16 17:55:02949 computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:24950
951 int64_t fusedLoopNestComputeCost =
MLIR Team27d067e2019-01-16 17:55:02952 getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
953 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:24954
955 double additionalComputeFraction =
956 fusedLoopNestComputeCost /
957 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
958 1;
959
960 // TODO(bondhugula): This is an ugly approximation. Fix this by finding a
961 // good way to calculate the footprint of the memref in the slice and
962 // divide it by the total memory footprint of the fused computation.
963 double storageReduction =
964 static_cast<double>(srcLoopNestCost) / sliceIterationCount;
965
966 LLVM_DEBUG(
967 std::stringstream msg;
968 msg << " evaluating fusion profitability at depth : " << i << "\n"
969 << std::setprecision(2) << " additional compute fraction: "
970 << 100.0 * additionalComputeFraction << "%\n"
971 << " storage reduction factor: " << storageReduction << "x\n"
972 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
973 << " slice iteration count: " << sliceIterationCount << "\n";
974 llvm::dbgs() << msg.str());
975
976 double computeToleranceThreshold =
977 clFusionAddlComputeTolerance.getNumOccurrences() > 0
978 ? clFusionAddlComputeTolerance
979 : LoopFusion::kComputeToleranceThreshold;
980
981 // TODO(b/123247369): This is a placeholder cost model.
982 // Among all choices that add an acceptable amount of redundant computation
983 // (as per computeToleranceThreshold), we will simply pick the one that
984 // reduces the intermediary size the most.
985 if ((storageReduction > maxStorageReduction) &&
986 (clMaximalLoopFusion ||
987 (additionalComputeFraction < computeToleranceThreshold))) {
988 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:02989 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:24990 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
991 // TODO(bondhugula,andydavis): find a good way to compute the memory
992 // footprint of the materialized slice.
993 // Approximating this to the compute cost of the slice. This could be an
994 // under-approximation or an overapproximation, but in many cases
995 // accurate.
996 sliceMemEstimate = sliceIterationCount;
MLIR Team38c2fe32019-01-14 19:26:25997 }
998 }
999
Uday Bondhugula864d9e02019-01-23 17:16:241000 // A simple cost model: fuse if it reduces the memory footprint. If
1001 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251002
Uday Bondhugula864d9e02019-01-23 17:16:241003 if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
1004 LLVM_DEBUG(llvm::dbgs()
1005 << "All fusion choices involve more than the threshold amount of"
1006 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251007 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241008 }
1009
1010 assert(bestDstLoopDepth.hasValue() &&
1011 "expected to have a value per logic above");
1012
1013 // Set dstLoopDepth based on best values from search.
1014 *dstLoopDepth = bestDstLoopDepth.getValue();
1015
1016 LLVM_DEBUG(
1017 llvm::dbgs() << " LoopFusion fusion stats:\n"
1018 << "\n Best loop depth: " << bestDstLoopDepth
1019 << "\n src loop nest compute cost: " << srcLoopNestCost
1020 << "\n dst loop nest compute cost: " << dstLoopNestCost
1021 << "\n fused loop nest compute cost: "
1022 << minFusedLoopNestComputeCost << "\n");
1023
1024 auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]);
1025 auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]);
1026
1027 Optional<double> storageReduction = None;
1028
1029 if (!clMaximalLoopFusion) {
1030 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1031 LLVM_DEBUG(
1032 llvm::dbgs()
1033 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1034 return false;
1035 }
1036
1037 auto srcMemSizeVal = srcMemSize.getValue();
1038 auto dstMemSizeVal = dstMemSize.getValue();
1039
1040 assert(sliceMemEstimate.hasValue() && "expected value");
1041 // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
1042 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1043
1044 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1045 << " dst mem: " << dstMemSizeVal << "\n"
1046 << " fused mem: " << fusedMem << "\n"
1047 << " slice mem: " << sliceMemEstimate << "\n");
1048
1049 if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
1050 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1051 return false;
1052 }
1053 storageReduction =
1054 100.0 *
1055 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1056 }
1057
1058 double additionalComputeFraction =
1059 100.0 * (minFusedLoopNestComputeCost /
1060 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1061 1);
1062
1063 std::stringstream msg;
1064 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
1065 << setprecision(2) << additionalComputeFraction
1066 << "% redundant computation and a ";
1067 msg << (storageReduction.hasValue()
1068 ? std::to_string(storageReduction.getValue())
1069 : "<unknown>");
1070 msg << "% storage reduction.\n";
1071 LLVM_DEBUG(llvm::dbgs() << msg.str());
1072
MLIR Team27d067e2019-01-16 17:55:021073 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241074 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021075 sliceState->lbs = bestSliceState->lbs;
1076 sliceState->ubs = bestSliceState->ubs;
1077 sliceState->lbOperands = bestSliceState->lbOperands;
1078 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241079
MLIR Team27d067e2019-01-16 17:55:021080 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251081 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
MLIR Team27d067e2019-01-16 17:55:021082 if (sliceState->lbs[i] != AffineMap::Null()) {
1083 canonicalizeMapAndOperands(&sliceState->lbs[i],
1084 &sliceState->lbOperands[i]);
1085 }
1086 if (sliceState->ubs[i] != AffineMap::Null()) {
1087 canonicalizeMapAndOperands(&sliceState->ubs[i],
1088 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251089 }
1090 }
1091 return true;
1092}
1093
MLIR Team6892ffb2018-12-20 04:42:551094// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:141095// relationship on a memref, with the goal of improving locality. Currently,
1096// this the producer/consumer relationship is required to be unique in the
Chris Lattner69d9e992018-12-28 16:48:091097// Function (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:001098//
MLIR Team3b692302018-12-17 17:57:141099// The steps of the algorithm are as follows:
1100//
MLIR Team6892ffb2018-12-20 04:42:551101// *) A worklist is initialized with node ids from the dependence graph.
1102// *) For each node id in the worklist:
Chris Lattner456ad6a2018-12-29 00:05:351103// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
1104// destination ForInst into which fusion will be attempted.
1105// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141106// *) For each LoadOp in 'dstLoadOps' do:
Chris Lattner69d9e992018-12-28 16:48:091107// *) Lookup dependent loop nests at earlier positions in the Function
MLIR Team3b692302018-12-17 17:57:141108// which have a single store op to the same memref.
1109// *) Check if dependences would be violated by the fusion. For example,
1110// the src loop nest may load from memrefs which are different than
1111// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:551112// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141113// bounds to be functions of 'dstLoopNest' IVs and symbols.
1114// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
1115// just before the dst load op user.
Chris Lattner456ad6a2018-12-29 00:05:351116// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:141117// and also add newly fuse load ops to 'dstLoopOps' to be considered
1118// as fusion dst load ops in another iteration.
1119// *) Remove old src loop nest and its associated state.
1120//
Chris Lattner456ad6a2018-12-29 00:05:351121// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141122// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551123// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141124//
MLIR Team6892ffb2018-12-20 04:42:551125// This greedy algorithm is not 'maximal' due to the current restriction of
1126// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141127//
1128// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551129// TODO(andydavis) Add support for fusing for input reuse (perhaps by
1130// constructing a graph with edges which represent loads from the same memref
1131// in two different loop nestst.
1132struct GreedyFusion {
1133public:
1134 MemRefDependenceGraph *mdg;
MLIR Team3b692302018-12-17 17:57:141135 SmallVector<unsigned, 4> worklist;
MLIR Teamf28e4df2018-11-01 14:26:001136
MLIR Team6892ffb2018-12-20 04:42:551137 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
1138 // Initialize worklist with nodes from 'mdg'.
1139 worklist.resize(mdg->nodes.size());
1140 std::iota(worklist.begin(), worklist.end(), 0);
1141 }
MLIR Team3b692302018-12-17 17:57:141142
1143 void run() {
MLIR Team3b692302018-12-17 17:57:141144 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551145 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141146 worklist.pop_back();
MLIR Team6892ffb2018-12-20 04:42:551147 // Skip if this node was removed (fused into another node).
1148 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141149 continue;
MLIR Team6892ffb2018-12-20 04:42:551150 // Get 'dstNode' into which to attempt fusion.
1151 auto *dstNode = mdg->getNode(dstId);
1152 // Skip if 'dstNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:351153 if (!isa<ForInst>(dstNode->inst))
MLIR Team3b692302018-12-17 17:57:141154 continue;
1155
Chris Lattner5187cfc2018-12-28 05:21:411156 SmallVector<OperationInst *, 4> loads = dstNode->loads;
MLIR Team27d067e2019-01-16 17:55:021157 SmallVector<OperationInst *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271158 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551159 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021160 // Get memref of load on top of the stack.
1161 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271162 if (visitedMemrefs.count(memref) > 0)
1163 continue;
1164 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021165 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1166 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551167 // Skip if no input edges along which to fuse.
1168 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141169 continue;
MLIR Team6892ffb2018-12-20 04:42:551170 // Iterate through in edges for 'dstId'.
1171 for (auto &srcEdge : mdg->inEdges[dstId]) {
1172 // Skip 'srcEdge' if not for 'memref'.
1173 if (srcEdge.memref != memref)
1174 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241175
MLIR Team6892ffb2018-12-20 04:42:551176 auto *srcNode = mdg->getNode(srcEdge.id);
1177 // Skip if 'srcNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:351178 if (!isa<ForInst>(srcNode->inst))
MLIR Team6892ffb2018-12-20 04:42:551179 continue;
MLIR Teamb28009b2019-01-23 19:11:431180 // Skip if 'srcNode' has more than one store to any memref.
1181 // TODO(andydavis) Support fusing multi-output src loop nests.
1182 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551183 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241184
MLIR Team6892ffb2018-12-20 04:42:551185 // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly
1186 // TODO(andydavis) Track dependence type with edges, and just check
1187 // for WAW dependence edge here.
1188 if (mdg->getInEdgeCount(srcNode->id, memref) != 0)
1189 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241190
MLIR Team6892ffb2018-12-20 04:42:551191 // Skip if 'srcNode' has out edges to other memrefs after 'dstId'.
MLIR Teamc4237ae2019-01-18 16:56:271192 if (mdg->getMinOutEdgeNodeId(srcNode->id, memref) < dstId)
MLIR Team6892ffb2018-12-20 04:42:551193 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241194
1195 // Check if fusion would be profitable.
MLIR Team6892ffb2018-12-20 04:42:551196 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351197 auto *srcStoreOpInst = srcNode->stores.front();
MLIR Team38c2fe32019-01-14 19:26:251198 unsigned dstLoopDepth;
1199 mlir::ComputationSliceState sliceState;
MLIR Team27d067e2019-01-16 17:55:021200 if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
MLIR Team38c2fe32019-01-14 19:26:251201 &dstLoopDepth))
1202 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241203
MLIR Team6892ffb2018-12-20 04:42:551204 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1205 auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
MLIR Team27d067e2019-01-16 17:55:021206 srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:551207 if (sliceLoopNest != nullptr) {
MLIR Teamc4237ae2019-01-18 16:56:271208 // Update edges between 'srcNode' and 'dstNode'.
1209 mdg->updateEdges(srcNode->id, dstNode->id);
1210
1211 // Collect slice loop stats.
1212 LoopNestStateCollector sliceCollector;
1213 sliceCollector.walkForInst(sliceLoopNest);
1214 // Promote single iteration slice loops to single IV value.
1215 for (auto *forInst : sliceCollector.forInsts) {
Chris Lattner456ad6a2018-12-29 00:05:351216 promoteIfSingleIteration(forInst);
MLIR Team6892ffb2018-12-20 04:42:551217 }
MLIR Teamc4237ae2019-01-18 16:56:271218
1219 // Create private memref for 'memref' in 'dstForInst'.
1220 auto *dstForInst = cast<ForInst>(dstNode->inst);
1221 SmallVector<OperationInst *, 4> storesForMemref;
1222 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
1223 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1224 storesForMemref.push_back(storeOpInst);
1225 }
1226 assert(storesForMemref.size() == 1);
Uday Bondhugula94a03f82019-01-22 21:58:521227 auto *newMemRef = createPrivateMemRef(
1228 dstForInst, storesForMemref[0], dstLoopDepth);
MLIR Teamc4237ae2019-01-18 16:56:271229 visitedMemrefs.insert(newMemRef);
1230
1231 // Collect dst loop stats after memref privatizaton transformation.
1232 LoopNestStateCollector dstLoopCollector;
1233 dstLoopCollector.walkForInst(dstForInst);
1234
1235 // Add new load ops to current Node load op list 'loads' to
1236 // continue fusing based on new operands.
1237 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
1238 auto *loadMemRef = loadOpInst->cast<LoadOp>()->getMemRef();
1239 if (visitedMemrefs.count(loadMemRef) == 0)
1240 loads.push_back(loadOpInst);
1241 }
1242
1243 // Clear and add back loads and stores
1244 mdg->clearNodeLoadAndStores(dstNode->id);
1245 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1246 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371247 // Remove old src loop nest if it no longer has outgoing dependence
1248 // edges, and it does not write to a memref which escapes the
1249 // function.
1250 if (!mdg->hasOutEdges(srcNode->id) &&
1251 !mdg->writesToLiveInOrEscapingMemrefs(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271252 mdg->removeNode(srcNode->id);
1253 cast<ForInst>(srcNode->inst)->erase();
1254 }
MLIR Team3b692302018-12-17 17:57:141255 }
MLIR Team3b692302018-12-17 17:57:141256 }
1257 }
1258 }
MLIR Teamc4237ae2019-01-18 16:56:271259 // Clean up any allocs with no users.
1260 for (auto &pair : mdg->memrefEdgeCount) {
1261 if (pair.second > 0)
1262 continue;
1263 auto *memref = pair.first;
MLIR Team71495d52019-01-22 21:23:371264 // Skip if there exist other uses (return instruction or function calls).
1265 if (!memref->use_empty())
1266 continue;
MLIR Teamc4237ae2019-01-18 16:56:271267 // Use list expected to match the dep graph info.
MLIR Teamc4237ae2019-01-18 16:56:271268 auto *inst = memref->getDefiningInst();
1269 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
1270 if (opInst && opInst->isa<AllocOp>())
1271 opInst->erase();
1272 }
MLIR Teamf28e4df2018-11-01 14:26:001273 }
MLIR Team3b692302018-12-17 17:57:141274};
1275
1276} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:001277
Chris Lattner79748892018-12-31 07:10:351278PassResult LoopFusion::runOnFunction(Function *f) {
MLIR Team6892ffb2018-12-20 04:42:551279 MemRefDependenceGraph g;
1280 if (g.init(f))
1281 GreedyFusion(&g).run();
MLIR Teamf28e4df2018-11-01 14:26:001282 return success();
1283}
Jacques Pienaar6f0fb222018-11-07 02:34:181284
1285static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");