blob: d7d69e569e5fcda05979031947d29f90e2c2b995 [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
River Riddle75553832019-01-29 05:23:5322#include "mlir/AffineOps/AffineOps.h"
MLIR Teamf28e4df2018-11-01 14:26:0023#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1424#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0025#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1426#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0027#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinOps.h"
Chris Lattner456ad6a2018-12-29 00:05:3531#include "mlir/IR/InstVisitor.h"
MLIR Teamf28e4df2018-11-01 14:26:0032#include "mlir/Pass.h"
33#include "mlir/StandardOps/StandardOps.h"
34#include "mlir/Transforms/LoopUtils.h"
35#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2736#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0037#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1438#include "llvm/ADT/DenseSet.h"
39#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2340#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2541#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1442#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2443#include <iomanip>
MLIR Team3b692302018-12-17 17:57:1444
MLIR Team38c2fe32019-01-14 19:26:2545#define DEBUG_TYPE "loop-fusion"
46
MLIR Team3b692302018-12-17 17:57:1447using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0048
49using namespace mlir;
50
River Riddle75c21e12019-01-26 06:14:0451static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
52
Uday Bondhugula864d9e02019-01-23 17:16:2453/// Disables fusion profitability check and fuses if valid.
MLIR Teamc4237ae2019-01-18 16:56:2754static llvm::cl::opt<bool>
55 clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
River Riddle75c21e12019-01-26 06:14:0456 llvm::cl::desc("Enables maximal loop fusion"),
57 llvm::cl::cat(clOptionsCategory));
Uday Bondhugula864d9e02019-01-23 17:16:2458
59/// A threshold in percent of additional computation allowed when fusing.
60static llvm::cl::opt<double> clFusionAddlComputeTolerance(
61 "fusion-compute-tolerance", llvm::cl::Hidden,
62 llvm::cl::desc("Fractional increase in additional"
River Riddle75c21e12019-01-26 06:14:0463 " computation tolerated while fusing"),
64 llvm::cl::cat(clOptionsCategory));
MLIR Teamc4237ae2019-01-18 16:56:2765
Uday Bondhugula8be26272019-02-02 01:06:2266static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
67 "fusion-fast-mem-space", llvm::cl::Hidden,
68 llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
69 llvm::cl::cat(clOptionsCategory));
70
71static llvm::cl::opt<unsigned> clFusionLocalBufThreshold(
72 "fusion-local-buf-threshold", llvm::cl::Hidden,
73 llvm::cl::desc("Threshold size (bytes) for promoting local buffers to fast "
74 "memory space"),
75 llvm::cl::cat(clOptionsCategory));
76
MLIR Teamf28e4df2018-11-01 14:26:0077namespace {
78
MLIR Team3b692302018-12-17 17:57:1479/// Loop fusion pass. This pass currently supports a greedy fusion policy,
80/// which fuses loop nests with single-writer/single-reader memref dependences
81/// with the goal of improving locality.
82
83// TODO(andydavis) Support fusion of source loop nests which write to multiple
84// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0085// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
86// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1487
MLIR Teamf28e4df2018-11-01 14:26:0088struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0389 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0090
Chris Lattner79748892018-12-31 07:10:3591 PassResult runOnFunction(Function *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1892 static char passID;
Uday Bondhugula864d9e02019-01-23 17:16:2493
Uday Bondhugula8be26272019-02-02 01:06:2294 // Any local buffers smaller than this size will be created in
95 // `fastMemorySpace` if provided.
96 unsigned localBufSizeThreshold = 1024;
97 Optional<unsigned> fastMemorySpace = None;
98
Uday Bondhugula864d9e02019-01-23 17:16:2499 // The amount of additional computation that is tolerated while fusing
100 // pair-wise as a fraction of the total computation.
101 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:00102};
103
MLIR Teamf28e4df2018-11-01 14:26:00104} // end anonymous namespace
105
Jacques Pienaar6f0fb222018-11-07 02:34:18106char LoopFusion::passID = 0;
107
MLIR Teamf28e4df2018-11-01 14:26:00108FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
109
MLIR Team3b692302018-12-17 17:57:14110namespace {
MLIR Teamf28e4df2018-11-01 14:26:00111
MLIR Team3b692302018-12-17 17:57:14112// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:35113// operations, and whether or not an IfInst was encountered in the loop nest.
114class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
MLIR Team3b692302018-12-17 17:57:14115public:
River Riddle5052bd82019-02-02 00:42:18116 SmallVector<OpPointer<AffineForOp>, 4> forOps;
River Riddleb4992772019-02-04 18:38:47117 SmallVector<Instruction *, 4> loadOpInsts;
118 SmallVector<Instruction *, 4> storeOpInsts;
River Riddle75553832019-01-29 05:23:53119 bool hasNonForRegion = false;
MLIR Team3b692302018-12-17 17:57:14120
River Riddleb4992772019-02-04 18:38:47121 void visitInstruction(Instruction *opInst) {
River Riddle5052bd82019-02-02 00:42:18122 if (opInst->isa<AffineForOp>())
123 forOps.push_back(opInst->cast<AffineForOp>());
124 else if (opInst->getNumBlockLists() != 0)
River Riddle75553832019-01-29 05:23:53125 hasNonForRegion = true;
126 else if (opInst->isa<LoadOp>())
Chris Lattner456ad6a2018-12-29 00:05:35127 loadOpInsts.push_back(opInst);
River Riddle75553832019-01-29 05:23:53128 else if (opInst->isa<StoreOp>())
Chris Lattner456ad6a2018-12-29 00:05:35129 storeOpInsts.push_back(opInst);
MLIR Team3b692302018-12-17 17:57:14130 }
131};
132
MLIR Team71495d52019-01-22 21:23:37133// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
River Riddleb4992772019-02-04 18:38:47134static bool isMemRefDereferencingOp(const Instruction &op) {
MLIR Team71495d52019-01-22 21:23:37135 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
136 op.isa<DmaWaitOp>())
137 return true;
138 return false;
139}
MLIR Team6892ffb2018-12-20 04:42:55140// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:35141// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55142// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27143// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55144// TODO(andydavis) Add a depth parameter to dependence graph construction.
145struct MemRefDependenceGraph {
146public:
147 // Node represents a node in the graph. A Node is either an entire loop nest
148 // rooted at the top level which contains loads/stores, or a top level
149 // load/store.
150 struct Node {
151 // The unique identifier of this node in the graph.
152 unsigned id;
153 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35154 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41155 // List of load operations.
River Riddleb4992772019-02-04 18:38:47156 SmallVector<Instruction *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35157 // List of store op insts.
River Riddleb4992772019-02-04 18:38:47158 SmallVector<Instruction *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35159 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55160
161 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10162 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55163 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35164 for (auto *loadOpInst : loads) {
165 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55166 ++loadOpCount;
167 }
168 return loadOpCount;
169 }
170
171 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10172 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55173 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35174 for (auto *storeOpInst : stores) {
175 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55176 ++storeOpCount;
177 }
178 return storeOpCount;
179 }
180 };
181
MLIR Teama0f3db402019-01-29 17:36:41182 // Edge represents a data dependece between nodes in the graph.
MLIR Team6892ffb2018-12-20 04:42:55183 struct Edge {
184 // The id of the node at the other end of the edge.
MLIR Team1e851912019-01-31 00:01:46185 // If this edge is stored in Edge = Node.inEdges[i], then
186 // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
187 // If this edge is stored in Edge = Node.outEdges[i], then
188 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
MLIR Team6892ffb2018-12-20 04:42:55189 unsigned id;
MLIR Teama0f3db402019-01-29 17:36:41190 // The SSA value on which this edge represents a dependence.
191 // If the value is a memref, then the dependence is between graph nodes
192 // which contain accesses to the same memref 'value'. If the value is a
193 // non-memref value, then the dependence is between a graph node which
194 // defines an SSA value and another graph node which uses the SSA value
195 // (e.g. a constant instruction defining a value which is used inside a loop
196 // nest).
197 Value *value;
MLIR Team6892ffb2018-12-20 04:42:55198 };
199
200 // Map from node id to Node.
201 DenseMap<unsigned, Node> nodes;
202 // Map from node id to list of input edges.
203 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
204 // Map from node id to list of output edges.
205 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27206 // Map from memref to a count on the dependence edges associated with that
207 // memref.
208 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Teama0f3db402019-01-29 17:36:41209 // The next unique identifier to use for newly created graph nodes.
210 unsigned nextNodeId = 0;
MLIR Team6892ffb2018-12-20 04:42:55211
212 MemRefDependenceGraph() {}
213
214 // Initializes the dependence graph based on operations in 'f'.
215 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09216 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55217
218 // Returns the graph node for 'id'.
219 Node *getNode(unsigned id) {
220 auto it = nodes.find(id);
221 assert(it != nodes.end());
222 return &it->second;
223 }
224
MLIR Teama0f3db402019-01-29 17:36:41225 // Adds a node with 'inst' to the graph and returns its unique identifier.
226 unsigned addNode(Instruction *inst) {
227 Node node(nextNodeId++, inst);
228 nodes.insert({node.id, node});
229 return node.id;
230 }
231
MLIR Teamc4237ae2019-01-18 16:56:27232 // Remove node 'id' (and its associated edges) from graph.
233 void removeNode(unsigned id) {
234 // Remove each edge in 'inEdges[id]'.
235 if (inEdges.count(id) > 0) {
236 SmallVector<Edge, 2> oldInEdges = inEdges[id];
237 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41238 removeEdge(inEdge.id, id, inEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27239 }
240 }
241 // Remove each edge in 'outEdges[id]'.
242 if (outEdges.count(id) > 0) {
243 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
244 for (auto &outEdge : oldOutEdges) {
MLIR Teama0f3db402019-01-29 17:36:41245 removeEdge(id, outEdge.id, outEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27246 }
247 }
248 // Erase remaining node state.
249 inEdges.erase(id);
250 outEdges.erase(id);
251 nodes.erase(id);
252 }
253
MLIR Teamd7c82442019-01-30 23:53:41254 // Returns true if node 'id' writes to any memref which escapes (or is an
255 // argument to) the function/block. Returns false otherwise.
256 bool writesToLiveInOrEscapingMemrefs(unsigned id) {
MLIR Team71495d52019-01-22 21:23:37257 Node *node = getNode(id);
258 for (auto *storeOpInst : node->stores) {
259 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
260 auto *inst = memref->getDefiningInst();
River Riddleb4992772019-02-04 18:38:47261 // Return false if 'memref' is a block argument.
262 if (!inst)
MLIR Teamd7c82442019-01-30 23:53:41263 return true;
MLIR Team71495d52019-01-22 21:23:37264 // Return false if any use of 'memref' escapes the function.
River Riddleb4992772019-02-04 18:38:47265 for (auto &use : memref->getUses())
266 if (!isMemRefDereferencingOp(*use.getOwner()))
MLIR Teamd7c82442019-01-30 23:53:41267 return true;
MLIR Teamd7c82442019-01-30 23:53:41268 }
269 return false;
270 }
271
272 // Returns true if node 'id' can be removed from the graph. Returns false
273 // otherwise. A node can be removed from the graph iff the following
274 // conditions are met:
275 // *) The node does not write to any memref which escapes (or is a
276 // function/block argument).
277 // *) The node has no successors in the dependence graph.
278 bool canRemoveNode(unsigned id) {
279 if (writesToLiveInOrEscapingMemrefs(id))
280 return false;
281 Node *node = getNode(id);
282 for (auto *storeOpInst : node->stores) {
MLIR Teama0f3db402019-01-29 17:36:41283 // Return false if there exist out edges from 'id' on 'memref'.
MLIR Teamd7c82442019-01-30 23:53:41284 if (getOutEdgeCount(id, storeOpInst->cast<StoreOp>()->getMemRef()) > 0)
MLIR Teama0f3db402019-01-29 17:36:41285 return false;
MLIR Team71495d52019-01-22 21:23:37286 }
MLIR Teama0f3db402019-01-29 17:36:41287 return true;
MLIR Team71495d52019-01-22 21:23:37288 }
289
MLIR Team27d067e2019-01-16 17:55:02290 // Returns true iff there is an edge from node 'srcId' to node 'dstId' for
MLIR Teama0f3db402019-01-29 17:36:41291 // 'value'. Returns false otherwise.
292 bool hasEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team27d067e2019-01-16 17:55:02293 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
294 return false;
295 }
296 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41297 return edge.id == dstId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02298 });
299 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41300 return edge.id == srcId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02301 });
302 return hasOutEdge && hasInEdge;
303 }
304
MLIR Teama0f3db402019-01-29 17:36:41305 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
306 void addEdge(unsigned srcId, unsigned dstId, Value *value) {
307 if (!hasEdge(srcId, dstId, value)) {
308 outEdges[srcId].push_back({dstId, value});
309 inEdges[dstId].push_back({srcId, value});
310 if (value->getType().isa<MemRefType>())
311 memrefEdgeCount[value]++;
MLIR Team27d067e2019-01-16 17:55:02312 }
MLIR Team6892ffb2018-12-20 04:42:55313 }
314
MLIR Teama0f3db402019-01-29 17:36:41315 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
316 void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team6892ffb2018-12-20 04:42:55317 assert(inEdges.count(dstId) > 0);
318 assert(outEdges.count(srcId) > 0);
MLIR Teama0f3db402019-01-29 17:36:41319 if (value->getType().isa<MemRefType>()) {
320 assert(memrefEdgeCount.count(value) > 0);
321 memrefEdgeCount[value]--;
322 }
MLIR Team6892ffb2018-12-20 04:42:55323 // Remove 'srcId' from 'inEdges[dstId]'.
324 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41325 if ((*it).id == srcId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55326 inEdges[dstId].erase(it);
327 break;
328 }
329 }
330 // Remove 'dstId' from 'outEdges[srcId]'.
331 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41332 if ((*it).id == dstId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55333 outEdges[srcId].erase(it);
334 break;
335 }
336 }
337 }
338
MLIR Teama0f3db402019-01-29 17:36:41339 // Returns the input edge count for node 'id' and 'memref' from src nodes
340 // which access 'memref'.
341 unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55342 unsigned inEdgeCount = 0;
343 if (inEdges.count(id) > 0)
344 for (auto &inEdge : inEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41345 if (inEdge.value == memref) {
346 Node *srcNode = getNode(inEdge.id);
347 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
348 if (srcNode->getLoadOpCount(memref) > 0 ||
349 srcNode->getStoreOpCount(memref) > 0)
350 ++inEdgeCount;
351 }
MLIR Team6892ffb2018-12-20 04:42:55352 return inEdgeCount;
353 }
354
355 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10356 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55357 unsigned outEdgeCount = 0;
358 if (outEdges.count(id) > 0)
359 for (auto &outEdge : outEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41360 if (outEdge.value == memref)
MLIR Team6892ffb2018-12-20 04:42:55361 ++outEdgeCount;
362 return outEdgeCount;
363 }
364
MLIR Teama0f3db402019-01-29 17:36:41365 // Computes and returns an insertion point instruction, before which the
366 // the fused <srcId, dstId> loop nest can be inserted while preserving
367 // dependences. Returns nullptr if no such insertion point is found.
368 Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId,
369 Value *memrefToSkip) {
MLIR Team5c5739d2019-01-25 06:27:40370 if (outEdges.count(srcId) == 0)
MLIR Teama0f3db402019-01-29 17:36:41371 return getNode(dstId)->inst;
372
373 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
374 SmallPtrSet<Instruction *, 2> srcDepInsts;
375 for (auto &outEdge : outEdges[srcId])
376 if (outEdge.id != dstId && outEdge.value != memrefToSkip)
377 srcDepInsts.insert(getNode(outEdge.id)->inst);
378
379 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
380 SmallPtrSet<Instruction *, 2> dstDepInsts;
381 for (auto &inEdge : inEdges[dstId])
382 if (inEdge.id != srcId && inEdge.value != memrefToSkip)
383 dstDepInsts.insert(getNode(inEdge.id)->inst);
384
385 Instruction *srcNodeInst = getNode(srcId)->inst;
386 Instruction *dstNodeInst = getNode(dstId)->inst;
387
388 // Computing insertion point:
389 // *) Walk all instruction positions in Block instruction list in the
390 // range (src, dst). For each instruction 'inst' visited in this search:
391 // *) Store in 'firstSrcDepPos' the first position where 'inst' has a
392 // dependence edge from 'srcNode'.
393 // *) Store in 'lastDstDepPost' the last position where 'inst' has a
394 // dependence edge to 'dstNode'.
395 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
396 // instruction insertion point (or return null pointer if no such
397 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
398 SmallVector<Instruction *, 2> depInsts;
399 Optional<unsigned> firstSrcDepPos;
400 Optional<unsigned> lastDstDepPos;
401 unsigned pos = 0;
402 for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
403 it != Block::iterator(dstNodeInst); ++it) {
404 Instruction *inst = &(*it);
405 if (srcDepInsts.count(inst) > 0 && firstSrcDepPos == None)
406 firstSrcDepPos = pos;
407 if (dstDepInsts.count(inst) > 0)
408 lastDstDepPos = pos;
409 depInsts.push_back(inst);
410 ++pos;
MLIR Team5c5739d2019-01-25 06:27:40411 }
MLIR Teama0f3db402019-01-29 17:36:41412
413 if (firstSrcDepPos.hasValue()) {
414 if (lastDstDepPos.hasValue()) {
415 if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
416 // No valid insertion point exists which preserves dependences.
417 return nullptr;
418 }
419 }
420 // Return the insertion point at 'firstSrcDepPos'.
421 return depInsts[firstSrcDepPos.getValue()];
422 }
423 // No dependence targets in range (or only dst deps in range), return
424 // 'dstNodInst' insertion point.
425 return dstNodeInst;
MLIR Team6892ffb2018-12-20 04:42:55426 }
427
MLIR Teama0f3db402019-01-29 17:36:41428 // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
429 // has been replaced in node at 'dstId' by a private memref.
430 void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
MLIR Team6892ffb2018-12-20 04:42:55431 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
432 if (inEdges.count(srcId) > 0) {
433 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
434 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41435 // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
436 if (inEdge.value != oldMemRef)
437 addEdge(inEdge.id, dstId, inEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55438 }
439 }
MLIR Teamc4237ae2019-01-18 16:56:27440 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55441 if (outEdges.count(srcId) > 0) {
442 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
443 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27444 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
445 if (outEdge.id == dstId)
MLIR Teama0f3db402019-01-29 17:36:41446 removeEdge(srcId, outEdge.id, outEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55447 }
448 }
MLIR Teama0f3db402019-01-29 17:36:41449 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
450 // replaced by a private memref). These edges could come from nodes
451 // other than 'srcId' which were removed in the previous step.
452 if (inEdges.count(dstId) > 0) {
453 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
454 for (auto &inEdge : oldInEdges)
455 if (inEdge.value == oldMemRef)
456 removeEdge(inEdge.id, dstId, inEdge.value);
457 }
MLIR Team6892ffb2018-12-20 04:42:55458 }
459
460 // Adds ops in 'loads' and 'stores' to node at 'id'.
River Riddleb4992772019-02-04 18:38:47461 void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads,
462 const SmallVectorImpl<Instruction *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55463 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35464 for (auto *loadOpInst : loads)
465 node->loads.push_back(loadOpInst);
466 for (auto *storeOpInst : stores)
467 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55468 }
469
MLIR Teamc4237ae2019-01-18 16:56:27470 void clearNodeLoadAndStores(unsigned id) {
471 Node *node = getNode(id);
472 node->loads.clear();
473 node->stores.clear();
474 }
475
MLIR Team6892ffb2018-12-20 04:42:55476 void print(raw_ostream &os) const {
477 os << "\nMemRefDependenceGraph\n";
478 os << "\nNodes:\n";
479 for (auto &idAndNode : nodes) {
480 os << "Node: " << idAndNode.first << "\n";
481 auto it = inEdges.find(idAndNode.first);
482 if (it != inEdges.end()) {
483 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41484 os << " InEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55485 }
486 it = outEdges.find(idAndNode.first);
487 if (it != outEdges.end()) {
488 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41489 os << " OutEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55490 }
491 }
492 }
493 void dump() const { print(llvm::errs()); }
494};
495
Chris Lattner456ad6a2018-12-29 00:05:35496// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55497// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39498// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55499// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09500bool MemRefDependenceGraph::init(Function *f) {
Chris Lattner3f190312018-12-27 22:35:10501 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43502
503 // TODO: support multi-block functions.
504 if (f->getBlocks().size() != 1)
505 return false;
506
River Riddle5052bd82019-02-02 00:42:18507 DenseMap<Instruction *, unsigned> forToNodeMap;
Chris Lattnerdffc5892018-12-29 23:33:43508 for (auto &inst : f->front()) {
River Riddleb4992772019-02-04 18:38:47509 if (auto forOp = inst.dyn_cast<AffineForOp>()) {
River Riddle5052bd82019-02-02 00:42:18510 // Create graph node 'id' to represent top-level 'forOp' and record
MLIR Team6892ffb2018-12-20 04:42:55511 // all loads and store accesses it contains.
512 LoopNestStateCollector collector;
River Riddle5052bd82019-02-02 00:42:18513 collector.walk(&inst);
514 // Return false if a non 'for' region was found (not currently supported).
River Riddle75553832019-01-29 05:23:53515 if (collector.hasNonForRegion)
MLIR Team6892ffb2018-12-20 04:42:55516 return false;
MLIR Teama0f3db402019-01-29 17:36:41517 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35518 for (auto *opInst : collector.loadOpInsts) {
519 node.loads.push_back(opInst);
520 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55521 memrefAccesses[memref].insert(node.id);
522 }
Chris Lattner456ad6a2018-12-29 00:05:35523 for (auto *opInst : collector.storeOpInsts) {
524 node.stores.push_back(opInst);
525 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55526 memrefAccesses[memref].insert(node.id);
527 }
River Riddle5052bd82019-02-02 00:42:18528 forToNodeMap[&inst] = node.id;
MLIR Team6892ffb2018-12-20 04:42:55529 nodes.insert({node.id, node});
River Riddleb4992772019-02-04 18:38:47530 } else if (auto loadOp = inst.dyn_cast<LoadOp>()) {
531 // Create graph node for top-level load op.
532 Node node(nextNodeId++, &inst);
533 node.loads.push_back(&inst);
534 auto *memref = inst.cast<LoadOp>()->getMemRef();
535 memrefAccesses[memref].insert(node.id);
536 nodes.insert({node.id, node});
537 } else if (auto storeOp = inst.dyn_cast<StoreOp>()) {
538 // Create graph node for top-level store op.
539 Node node(nextNodeId++, &inst);
540 node.stores.push_back(&inst);
541 auto *memref = inst.cast<StoreOp>()->getMemRef();
542 memrefAccesses[memref].insert(node.id);
543 nodes.insert({node.id, node});
544 } else if (inst.getNumBlockLists() != 0) {
545 // Return false if another region is found (not currently supported).
546 return false;
547 } else if (inst.getNumResults() > 0 && !inst.use_empty()) {
548 // Create graph node for top-level producer of SSA values, which
549 // could be used by loop nest nodes.
550 Node node(nextNodeId++, &inst);
551 nodes.insert({node.id, node});
MLIR Teama0f3db402019-01-29 17:36:41552 }
553 }
554
555 // Add dependence edges between nodes which produce SSA values and their
556 // users.
557 for (auto &idAndNode : nodes) {
558 const Node &node = idAndNode.second;
559 if (!node.loads.empty() || !node.stores.empty())
560 continue;
River Riddleb4992772019-02-04 18:38:47561 auto *opInst = node.inst;
MLIR Teama0f3db402019-01-29 17:36:41562 for (auto *value : opInst->getResults()) {
563 for (auto &use : value->getUses()) {
River Riddle5052bd82019-02-02 00:42:18564 SmallVector<OpPointer<AffineForOp>, 4> loops;
River Riddleb4992772019-02-04 18:38:47565 getLoopIVs(*use.getOwner(), &loops);
MLIR Teama0f3db402019-01-29 17:36:41566 if (loops.empty())
567 continue;
River Riddle5052bd82019-02-02 00:42:18568 assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
569 unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()];
MLIR Teama0f3db402019-01-29 17:36:41570 addEdge(node.id, userLoopNestId, value);
MLIR Team6892ffb2018-12-20 04:42:55571 }
572 }
MLIR Team6892ffb2018-12-20 04:42:55573 }
574
575 // Walk memref access lists and add graph edges between dependent nodes.
576 for (auto &memrefAndList : memrefAccesses) {
577 unsigned n = memrefAndList.second.size();
578 for (unsigned i = 0; i < n; ++i) {
579 unsigned srcId = memrefAndList.second[i];
580 bool srcHasStore =
581 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
582 for (unsigned j = i + 1; j < n; ++j) {
583 unsigned dstId = memrefAndList.second[j];
584 bool dstHasStore =
585 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
586 if (srcHasStore || dstHasStore)
587 addEdge(srcId, dstId, memrefAndList.first);
588 }
589 }
590 }
591 return true;
592}
593
MLIR Team38c2fe32019-01-14 19:26:25594namespace {
595
596// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
597// and operation count) for a loop nest up until the innermost loop body.
598struct LoopNestStats {
River Riddle5052bd82019-02-02 00:42:18599 // Map from AffineForOp to immediate child AffineForOps in its loop body.
600 DenseMap<Instruction *, SmallVector<OpPointer<AffineForOp>, 2>> loopMap;
601 // Map from AffineForOp to count of operations in its loop body.
602 DenseMap<Instruction *, uint64_t> opCountMap;
603 // Map from AffineForOp to its constant trip count.
604 DenseMap<Instruction *, uint64_t> tripCountMap;
MLIR Team38c2fe32019-01-14 19:26:25605};
606
607// LoopNestStatsCollector walks a single loop nest and gathers per-loop
608// trip count and operation count statistics and records them in 'stats'.
609class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
610public:
611 LoopNestStats *stats;
612 bool hasLoopWithNonConstTripCount = false;
613
614 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
615
River Riddleb4992772019-02-04 18:38:47616 void visitInstruction(Instruction *opInst) {
River Riddle5052bd82019-02-02 00:42:18617 auto forOp = opInst->dyn_cast<AffineForOp>();
618 if (!forOp)
619 return;
620
621 auto *forInst = forOp->getInstruction();
622 auto *parentInst = forOp->getInstruction()->getParentInst();
MLIR Team38c2fe32019-01-14 19:26:25623 if (parentInst != nullptr) {
River Riddleb4992772019-02-04 18:38:47624 assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
River Riddle5052bd82019-02-02 00:42:18625 // Add mapping to 'forOp' from its parent AffineForOp.
626 stats->loopMap[parentInst].push_back(forOp);
MLIR Team38c2fe32019-01-14 19:26:25627 }
River Riddle5052bd82019-02-02 00:42:18628
629 // Record the number of op instructions in the body of 'forOp'.
MLIR Team38c2fe32019-01-14 19:26:25630 unsigned count = 0;
631 stats->opCountMap[forInst] = 0;
River Riddle5052bd82019-02-02 00:42:18632 for (auto &inst : *forOp->getBody()) {
River Riddleb4992772019-02-04 18:38:47633 if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
MLIR Team38c2fe32019-01-14 19:26:25634 ++count;
635 }
636 stats->opCountMap[forInst] = count;
River Riddle5052bd82019-02-02 00:42:18637 // Record trip count for 'forOp'. Set flag if trip count is not constant.
638 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
MLIR Team38c2fe32019-01-14 19:26:25639 if (!maybeConstTripCount.hasValue()) {
640 hasLoopWithNonConstTripCount = true;
641 return;
642 }
643 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
644 }
645};
646
River Riddle5052bd82019-02-02 00:42:18647// Computes the total cost of the loop nest rooted at 'forOp'.
MLIR Team38c2fe32019-01-14 19:26:25648// Currently, the total cost is computed by counting the total operation
649// instance count (i.e. total number of operations in the loop bodyloop
650// operation count * loop trip count) for the entire loop nest.
651// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
652// specified in the map when computing the total op instance count.
653// NOTE: this is used to compute the cost of computation slices, which are
654// sliced along the iteration dimension, and thus reduce the trip count.
River Riddle5052bd82019-02-02 00:42:18655// If 'computeCostMap' is non-null, the total op count for forOps specified
MLIR Team38c2fe32019-01-14 19:26:25656// in the map is increased (not overridden) by adding the op count from the
657// map to the existing op count for the for loop. This is done before
658// multiplying by the loop's trip count, and is used to model the cost of
659// inserting a sliced loop nest of known cost into the loop's body.
660// NOTE: this is used to compute the cost of fusing a slice of some loop nest
661// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24662static int64_t getComputeCost(
River Riddle5052bd82019-02-02 00:42:18663 Instruction *forInst, LoopNestStats *stats,
664 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountOverrideMap,
665 DenseMap<Instruction *, int64_t> *computeCostMap) {
666 // 'opCount' is the total number operations in one iteration of 'forOp' body
Uday Bondhugula864d9e02019-01-23 17:16:24667 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25668 if (stats->loopMap.count(forInst) > 0) {
River Riddle5052bd82019-02-02 00:42:18669 for (auto childForOp : stats->loopMap[forInst]) {
670 opCount += getComputeCost(childForOp->getInstruction(), stats,
671 tripCountOverrideMap, computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:25672 }
673 }
674 // Add in additional op instances from slice (if specified in map).
675 if (computeCostMap != nullptr) {
676 auto it = computeCostMap->find(forInst);
677 if (it != computeCostMap->end()) {
678 opCount += it->second;
679 }
680 }
681 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24682 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25683 if (tripCountOverrideMap != nullptr) {
684 auto it = tripCountOverrideMap->find(forInst);
685 if (it != tripCountOverrideMap->end()) {
686 tripCount = it->second;
687 }
688 }
689 // Returns the total number of dynamic instances of operations in loop body.
690 return tripCount * opCount;
691}
692
693} // end anonymous namespace
694
MLIR Team27d067e2019-01-16 17:55:02695static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00696 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
697 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02698 assert(lbMap.getNumDims() == ubMap.getNumDims());
699 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
700 // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
701 // ub_expr - lb_expr
702 AffineExpr lbExpr(lbMap.getResult(0));
703 AffineExpr ubExpr(ubMap.getResult(0));
704 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
705 lbMap.getNumSymbols());
706 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
707 if (!cExpr)
708 return None;
709 return cExpr.getValue();
710}
711
River Riddle5052bd82019-02-02 00:42:18712// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
MLIR Team38c2fe32019-01-14 19:26:25713// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
714// Returns true on success, false otherwise (if a non-constant trip count
715// was encountered).
716// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02717static bool buildSliceTripCountMap(
River Riddleb4992772019-02-04 18:38:47718 Instruction *srcOpInst, ComputationSliceState *sliceState,
River Riddle5052bd82019-02-02 00:42:18719 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
720 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02721 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25722 unsigned numSrcLoopIVs = srcLoopIVs.size();
River Riddle5052bd82019-02-02 00:42:18723 // Populate map from AffineForOp -> trip count
MLIR Team38c2fe32019-01-14 19:26:25724 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
725 AffineMap lbMap = sliceState->lbs[i];
726 AffineMap ubMap = sliceState->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17727 if (lbMap == AffineMap() || ubMap == AffineMap()) {
MLIR Team38c2fe32019-01-14 19:26:25728 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
729 if (srcLoopIVs[i]->hasConstantLowerBound() &&
730 srcLoopIVs[i]->hasConstantUpperBound()) {
River Riddle5052bd82019-02-02 00:42:18731 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] =
MLIR Team38c2fe32019-01-14 19:26:25732 srcLoopIVs[i]->getConstantUpperBound() -
733 srcLoopIVs[i]->getConstantLowerBound();
734 continue;
735 }
736 return false;
737 }
MLIR Team27d067e2019-01-16 17:55:02738 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
739 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25740 return false;
River Riddle5052bd82019-02-02 00:42:18741 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25742 }
743 return true;
744}
745
MLIR Team27d067e2019-01-16 17:55:02746// Removes load operations from 'srcLoads' which operate on 'memref', and
747// adds them to 'dstLoads'.
748static void
749moveLoadsAccessingMemrefTo(Value *memref,
River Riddleb4992772019-02-04 18:38:47750 SmallVectorImpl<Instruction *> *srcLoads,
751 SmallVectorImpl<Instruction *> *dstLoads) {
MLIR Team27d067e2019-01-16 17:55:02752 dstLoads->clear();
River Riddleb4992772019-02-04 18:38:47753 SmallVector<Instruction *, 4> srcLoadsToKeep;
MLIR Team27d067e2019-01-16 17:55:02754 for (auto *load : *srcLoads) {
755 if (load->cast<LoadOp>()->getMemRef() == memref)
756 dstLoads->push_back(load);
757 else
758 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25759 }
MLIR Team27d067e2019-01-16 17:55:02760 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25761}
762
MLIR Team27d067e2019-01-16 17:55:02763// Returns the innermost common loop depth for the set of operations in 'ops'.
River Riddleb4992772019-02-04 18:38:47764static unsigned getInnermostCommonLoopDepth(ArrayRef<Instruction *> ops) {
MLIR Team27d067e2019-01-16 17:55:02765 unsigned numOps = ops.size();
766 assert(numOps > 0);
767
River Riddle5052bd82019-02-02 00:42:18768 std::vector<SmallVector<OpPointer<AffineForOp>, 4>> loops(numOps);
MLIR Team27d067e2019-01-16 17:55:02769 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
770 for (unsigned i = 0; i < numOps; ++i) {
771 getLoopIVs(*ops[i], &loops[i]);
772 loopDepthLimit =
773 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25774 }
MLIR Team27d067e2019-01-16 17:55:02775
776 unsigned loopDepth = 0;
777 for (unsigned d = 0; d < loopDepthLimit; ++d) {
778 unsigned i;
779 for (i = 1; i < numOps; ++i) {
River Riddle5052bd82019-02-02 00:42:18780 if (loops[i - 1][d] != loops[i][d])
MLIR Team27d067e2019-01-16 17:55:02781 break;
MLIR Team27d067e2019-01-16 17:55:02782 }
783 if (i != numOps)
784 break;
785 ++loopDepth;
786 }
787 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25788}
789
MLIR Teamd7c82442019-01-30 23:53:41790// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
791// and 'storeOpInsts' are satisfied.
River Riddleb4992772019-02-04 18:38:47792static unsigned getMaxLoopDepth(ArrayRef<Instruction *> loadOpInsts,
793 ArrayRef<Instruction *> storeOpInsts) {
MLIR Teamd7c82442019-01-30 23:53:41794 // Merge loads and stores into the same array.
River Riddleb4992772019-02-04 18:38:47795 SmallVector<Instruction *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
MLIR Teamd7c82442019-01-30 23:53:41796 ops.append(storeOpInsts.begin(), storeOpInsts.end());
797
798 // Compute the innermost common loop depth for loads and stores.
799 unsigned loopDepth = getInnermostCommonLoopDepth(ops);
800
801 // Return common loop depth for loads if there are no store ops.
802 if (storeOpInsts.empty())
803 return loopDepth;
804
805 // Check dependences on all pairs of ops in 'ops' and store the minimum
806 // loop depth at which a dependence is satisfied.
807 for (unsigned i = 0, e = ops.size(); i < e; ++i) {
808 auto *srcOpInst = ops[i];
809 MemRefAccess srcAccess(srcOpInst);
810 for (unsigned j = 0; j < e; ++j) {
811 auto *dstOpInst = ops[j];
812 MemRefAccess dstAccess(dstOpInst);
813
814 unsigned numCommonLoops =
815 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
816 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
817 FlatAffineConstraints dependenceConstraints;
818 // TODO(andydavis) Cache dependence analysis results, check cache here.
819 if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
820 &dependenceConstraints,
821 /*dependenceComponents=*/nullptr)) {
822 // Store minimum loop depth and break because we want the min 'd' at
823 // which there is a dependence.
824 loopDepth = std::min(loopDepth, d - 1);
825 break;
826 }
827 }
828 }
829 }
830 return loopDepth;
831}
832
Uday Bondhugulac1ca23e2019-01-16 21:13:00833// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
834// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:02835// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
836// and 'sliceStateB' are aligned.
837// Specifically, when taking the union of overlapping intervals, it assumes
838// that both intervals start at zero. Support needs to be added to take into
839// account interval start offset when computing the union.
840// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:00841static bool getSliceUnion(const ComputationSliceState &sliceStateA,
842 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:02843 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
844 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
845
846 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
847 AffineMap lbMapA = sliceStateA.lbs[i];
848 AffineMap ubMapA = sliceStateA.ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17849 if (lbMapA == AffineMap()) {
850 assert(ubMapA == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02851 continue;
852 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00853 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:02854
855 AffineMap lbMapB = sliceStateB->lbs[i];
856 AffineMap ubMapB = sliceStateB->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17857 if (lbMapB == AffineMap()) {
858 assert(ubMapB == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02859 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
860 sliceStateB->lbs[i] = lbMapA;
861 sliceStateB->ubs[i] = ubMapA;
862 continue;
863 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00864
865 // TODO(andydavis) Change this code to take the min across all lower bounds
866 // and max across all upper bounds for each dimension. This code can for
867 // cases where a unique min or max could not be statically determined.
868
869 // Assumption: both lower bounds are the same.
870 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:02871 return false;
872
873 // Add bound with the largest trip count to union.
874 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
875 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
876 if (!tripCountA.hasValue() || !tripCountB.hasValue())
877 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:00878
MLIR Team27d067e2019-01-16 17:55:02879 if (tripCountA.getValue() > tripCountB.getValue()) {
880 sliceStateB->lbs[i] = lbMapA;
881 sliceStateB->ubs[i] = ubMapA;
882 }
883 }
884 return true;
885}
886
Uday Bondhugula8be26272019-02-02 01:06:22887// TODO(mlir-team): improve/complete this when we have target data.
888unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
889 auto elementType = memRefType.getElementType();
890
891 unsigned sizeInBits;
892 if (elementType.isIntOrFloat()) {
893 sizeInBits = elementType.getIntOrFloatBitWidth();
894 } else {
895 auto vectorType = elementType.cast<VectorType>();
896 sizeInBits =
897 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
898 }
899 return llvm::divideCeil(sizeInBits, 8);
900}
901
MLIR Teamc4237ae2019-01-18 16:56:27902// Creates and returns a private (single-user) memref for fused loop rooted
River Riddle5052bd82019-02-02 00:42:18903// at 'forOp', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:52904// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
905// TODO(bondhugula): consider refactoring the common code from generateDma and
906// this one.
River Riddle5052bd82019-02-02 00:42:18907static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
River Riddleb4992772019-02-04 18:38:47908 Instruction *srcStoreOpInst,
Uday Bondhugula8be26272019-02-02 01:06:22909 unsigned dstLoopDepth,
910 Optional<unsigned> fastMemorySpace,
911 unsigned localBufSizeThreshold) {
River Riddle5052bd82019-02-02 00:42:18912 auto *forInst = forOp->getInstruction();
913
914 // Create builder to insert alloc op just before 'forOp'.
MLIR Teamc4237ae2019-01-18 16:56:27915 FuncBuilder b(forInst);
916 // Builder to create constants at the top level.
917 FuncBuilder top(forInst->getFunction());
918 // Create new memref type based on slice bounds.
919 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>()->getMemRef();
920 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
921 unsigned rank = oldMemRefType.getRank();
922
Uday Bondhugula94a03f82019-01-22 21:58:52923 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
Uday Bondhugulab26900d2019-02-04 15:58:42924 auto region = getMemRefRegion(srcStoreOpInst, dstLoopDepth);
River Riddle6859f332019-01-23 22:39:45925 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:27926 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:52927 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:27928 lbs.reserve(rank);
929 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:52930 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27931 Optional<int64_t> numElements =
Uday Bondhugulab26900d2019-02-04 15:58:42932 region->getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
Uday Bondhugula8be26272019-02-02 01:06:22933 assert(numElements.hasValue() &&
934 "non-constant number of elts in local buffer");
MLIR Teamc4237ae2019-01-18 16:56:27935
Uday Bondhugulab26900d2019-02-04 15:58:42936 const FlatAffineConstraints *cst = region->getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:52937 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
938 // on; this would correspond to loop IVs surrounding the level at which the
939 // slice is being materialized.
940 SmallVector<Value *, 8> outerIVs;
941 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
942
943 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:27944 SmallVector<AffineExpr, 4> offsets;
945 offsets.reserve(rank);
946 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:52947 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
948
MLIR Teamc4237ae2019-01-18 16:56:27949 AffineExpr offset = top.getAffineConstantExpr(0);
950 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
951 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
952 }
Uday Bondhugula94a03f82019-01-22 21:58:52953 assert(lbDivisors[d] > 0);
954 offset =
955 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:27956 offsets.push_back(offset);
957 }
958
959 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
960 // by 'srcStoreOpInst'.
Uday Bondhugula8be26272019-02-02 01:06:22961 uint64_t bufSize =
962 getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
963 unsigned newMemSpace;
964 if (bufSize < localBufSizeThreshold && fastMemorySpace.hasValue()) {
965 newMemSpace = fastMemorySpace.getValue();
966 } else {
967 newMemSpace = oldMemRefType.getMemorySpace();
968 }
969 auto newMemRefType = top.getMemRefType(
970 newShape, oldMemRefType.getElementType(), {}, newMemSpace);
MLIR Teamc4237ae2019-01-18 16:56:27971 // Gather alloc operands for the dynamic dimensions of the memref.
972 SmallVector<Value *, 4> allocOperands;
973 unsigned dynamicDimCount = 0;
974 for (auto dimSize : oldMemRefType.getShape()) {
975 if (dimSize == -1)
976 allocOperands.push_back(
River Riddle5052bd82019-02-02 00:42:18977 top.create<DimOp>(forOp->getLoc(), oldMemRef, dynamicDimCount++));
MLIR Teamc4237ae2019-01-18 16:56:27978 }
979
River Riddle5052bd82019-02-02 00:42:18980 // Create new private memref for fused loop 'forOp'.
MLIR Teama0f3db402019-01-29 17:36:41981 // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
982 // consumer loop nests to reduce their live range. Currently they are added
983 // at the beginning of the function, because loop nests can be reordered
984 // during the fusion pass.
MLIR Teamc4237ae2019-01-18 16:56:27985 Value *newMemRef =
River Riddle5052bd82019-02-02 00:42:18986 top.create<AllocOp>(forOp->getLoc(), newMemRefType, allocOperands);
MLIR Teamc4237ae2019-01-18 16:56:27987
988 // Build an AffineMap to remap access functions based on lower bound offsets.
989 SmallVector<AffineExpr, 4> remapExprs;
990 remapExprs.reserve(rank);
991 unsigned zeroOffsetCount = 0;
992 for (unsigned i = 0; i < rank; i++) {
993 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
994 if (constExpr.getValue() == 0)
995 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:52996 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
997
998 auto remapExpr =
999 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
1000 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:271001 }
Uday Bondhugula94a03f82019-01-22 21:58:521002 auto indexRemap =
1003 zeroOffsetCount == rank
Nicolas Vasilache0e7a8a92019-01-26 18:41:171004 ? AffineMap()
Uday Bondhugula94a03f82019-01-22 21:58:521005 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:271006 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:521007 bool ret =
1008 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
1009 /*extraOperands=*/outerIVs,
River Riddle5052bd82019-02-02 00:42:181010 /*domInstFilter=*/&*forOp->getBody()->begin());
Uday Bondhugula94a03f82019-01-22 21:58:521011 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:371012 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:271013 return newMemRef;
1014}
1015
Uday Bondhugula864d9e02019-01-23 17:16:241016// Does the slice have a single iteration?
1017static uint64_t getSliceIterationCount(
River Riddle5052bd82019-02-02 00:42:181018 const llvm::SmallDenseMap<Instruction *, uint64_t, 8> &sliceTripCountMap) {
Uday Bondhugula864d9e02019-01-23 17:16:241019 uint64_t iterCount = 1;
1020 for (const auto &count : sliceTripCountMap) {
1021 iterCount *= count.second;
1022 }
1023 return iterCount;
1024}
1025
MLIR Team27d067e2019-01-16 17:55:021026// Checks the profitability of fusing a backwards slice of the loop nest
MLIR Teamd7c82442019-01-30 23:53:411027// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
Uday Bondhugulab4a14432019-01-26 00:00:501028// Returns true if it is profitable to fuse the candidate loop nests. Returns
1029// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
1030// to materialize the source loop nest slice.
MLIR Team38c2fe32019-01-14 19:26:251031// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:021032// *) Computes the backward computation slice at 'srcOpInst'. This
1033// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:251034// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:021035// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:251036// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
1037// loop nest is the total number of dynamic operation instances in the loop
1038// nest).
1039// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:021040// loop nest at various values of dst loop depth, attempting to fuse
1041// the largest compution slice at the maximal dst loop depth (closest to the
1042// load) to minimize reuse distance and potentially enable subsequent
1043// load/store forwarding.
MLIR Teamd7c82442019-01-30 23:53:411044// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
MLIR Team27d067e2019-01-16 17:55:021045// the same memref as is written by 'srcOpInst', then the union of slice
1046// loop bounds is used to compute the slice and associated slice cost.
Uday Bondhugulab4a14432019-01-26 00:00:501047// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
MLIR Team38c2fe32019-01-14 19:26:251048// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:021049// NOTE: We attempt to maximize the dst loop depth, but there are cases
1050// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:251051// loop (within the src computation slice) at a depth which results in
1052// execessive recomputation (see unit tests for examples).
1053// *) Compares the total cost of the unfused loop nests to the min cost fused
1054// loop nest computed in the previous step, and returns true if the latter
1055// is lower.
River Riddleb4992772019-02-04 18:38:471056static bool isFusionProfitable(Instruction *srcOpInst,
1057 ArrayRef<Instruction *> dstLoadOpInsts,
1058 ArrayRef<Instruction *> dstStoreOpInsts,
MLIR Team38c2fe32019-01-14 19:26:251059 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:021060 unsigned *dstLoopDepth) {
Uday Bondhugula06d21d92019-01-25 01:01:491061 LLVM_DEBUG({
1062 llvm::dbgs() << "Checking whether fusion is profitable between:\n";
1063 llvm::dbgs() << " ";
1064 srcOpInst->dump();
1065 llvm::dbgs() << " and \n";
MLIR Teamd7c82442019-01-30 23:53:411066 for (auto dstOpInst : dstLoadOpInsts) {
Uday Bondhugula06d21d92019-01-25 01:01:491067 llvm::dbgs() << " ";
1068 dstOpInst->dump();
1069 };
1070 });
Uday Bondhugula864d9e02019-01-23 17:16:241071
MLIR Team38c2fe32019-01-14 19:26:251072 // Compute cost of sliced and unsliced src loop nest.
River Riddle5052bd82019-02-02 00:42:181073 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:021074 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251075 unsigned numSrcLoopIVs = srcLoopIVs.size();
1076
1077 // Walk src loop nest and collect stats.
1078 LoopNestStats srcLoopNestStats;
1079 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
River Riddle5052bd82019-02-02 00:42:181080 srcStatsCollector.walk(srcLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251081 // Currently only constant trip count loop nests are supported.
1082 if (srcStatsCollector.hasLoopWithNonConstTripCount)
1083 return false;
1084
1085 // Compute cost of dst loop nest.
River Riddle5052bd82019-02-02 00:42:181086 SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
MLIR Teamd7c82442019-01-30 23:53:411087 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251088
1089 LoopNestStats dstLoopNestStats;
1090 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
River Riddle5052bd82019-02-02 00:42:181091 dstStatsCollector.walk(dstLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251092 // Currently only constant trip count loop nests are supported.
1093 if (dstStatsCollector.hasLoopWithNonConstTripCount)
1094 return false;
1095
MLIR Teamd7c82442019-01-30 23:53:411096 // Compute the maximum loop depth at which we can can insert the src slice
1097 // and still satisfy dest loop nest dependences.
1098 unsigned maxDstLoopDepth = getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts);
MLIR Team27d067e2019-01-16 17:55:021099 if (maxDstLoopDepth == 0)
1100 return false;
1101
1102 // Search for min cost value for 'dstLoopDepth'. At each value of
1103 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
1104 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
1105 // of these bounds). Next the union slice bounds are used to calculate
1106 // the cost of the slice and the cost of the slice inserted into the dst
1107 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:241108 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
1109 uint64_t maxStorageReduction = 0;
1110 Optional<uint64_t> sliceMemEstimate = None;
1111
MLIR Team27d067e2019-01-16 17:55:021112 SmallVector<ComputationSliceState, 4> sliceStates;
1113 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:241114 // The best loop depth at which to materialize the slice.
1115 Optional<unsigned> bestDstLoopDepth = None;
1116
1117 // Compute op instance count for the src loop nest without iteration slicing.
River Riddle5052bd82019-02-02 00:42:181118 uint64_t srcLoopNestCost =
1119 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
1120 /*tripCountOverrideMap=*/nullptr,
1121 /*computeCostMap=*/nullptr);
Uday Bondhugula864d9e02019-01-23 17:16:241122
1123 // Compute op instance count for the src loop nest.
River Riddle5052bd82019-02-02 00:42:181124 uint64_t dstLoopNestCost =
1125 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
1126 /*tripCountOverrideMap=*/nullptr,
1127 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:021128
River Riddle5052bd82019-02-02 00:42:181129 llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap;
1130 DenseMap<Instruction *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:021131 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
1132 MemRefAccess srcAccess(srcOpInst);
1133 // Handle the common case of one dst load without a copy.
1134 if (!mlir::getBackwardComputationSliceState(
MLIR Teamd7c82442019-01-30 23:53:411135 srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1]))
MLIR Team27d067e2019-01-16 17:55:021136 return false;
MLIR Teamd7c82442019-01-30 23:53:411137 // Compute the union of slice bound of all ops in 'dstLoadOpInsts'.
1138 for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) {
1139 MemRefAccess dstAccess(dstLoadOpInsts[j]);
MLIR Team27d067e2019-01-16 17:55:021140 ComputationSliceState tmpSliceState;
1141 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
1142 &tmpSliceState))
1143 return false;
1144 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:001145 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:251146 }
Uday Bondhugulab4a14432019-01-26 00:00:501147 // Build trip count map for computation slice. We'll skip cases where the
1148 // trip count was non-constant.
MLIR Team27d067e2019-01-16 17:55:021149 sliceTripCountMap.clear();
1150 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
1151 &sliceTripCountMap))
Uday Bondhugula864d9e02019-01-23 17:16:241152 continue;
1153
1154 // Checks whether a store to load forwarding will happen.
1155 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
Uday Bondhugula864d9e02019-01-23 17:16:241156 assert(sliceIterationCount > 0);
Uday Bondhugulab4a14432019-01-26 00:00:501157 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
Uday Bondhugula864d9e02019-01-23 17:16:241158
1159 // Compute cost of fusion for this dest loop depth.
1160
1161 computeCostMap.clear();
1162
1163 // The store and loads to this memref will disappear.
1164 if (storeLoadFwdGuaranteed) {
1165 // A single store disappears: -1 for that.
River Riddle5052bd82019-02-02 00:42:181166 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
MLIR Teamd7c82442019-01-30 23:53:411167 for (auto *loadOp : dstLoadOpInsts) {
River Riddle5052bd82019-02-02 00:42:181168 auto *parentInst = loadOp->getParentInst();
River Riddleb4992772019-02-04 18:38:471169 if (parentInst && parentInst->isa<AffineForOp>())
River Riddle5052bd82019-02-02 00:42:181170 computeCostMap[parentInst] = -1;
Uday Bondhugula864d9e02019-01-23 17:16:241171 }
1172 }
MLIR Team27d067e2019-01-16 17:55:021173
MLIR Team38c2fe32019-01-14 19:26:251174 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:241175 int64_t sliceComputeCost =
River Riddle5052bd82019-02-02 00:42:181176 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
Uday Bondhugula864d9e02019-01-23 17:16:241177 /*tripCountOverrideMap=*/&sliceTripCountMap,
1178 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:251179
Uday Bondhugula864d9e02019-01-23 17:16:241180 // Compute cost of fusion for this depth.
River Riddle5052bd82019-02-02 00:42:181181 computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:241182
1183 int64_t fusedLoopNestComputeCost =
River Riddle5052bd82019-02-02 00:42:181184 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
MLIR Team27d067e2019-01-16 17:55:021185 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:241186
1187 double additionalComputeFraction =
1188 fusedLoopNestComputeCost /
1189 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1190 1;
1191
1192 // TODO(bondhugula): This is an ugly approximation. Fix this by finding a
1193 // good way to calculate the footprint of the memref in the slice and
1194 // divide it by the total memory footprint of the fused computation.
1195 double storageReduction =
1196 static_cast<double>(srcLoopNestCost) / sliceIterationCount;
1197
Uday Bondhugula06d21d92019-01-25 01:01:491198 LLVM_DEBUG({
1199 std::stringstream msg;
1200 msg << " evaluating fusion profitability at depth : " << i << "\n"
1201 << std::setprecision(2) << " additional compute fraction: "
1202 << 100.0 * additionalComputeFraction << "%\n"
1203 << " storage reduction factor: " << storageReduction << "x\n"
1204 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
1205 << " slice iteration count: " << sliceIterationCount << "\n";
1206 llvm::dbgs() << msg.str();
1207 });
Uday Bondhugula864d9e02019-01-23 17:16:241208
1209 double computeToleranceThreshold =
1210 clFusionAddlComputeTolerance.getNumOccurrences() > 0
1211 ? clFusionAddlComputeTolerance
1212 : LoopFusion::kComputeToleranceThreshold;
1213
1214 // TODO(b/123247369): This is a placeholder cost model.
1215 // Among all choices that add an acceptable amount of redundant computation
1216 // (as per computeToleranceThreshold), we will simply pick the one that
1217 // reduces the intermediary size the most.
1218 if ((storageReduction > maxStorageReduction) &&
1219 (clMaximalLoopFusion ||
1220 (additionalComputeFraction < computeToleranceThreshold))) {
1221 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:021222 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:241223 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
1224 // TODO(bondhugula,andydavis): find a good way to compute the memory
1225 // footprint of the materialized slice.
1226 // Approximating this to the compute cost of the slice. This could be an
1227 // under-approximation or an overapproximation, but in many cases
1228 // accurate.
1229 sliceMemEstimate = sliceIterationCount;
MLIR Team38c2fe32019-01-14 19:26:251230 }
1231 }
1232
Uday Bondhugula864d9e02019-01-23 17:16:241233 // A simple cost model: fuse if it reduces the memory footprint. If
1234 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251235
Uday Bondhugula864d9e02019-01-23 17:16:241236 if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
1237 LLVM_DEBUG(llvm::dbgs()
1238 << "All fusion choices involve more than the threshold amount of"
1239 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251240 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241241 }
1242
1243 assert(bestDstLoopDepth.hasValue() &&
1244 "expected to have a value per logic above");
1245
1246 // Set dstLoopDepth based on best values from search.
1247 *dstLoopDepth = bestDstLoopDepth.getValue();
1248
1249 LLVM_DEBUG(
Uday Bondhugula06d21d92019-01-25 01:01:491250 llvm::dbgs() << " LoopFusion fusion stats:"
1251 << "\n best loop depth: " << bestDstLoopDepth
Uday Bondhugula864d9e02019-01-23 17:16:241252 << "\n src loop nest compute cost: " << srcLoopNestCost
1253 << "\n dst loop nest compute cost: " << dstLoopNestCost
1254 << "\n fused loop nest compute cost: "
1255 << minFusedLoopNestComputeCost << "\n");
1256
River Riddle5052bd82019-02-02 00:42:181257 auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
1258 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Uday Bondhugula864d9e02019-01-23 17:16:241259
1260 Optional<double> storageReduction = None;
1261
1262 if (!clMaximalLoopFusion) {
1263 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1264 LLVM_DEBUG(
1265 llvm::dbgs()
1266 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1267 return false;
1268 }
1269
1270 auto srcMemSizeVal = srcMemSize.getValue();
1271 auto dstMemSizeVal = dstMemSize.getValue();
1272
1273 assert(sliceMemEstimate.hasValue() && "expected value");
1274 // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
1275 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1276
1277 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1278 << " dst mem: " << dstMemSizeVal << "\n"
1279 << " fused mem: " << fusedMem << "\n"
1280 << " slice mem: " << sliceMemEstimate << "\n");
1281
1282 if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
1283 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1284 return false;
1285 }
1286 storageReduction =
1287 100.0 *
1288 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1289 }
1290
1291 double additionalComputeFraction =
1292 100.0 * (minFusedLoopNestComputeCost /
1293 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1294 1);
MLIR Team5c5739d2019-01-25 06:27:401295 (void)additionalComputeFraction;
Uday Bondhugula06d21d92019-01-25 01:01:491296 LLVM_DEBUG({
1297 std::stringstream msg;
1298 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
1299 << setprecision(2) << additionalComputeFraction
1300 << "% redundant computation and a ";
1301 msg << (storageReduction.hasValue()
1302 ? std::to_string(storageReduction.getValue())
1303 : "<unknown>");
1304 msg << "% storage reduction.\n";
1305 llvm::dbgs() << msg.str();
1306 });
Uday Bondhugula864d9e02019-01-23 17:16:241307
MLIR Team27d067e2019-01-16 17:55:021308 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241309 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021310 sliceState->lbs = bestSliceState->lbs;
1311 sliceState->ubs = bestSliceState->ubs;
1312 sliceState->lbOperands = bestSliceState->lbOperands;
1313 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241314
MLIR Team27d067e2019-01-16 17:55:021315 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251316 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Nicolas Vasilache0e7a8a92019-01-26 18:41:171317 if (sliceState->lbs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021318 canonicalizeMapAndOperands(&sliceState->lbs[i],
1319 &sliceState->lbOperands[i]);
1320 }
Nicolas Vasilache0e7a8a92019-01-26 18:41:171321 if (sliceState->ubs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021322 canonicalizeMapAndOperands(&sliceState->ubs[i],
1323 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251324 }
1325 }
1326 return true;
1327}
1328
MLIR Team6892ffb2018-12-20 04:42:551329// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:141330// relationship on a memref, with the goal of improving locality. Currently,
1331// this the producer/consumer relationship is required to be unique in the
Chris Lattner69d9e992018-12-28 16:48:091332// Function (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:001333//
MLIR Team3b692302018-12-17 17:57:141334// The steps of the algorithm are as follows:
1335//
MLIR Team6892ffb2018-12-20 04:42:551336// *) A worklist is initialized with node ids from the dependence graph.
1337// *) For each node id in the worklist:
River Riddle5052bd82019-02-02 00:42:181338// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a
1339// candidate destination AffineForOp into which fusion will be attempted.
1340// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141341// *) For each LoadOp in 'dstLoadOps' do:
Chris Lattner69d9e992018-12-28 16:48:091342// *) Lookup dependent loop nests at earlier positions in the Function
MLIR Team3b692302018-12-17 17:57:141343// which have a single store op to the same memref.
1344// *) Check if dependences would be violated by the fusion. For example,
1345// the src loop nest may load from memrefs which are different than
1346// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:551347// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141348// bounds to be functions of 'dstLoopNest' IVs and symbols.
1349// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
1350// just before the dst load op user.
Chris Lattner456ad6a2018-12-29 00:05:351351// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:141352// and also add newly fuse load ops to 'dstLoopOps' to be considered
1353// as fusion dst load ops in another iteration.
1354// *) Remove old src loop nest and its associated state.
1355//
Chris Lattner456ad6a2018-12-29 00:05:351356// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141357// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551358// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141359//
MLIR Team6892ffb2018-12-20 04:42:551360// This greedy algorithm is not 'maximal' due to the current restriction of
1361// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141362//
1363// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551364// TODO(andydavis) Add support for fusing for input reuse (perhaps by
1365// constructing a graph with edges which represent loads from the same memref
MLIR Team5c5739d2019-01-25 06:27:401366// in two different loop nests.
MLIR Team6892ffb2018-12-20 04:42:551367struct GreedyFusion {
1368public:
1369 MemRefDependenceGraph *mdg;
MLIR Team3b692302018-12-17 17:57:141370 SmallVector<unsigned, 4> worklist;
MLIR Teamf28e4df2018-11-01 14:26:001371
MLIR Team6892ffb2018-12-20 04:42:551372 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
1373 // Initialize worklist with nodes from 'mdg'.
1374 worklist.resize(mdg->nodes.size());
1375 std::iota(worklist.begin(), worklist.end(), 0);
1376 }
MLIR Team3b692302018-12-17 17:57:141377
Uday Bondhugula8be26272019-02-02 01:06:221378 void run(unsigned localBufSizeThreshold, Optional<unsigned> fastMemorySpace) {
MLIR Team3b692302018-12-17 17:57:141379 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551380 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141381 worklist.pop_back();
MLIR Team6892ffb2018-12-20 04:42:551382 // Skip if this node was removed (fused into another node).
1383 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141384 continue;
MLIR Team6892ffb2018-12-20 04:42:551385 // Get 'dstNode' into which to attempt fusion.
1386 auto *dstNode = mdg->getNode(dstId);
1387 // Skip if 'dstNode' is not a loop nest.
River Riddleb4992772019-02-04 18:38:471388 if (!dstNode->inst->isa<AffineForOp>())
MLIR Team3b692302018-12-17 17:57:141389 continue;
1390
River Riddleb4992772019-02-04 18:38:471391 SmallVector<Instruction *, 4> loads = dstNode->loads;
1392 SmallVector<Instruction *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271393 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551394 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021395 // Get memref of load on top of the stack.
1396 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271397 if (visitedMemrefs.count(memref) > 0)
1398 continue;
1399 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021400 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1401 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551402 // Skip if no input edges along which to fuse.
1403 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141404 continue;
MLIR Team1e851912019-01-31 00:01:461405 // Iterate through in edges for 'dstId' and src node id for any
1406 // edges on 'memref'.
1407 SmallVector<unsigned, 2> srcNodeIds;
MLIR Team6892ffb2018-12-20 04:42:551408 for (auto &srcEdge : mdg->inEdges[dstId]) {
1409 // Skip 'srcEdge' if not for 'memref'.
MLIR Teama0f3db402019-01-29 17:36:411410 if (srcEdge.value != memref)
MLIR Team6892ffb2018-12-20 04:42:551411 continue;
MLIR Team1e851912019-01-31 00:01:461412 srcNodeIds.push_back(srcEdge.id);
1413 }
1414 for (unsigned srcId : srcNodeIds) {
1415 // Skip if this node was removed (fused into another node).
1416 if (mdg->nodes.count(srcId) == 0)
1417 continue;
1418 // Get 'srcNode' from which to attempt fusion into 'dstNode'.
1419 auto *srcNode = mdg->getNode(srcId);
MLIR Team6892ffb2018-12-20 04:42:551420 // Skip if 'srcNode' is not a loop nest.
River Riddleb4992772019-02-04 18:38:471421 if (!srcNode->inst->isa<AffineForOp>())
MLIR Team6892ffb2018-12-20 04:42:551422 continue;
MLIR Teamb28009b2019-01-23 19:11:431423 // Skip if 'srcNode' has more than one store to any memref.
1424 // TODO(andydavis) Support fusing multi-output src loop nests.
1425 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551426 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241427
MLIR Teama0f3db402019-01-29 17:36:411428 // Skip 'srcNode' if it has in edges on 'memref'.
MLIR Team6892ffb2018-12-20 04:42:551429 // TODO(andydavis) Track dependence type with edges, and just check
MLIR Teama0f3db402019-01-29 17:36:411430 // for WAW dependence edge here. Note that this check is overly
1431 // conservative and will be removed in the future.
1432 if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0)
MLIR Team6892ffb2018-12-20 04:42:551433 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241434
MLIR Teamd7c82442019-01-30 23:53:411435 // Skip if 'srcNode' writes to any live in or escaping memrefs.
1436 if (mdg->writesToLiveInOrEscapingMemrefs(srcNode->id))
1437 continue;
1438
MLIR Teama0f3db402019-01-29 17:36:411439 // Compute an instruction list insertion point for the fused loop
1440 // nest which preserves dependences.
1441 Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint(
1442 srcNode->id, dstNode->id, memref);
1443 if (insertPointInst == nullptr)
MLIR Team6892ffb2018-12-20 04:42:551444 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241445
MLIR Team6892ffb2018-12-20 04:42:551446 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351447 auto *srcStoreOpInst = srcNode->stores.front();
MLIR Teamd7c82442019-01-30 23:53:411448 // Gather 'dstNode' store ops to 'memref'.
River Riddleb4992772019-02-04 18:38:471449 SmallVector<Instruction *, 2> dstStoreOpInsts;
MLIR Teamd7c82442019-01-30 23:53:411450 for (auto *storeOpInst : dstNode->stores)
1451 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1452 dstStoreOpInsts.push_back(storeOpInst);
1453
Uday Bondhugulab4a14432019-01-26 00:00:501454 unsigned bestDstLoopDepth;
MLIR Team38c2fe32019-01-14 19:26:251455 mlir::ComputationSliceState sliceState;
MLIR Teama0f3db402019-01-29 17:36:411456 // Check if fusion would be profitable.
MLIR Teamd7c82442019-01-30 23:53:411457 if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts,
1458 dstStoreOpInsts, &sliceState,
Uday Bondhugulab4a14432019-01-26 00:00:501459 &bestDstLoopDepth))
MLIR Team38c2fe32019-01-14 19:26:251460 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241461
MLIR Team6892ffb2018-12-20 04:42:551462 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
River Riddle5052bd82019-02-02 00:42:181463 auto sliceLoopNest = mlir::insertBackwardComputationSlice(
Uday Bondhugulab4a14432019-01-26 00:00:501464 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:551465 if (sliceLoopNest != nullptr) {
River Riddle5052bd82019-02-02 00:42:181466 // Move 'dstAffineForOp' before 'insertPointInst' if needed.
River Riddleb4992772019-02-04 18:38:471467 auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
River Riddle5052bd82019-02-02 00:42:181468 if (insertPointInst != dstAffineForOp->getInstruction()) {
1469 dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
MLIR Teama0f3db402019-01-29 17:36:411470 }
MLIR Teamc4237ae2019-01-18 16:56:271471 // Update edges between 'srcNode' and 'dstNode'.
MLIR Teama0f3db402019-01-29 17:36:411472 mdg->updateEdges(srcNode->id, dstNode->id, memref);
MLIR Teamc4237ae2019-01-18 16:56:271473
1474 // Collect slice loop stats.
1475 LoopNestStateCollector sliceCollector;
River Riddle5052bd82019-02-02 00:42:181476 sliceCollector.walk(sliceLoopNest->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271477 // Promote single iteration slice loops to single IV value.
River Riddle5052bd82019-02-02 00:42:181478 for (auto forOp : sliceCollector.forOps) {
1479 promoteIfSingleIteration(forOp);
MLIR Team6892ffb2018-12-20 04:42:551480 }
River Riddle5052bd82019-02-02 00:42:181481 // Create private memref for 'memref' in 'dstAffineForOp'.
River Riddleb4992772019-02-04 18:38:471482 SmallVector<Instruction *, 4> storesForMemref;
MLIR Teamc4237ae2019-01-18 16:56:271483 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
1484 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1485 storesForMemref.push_back(storeOpInst);
1486 }
1487 assert(storesForMemref.size() == 1);
Uday Bondhugula94a03f82019-01-22 21:58:521488 auto *newMemRef = createPrivateMemRef(
Uday Bondhugula8be26272019-02-02 01:06:221489 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1490 fastMemorySpace, localBufSizeThreshold);
MLIR Teamc4237ae2019-01-18 16:56:271491 visitedMemrefs.insert(newMemRef);
MLIR Teama0f3db402019-01-29 17:36:411492 // Create new node in dependence graph for 'newMemRef' alloc op.
1493 unsigned newMemRefNodeId =
1494 mdg->addNode(newMemRef->getDefiningInst());
1495 // Add edge from 'newMemRef' node to dstNode.
1496 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
MLIR Teamc4237ae2019-01-18 16:56:271497
1498 // Collect dst loop stats after memref privatizaton transformation.
1499 LoopNestStateCollector dstLoopCollector;
River Riddle5052bd82019-02-02 00:42:181500 dstLoopCollector.walk(dstAffineForOp->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271501
1502 // Add new load ops to current Node load op list 'loads' to
1503 // continue fusing based on new operands.
1504 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
1505 auto *loadMemRef = loadOpInst->cast<LoadOp>()->getMemRef();
1506 if (visitedMemrefs.count(loadMemRef) == 0)
1507 loads.push_back(loadOpInst);
1508 }
1509
1510 // Clear and add back loads and stores
1511 mdg->clearNodeLoadAndStores(dstNode->id);
1512 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1513 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371514 // Remove old src loop nest if it no longer has outgoing dependence
1515 // edges, and it does not write to a memref which escapes the
1516 // function.
MLIR Teama0f3db402019-01-29 17:36:411517 if (mdg->canRemoveNode(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271518 mdg->removeNode(srcNode->id);
River Riddle5052bd82019-02-02 00:42:181519 srcNode->inst->erase();
MLIR Teamc4237ae2019-01-18 16:56:271520 }
MLIR Team3b692302018-12-17 17:57:141521 }
MLIR Team3b692302018-12-17 17:57:141522 }
1523 }
1524 }
MLIR Teamc4237ae2019-01-18 16:56:271525 // Clean up any allocs with no users.
1526 for (auto &pair : mdg->memrefEdgeCount) {
1527 if (pair.second > 0)
1528 continue;
1529 auto *memref = pair.first;
MLIR Team71495d52019-01-22 21:23:371530 // Skip if there exist other uses (return instruction or function calls).
1531 if (!memref->use_empty())
1532 continue;
MLIR Teamc4237ae2019-01-18 16:56:271533 // Use list expected to match the dep graph info.
MLIR Teamc4237ae2019-01-18 16:56:271534 auto *inst = memref->getDefiningInst();
River Riddleb4992772019-02-04 18:38:471535 if (inst && inst->isa<AllocOp>())
1536 inst->erase();
MLIR Teamc4237ae2019-01-18 16:56:271537 }
MLIR Teamf28e4df2018-11-01 14:26:001538 }
MLIR Team3b692302018-12-17 17:57:141539};
1540
1541} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:001542
Chris Lattner79748892018-12-31 07:10:351543PassResult LoopFusion::runOnFunction(Function *f) {
Uday Bondhugula8be26272019-02-02 01:06:221544 if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
1545 fastMemorySpace = clFusionFastMemorySpace.getValue();
1546 }
1547
MLIR Team6892ffb2018-12-20 04:42:551548 MemRefDependenceGraph g;
1549 if (g.init(f))
Uday Bondhugula8be26272019-02-02 01:06:221550 GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace);
MLIR Teamf28e4df2018-11-01 14:26:001551 return success();
1552}
Jacques Pienaar6f0fb222018-11-07 02:34:181553
1554static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");