blob: 77e5a6aa04f8b4b6f5e423a2c518e3d8585a7c36 [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
River Riddle75553832019-01-29 05:23:5322#include "mlir/AffineOps/AffineOps.h"
MLIR Teamf28e4df2018-11-01 14:26:0023#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1424#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0025#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1426#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0027#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinOps.h"
MLIR Teamf28e4df2018-11-01 14:26:0031#include "mlir/Pass.h"
32#include "mlir/StandardOps/StandardOps.h"
33#include "mlir/Transforms/LoopUtils.h"
34#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2735#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0036#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1437#include "llvm/ADT/DenseSet.h"
38#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2339#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2540#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1441#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2442#include <iomanip>
MLIR Team3b692302018-12-17 17:57:1443
MLIR Team38c2fe32019-01-14 19:26:2544#define DEBUG_TYPE "loop-fusion"
45
MLIR Team3b692302018-12-17 17:57:1446using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0047
48using namespace mlir;
49
River Riddle75c21e12019-01-26 06:14:0450static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
51
Uday Bondhugula864d9e02019-01-23 17:16:2452/// Disables fusion profitability check and fuses if valid.
MLIR Teamc4237ae2019-01-18 16:56:2753static llvm::cl::opt<bool>
54 clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
River Riddle75c21e12019-01-26 06:14:0455 llvm::cl::desc("Enables maximal loop fusion"),
56 llvm::cl::cat(clOptionsCategory));
Uday Bondhugula864d9e02019-01-23 17:16:2457
58/// A threshold in percent of additional computation allowed when fusing.
59static llvm::cl::opt<double> clFusionAddlComputeTolerance(
60 "fusion-compute-tolerance", llvm::cl::Hidden,
61 llvm::cl::desc("Fractional increase in additional"
River Riddle75c21e12019-01-26 06:14:0462 " computation tolerated while fusing"),
63 llvm::cl::cat(clOptionsCategory));
MLIR Teamc4237ae2019-01-18 16:56:2764
Uday Bondhugula8be26272019-02-02 01:06:2265static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
66 "fusion-fast-mem-space", llvm::cl::Hidden,
67 llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
68 llvm::cl::cat(clOptionsCategory));
69
70static llvm::cl::opt<unsigned> clFusionLocalBufThreshold(
71 "fusion-local-buf-threshold", llvm::cl::Hidden,
72 llvm::cl::desc("Threshold size (bytes) for promoting local buffers to fast "
73 "memory space"),
74 llvm::cl::cat(clOptionsCategory));
75
MLIR Teamf28e4df2018-11-01 14:26:0076namespace {
77
MLIR Team3b692302018-12-17 17:57:1478/// Loop fusion pass. This pass currently supports a greedy fusion policy,
79/// which fuses loop nests with single-writer/single-reader memref dependences
80/// with the goal of improving locality.
81
82// TODO(andydavis) Support fusion of source loop nests which write to multiple
83// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0084// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
85// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1486
MLIR Teamf28e4df2018-11-01 14:26:0087struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0388 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0089
Chris Lattner79748892018-12-31 07:10:3590 PassResult runOnFunction(Function *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1891 static char passID;
Uday Bondhugula864d9e02019-01-23 17:16:2492
Uday Bondhugula8be26272019-02-02 01:06:2293 // Any local buffers smaller than this size will be created in
94 // `fastMemorySpace` if provided.
95 unsigned localBufSizeThreshold = 1024;
96 Optional<unsigned> fastMemorySpace = None;
97
Uday Bondhugula864d9e02019-01-23 17:16:2498 // The amount of additional computation that is tolerated while fusing
99 // pair-wise as a fraction of the total computation.
100 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:00101};
102
MLIR Teamf28e4df2018-11-01 14:26:00103} // end anonymous namespace
104
Jacques Pienaar6f0fb222018-11-07 02:34:18105char LoopFusion::passID = 0;
106
MLIR Teamf28e4df2018-11-01 14:26:00107FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
108
MLIR Team3b692302018-12-17 17:57:14109namespace {
MLIR Teamf28e4df2018-11-01 14:26:00110
MLIR Team3b692302018-12-17 17:57:14111// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:35112// operations, and whether or not an IfInst was encountered in the loop nest.
River Riddlebf9c3812019-02-05 00:24:44113struct LoopNestStateCollector {
River Riddle5052bd82019-02-02 00:42:18114 SmallVector<OpPointer<AffineForOp>, 4> forOps;
River Riddleb4992772019-02-04 18:38:47115 SmallVector<Instruction *, 4> loadOpInsts;
116 SmallVector<Instruction *, 4> storeOpInsts;
River Riddle75553832019-01-29 05:23:53117 bool hasNonForRegion = false;
MLIR Team3b692302018-12-17 17:57:14118
River Riddlebf9c3812019-02-05 00:24:44119 void collect(Instruction *instToWalk) {
120 instToWalk->walk([&](Instruction *opInst) {
121 if (opInst->isa<AffineForOp>())
122 forOps.push_back(opInst->cast<AffineForOp>());
123 else if (opInst->getNumBlockLists() != 0)
124 hasNonForRegion = true;
125 else if (opInst->isa<LoadOp>())
126 loadOpInsts.push_back(opInst);
127 else if (opInst->isa<StoreOp>())
128 storeOpInsts.push_back(opInst);
129 });
MLIR Team3b692302018-12-17 17:57:14130 }
131};
132
MLIR Team71495d52019-01-22 21:23:37133// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
River Riddleb4992772019-02-04 18:38:47134static bool isMemRefDereferencingOp(const Instruction &op) {
MLIR Team71495d52019-01-22 21:23:37135 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
136 op.isa<DmaWaitOp>())
137 return true;
138 return false;
139}
MLIR Team6892ffb2018-12-20 04:42:55140// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:35141// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55142// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27143// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55144// TODO(andydavis) Add a depth parameter to dependence graph construction.
145struct MemRefDependenceGraph {
146public:
147 // Node represents a node in the graph. A Node is either an entire loop nest
148 // rooted at the top level which contains loads/stores, or a top level
149 // load/store.
150 struct Node {
151 // The unique identifier of this node in the graph.
152 unsigned id;
153 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35154 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41155 // List of load operations.
River Riddleb4992772019-02-04 18:38:47156 SmallVector<Instruction *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35157 // List of store op insts.
River Riddleb4992772019-02-04 18:38:47158 SmallVector<Instruction *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35159 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55160
161 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10162 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55163 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35164 for (auto *loadOpInst : loads) {
165 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55166 ++loadOpCount;
167 }
168 return loadOpCount;
169 }
170
171 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10172 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55173 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35174 for (auto *storeOpInst : stores) {
175 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55176 ++storeOpCount;
177 }
178 return storeOpCount;
179 }
180 };
181
MLIR Teama0f3db402019-01-29 17:36:41182 // Edge represents a data dependece between nodes in the graph.
MLIR Team6892ffb2018-12-20 04:42:55183 struct Edge {
184 // The id of the node at the other end of the edge.
MLIR Team1e851912019-01-31 00:01:46185 // If this edge is stored in Edge = Node.inEdges[i], then
186 // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
187 // If this edge is stored in Edge = Node.outEdges[i], then
188 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
MLIR Team6892ffb2018-12-20 04:42:55189 unsigned id;
MLIR Teama0f3db402019-01-29 17:36:41190 // The SSA value on which this edge represents a dependence.
191 // If the value is a memref, then the dependence is between graph nodes
192 // which contain accesses to the same memref 'value'. If the value is a
193 // non-memref value, then the dependence is between a graph node which
194 // defines an SSA value and another graph node which uses the SSA value
195 // (e.g. a constant instruction defining a value which is used inside a loop
196 // nest).
197 Value *value;
MLIR Team6892ffb2018-12-20 04:42:55198 };
199
200 // Map from node id to Node.
201 DenseMap<unsigned, Node> nodes;
202 // Map from node id to list of input edges.
203 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
204 // Map from node id to list of output edges.
205 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27206 // Map from memref to a count on the dependence edges associated with that
207 // memref.
208 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Teama0f3db402019-01-29 17:36:41209 // The next unique identifier to use for newly created graph nodes.
210 unsigned nextNodeId = 0;
MLIR Team6892ffb2018-12-20 04:42:55211
212 MemRefDependenceGraph() {}
213
214 // Initializes the dependence graph based on operations in 'f'.
215 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09216 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55217
218 // Returns the graph node for 'id'.
219 Node *getNode(unsigned id) {
220 auto it = nodes.find(id);
221 assert(it != nodes.end());
222 return &it->second;
223 }
224
MLIR Teama0f3db402019-01-29 17:36:41225 // Adds a node with 'inst' to the graph and returns its unique identifier.
226 unsigned addNode(Instruction *inst) {
227 Node node(nextNodeId++, inst);
228 nodes.insert({node.id, node});
229 return node.id;
230 }
231
MLIR Teamc4237ae2019-01-18 16:56:27232 // Remove node 'id' (and its associated edges) from graph.
233 void removeNode(unsigned id) {
234 // Remove each edge in 'inEdges[id]'.
235 if (inEdges.count(id) > 0) {
236 SmallVector<Edge, 2> oldInEdges = inEdges[id];
237 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41238 removeEdge(inEdge.id, id, inEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27239 }
240 }
241 // Remove each edge in 'outEdges[id]'.
242 if (outEdges.count(id) > 0) {
243 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
244 for (auto &outEdge : oldOutEdges) {
MLIR Teama0f3db402019-01-29 17:36:41245 removeEdge(id, outEdge.id, outEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27246 }
247 }
248 // Erase remaining node state.
249 inEdges.erase(id);
250 outEdges.erase(id);
251 nodes.erase(id);
252 }
253
MLIR Teamd7c82442019-01-30 23:53:41254 // Returns true if node 'id' writes to any memref which escapes (or is an
255 // argument to) the function/block. Returns false otherwise.
256 bool writesToLiveInOrEscapingMemrefs(unsigned id) {
MLIR Team71495d52019-01-22 21:23:37257 Node *node = getNode(id);
258 for (auto *storeOpInst : node->stores) {
259 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
260 auto *inst = memref->getDefiningInst();
River Riddleb4992772019-02-04 18:38:47261 // Return false if 'memref' is a block argument.
262 if (!inst)
MLIR Teamd7c82442019-01-30 23:53:41263 return true;
MLIR Team71495d52019-01-22 21:23:37264 // Return false if any use of 'memref' escapes the function.
River Riddleb4992772019-02-04 18:38:47265 for (auto &use : memref->getUses())
266 if (!isMemRefDereferencingOp(*use.getOwner()))
MLIR Teamd7c82442019-01-30 23:53:41267 return true;
MLIR Teamd7c82442019-01-30 23:53:41268 }
269 return false;
270 }
271
272 // Returns true if node 'id' can be removed from the graph. Returns false
273 // otherwise. A node can be removed from the graph iff the following
274 // conditions are met:
275 // *) The node does not write to any memref which escapes (or is a
276 // function/block argument).
277 // *) The node has no successors in the dependence graph.
278 bool canRemoveNode(unsigned id) {
279 if (writesToLiveInOrEscapingMemrefs(id))
280 return false;
281 Node *node = getNode(id);
282 for (auto *storeOpInst : node->stores) {
MLIR Teama0f3db402019-01-29 17:36:41283 // Return false if there exist out edges from 'id' on 'memref'.
MLIR Teamd7c82442019-01-30 23:53:41284 if (getOutEdgeCount(id, storeOpInst->cast<StoreOp>()->getMemRef()) > 0)
MLIR Teama0f3db402019-01-29 17:36:41285 return false;
MLIR Team71495d52019-01-22 21:23:37286 }
MLIR Teama0f3db402019-01-29 17:36:41287 return true;
MLIR Team71495d52019-01-22 21:23:37288 }
289
MLIR Team27d067e2019-01-16 17:55:02290 // Returns true iff there is an edge from node 'srcId' to node 'dstId' for
MLIR Teama0f3db402019-01-29 17:36:41291 // 'value'. Returns false otherwise.
292 bool hasEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team27d067e2019-01-16 17:55:02293 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
294 return false;
295 }
296 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41297 return edge.id == dstId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02298 });
299 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41300 return edge.id == srcId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02301 });
302 return hasOutEdge && hasInEdge;
303 }
304
MLIR Teama0f3db402019-01-29 17:36:41305 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
306 void addEdge(unsigned srcId, unsigned dstId, Value *value) {
307 if (!hasEdge(srcId, dstId, value)) {
308 outEdges[srcId].push_back({dstId, value});
309 inEdges[dstId].push_back({srcId, value});
310 if (value->getType().isa<MemRefType>())
311 memrefEdgeCount[value]++;
MLIR Team27d067e2019-01-16 17:55:02312 }
MLIR Team6892ffb2018-12-20 04:42:55313 }
314
MLIR Teama0f3db402019-01-29 17:36:41315 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
316 void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team6892ffb2018-12-20 04:42:55317 assert(inEdges.count(dstId) > 0);
318 assert(outEdges.count(srcId) > 0);
MLIR Teama0f3db402019-01-29 17:36:41319 if (value->getType().isa<MemRefType>()) {
320 assert(memrefEdgeCount.count(value) > 0);
321 memrefEdgeCount[value]--;
322 }
MLIR Team6892ffb2018-12-20 04:42:55323 // Remove 'srcId' from 'inEdges[dstId]'.
324 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41325 if ((*it).id == srcId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55326 inEdges[dstId].erase(it);
327 break;
328 }
329 }
330 // Remove 'dstId' from 'outEdges[srcId]'.
331 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41332 if ((*it).id == dstId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55333 outEdges[srcId].erase(it);
334 break;
335 }
336 }
337 }
338
MLIR Teama0f3db402019-01-29 17:36:41339 // Returns the input edge count for node 'id' and 'memref' from src nodes
340 // which access 'memref'.
341 unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55342 unsigned inEdgeCount = 0;
343 if (inEdges.count(id) > 0)
344 for (auto &inEdge : inEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41345 if (inEdge.value == memref) {
346 Node *srcNode = getNode(inEdge.id);
347 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
348 if (srcNode->getLoadOpCount(memref) > 0 ||
349 srcNode->getStoreOpCount(memref) > 0)
350 ++inEdgeCount;
351 }
MLIR Team6892ffb2018-12-20 04:42:55352 return inEdgeCount;
353 }
354
355 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10356 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55357 unsigned outEdgeCount = 0;
358 if (outEdges.count(id) > 0)
359 for (auto &outEdge : outEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41360 if (outEdge.value == memref)
MLIR Team6892ffb2018-12-20 04:42:55361 ++outEdgeCount;
362 return outEdgeCount;
363 }
364
MLIR Teama0f3db402019-01-29 17:36:41365 // Computes and returns an insertion point instruction, before which the
366 // the fused <srcId, dstId> loop nest can be inserted while preserving
367 // dependences. Returns nullptr if no such insertion point is found.
368 Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId,
369 Value *memrefToSkip) {
MLIR Team5c5739d2019-01-25 06:27:40370 if (outEdges.count(srcId) == 0)
MLIR Teama0f3db402019-01-29 17:36:41371 return getNode(dstId)->inst;
372
373 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
374 SmallPtrSet<Instruction *, 2> srcDepInsts;
375 for (auto &outEdge : outEdges[srcId])
376 if (outEdge.id != dstId && outEdge.value != memrefToSkip)
377 srcDepInsts.insert(getNode(outEdge.id)->inst);
378
379 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
380 SmallPtrSet<Instruction *, 2> dstDepInsts;
381 for (auto &inEdge : inEdges[dstId])
382 if (inEdge.id != srcId && inEdge.value != memrefToSkip)
383 dstDepInsts.insert(getNode(inEdge.id)->inst);
384
385 Instruction *srcNodeInst = getNode(srcId)->inst;
386 Instruction *dstNodeInst = getNode(dstId)->inst;
387
388 // Computing insertion point:
389 // *) Walk all instruction positions in Block instruction list in the
390 // range (src, dst). For each instruction 'inst' visited in this search:
391 // *) Store in 'firstSrcDepPos' the first position where 'inst' has a
392 // dependence edge from 'srcNode'.
393 // *) Store in 'lastDstDepPost' the last position where 'inst' has a
394 // dependence edge to 'dstNode'.
395 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
396 // instruction insertion point (or return null pointer if no such
397 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
398 SmallVector<Instruction *, 2> depInsts;
399 Optional<unsigned> firstSrcDepPos;
400 Optional<unsigned> lastDstDepPos;
401 unsigned pos = 0;
402 for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
403 it != Block::iterator(dstNodeInst); ++it) {
404 Instruction *inst = &(*it);
405 if (srcDepInsts.count(inst) > 0 && firstSrcDepPos == None)
406 firstSrcDepPos = pos;
407 if (dstDepInsts.count(inst) > 0)
408 lastDstDepPos = pos;
409 depInsts.push_back(inst);
410 ++pos;
MLIR Team5c5739d2019-01-25 06:27:40411 }
MLIR Teama0f3db402019-01-29 17:36:41412
413 if (firstSrcDepPos.hasValue()) {
414 if (lastDstDepPos.hasValue()) {
415 if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
416 // No valid insertion point exists which preserves dependences.
417 return nullptr;
418 }
419 }
420 // Return the insertion point at 'firstSrcDepPos'.
421 return depInsts[firstSrcDepPos.getValue()];
422 }
423 // No dependence targets in range (or only dst deps in range), return
424 // 'dstNodInst' insertion point.
425 return dstNodeInst;
MLIR Team6892ffb2018-12-20 04:42:55426 }
427
MLIR Teama0f3db402019-01-29 17:36:41428 // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
429 // has been replaced in node at 'dstId' by a private memref.
430 void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
MLIR Team6892ffb2018-12-20 04:42:55431 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
432 if (inEdges.count(srcId) > 0) {
433 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
434 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41435 // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
436 if (inEdge.value != oldMemRef)
437 addEdge(inEdge.id, dstId, inEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55438 }
439 }
MLIR Teamc4237ae2019-01-18 16:56:27440 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55441 if (outEdges.count(srcId) > 0) {
442 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
443 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27444 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
445 if (outEdge.id == dstId)
MLIR Teama0f3db402019-01-29 17:36:41446 removeEdge(srcId, outEdge.id, outEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55447 }
448 }
MLIR Teama0f3db402019-01-29 17:36:41449 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
450 // replaced by a private memref). These edges could come from nodes
451 // other than 'srcId' which were removed in the previous step.
452 if (inEdges.count(dstId) > 0) {
453 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
454 for (auto &inEdge : oldInEdges)
455 if (inEdge.value == oldMemRef)
456 removeEdge(inEdge.id, dstId, inEdge.value);
457 }
MLIR Team6892ffb2018-12-20 04:42:55458 }
459
460 // Adds ops in 'loads' and 'stores' to node at 'id'.
River Riddleb4992772019-02-04 18:38:47461 void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads,
462 const SmallVectorImpl<Instruction *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55463 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35464 for (auto *loadOpInst : loads)
465 node->loads.push_back(loadOpInst);
466 for (auto *storeOpInst : stores)
467 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55468 }
469
MLIR Teamc4237ae2019-01-18 16:56:27470 void clearNodeLoadAndStores(unsigned id) {
471 Node *node = getNode(id);
472 node->loads.clear();
473 node->stores.clear();
474 }
475
MLIR Team6892ffb2018-12-20 04:42:55476 void print(raw_ostream &os) const {
477 os << "\nMemRefDependenceGraph\n";
478 os << "\nNodes:\n";
479 for (auto &idAndNode : nodes) {
480 os << "Node: " << idAndNode.first << "\n";
481 auto it = inEdges.find(idAndNode.first);
482 if (it != inEdges.end()) {
483 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41484 os << " InEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55485 }
486 it = outEdges.find(idAndNode.first);
487 if (it != outEdges.end()) {
488 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41489 os << " OutEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55490 }
491 }
492 }
493 void dump() const { print(llvm::errs()); }
494};
495
Chris Lattner456ad6a2018-12-29 00:05:35496// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55497// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39498// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55499// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09500bool MemRefDependenceGraph::init(Function *f) {
Chris Lattner3f190312018-12-27 22:35:10501 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43502
503 // TODO: support multi-block functions.
504 if (f->getBlocks().size() != 1)
505 return false;
506
River Riddle5052bd82019-02-02 00:42:18507 DenseMap<Instruction *, unsigned> forToNodeMap;
Chris Lattnerdffc5892018-12-29 23:33:43508 for (auto &inst : f->front()) {
River Riddleb4992772019-02-04 18:38:47509 if (auto forOp = inst.dyn_cast<AffineForOp>()) {
River Riddle5052bd82019-02-02 00:42:18510 // Create graph node 'id' to represent top-level 'forOp' and record
MLIR Team6892ffb2018-12-20 04:42:55511 // all loads and store accesses it contains.
512 LoopNestStateCollector collector;
River Riddlebf9c3812019-02-05 00:24:44513 collector.collect(&inst);
River Riddle5052bd82019-02-02 00:42:18514 // Return false if a non 'for' region was found (not currently supported).
River Riddle75553832019-01-29 05:23:53515 if (collector.hasNonForRegion)
MLIR Team6892ffb2018-12-20 04:42:55516 return false;
MLIR Teama0f3db402019-01-29 17:36:41517 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35518 for (auto *opInst : collector.loadOpInsts) {
519 node.loads.push_back(opInst);
520 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55521 memrefAccesses[memref].insert(node.id);
522 }
Chris Lattner456ad6a2018-12-29 00:05:35523 for (auto *opInst : collector.storeOpInsts) {
524 node.stores.push_back(opInst);
525 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55526 memrefAccesses[memref].insert(node.id);
527 }
River Riddle5052bd82019-02-02 00:42:18528 forToNodeMap[&inst] = node.id;
MLIR Team6892ffb2018-12-20 04:42:55529 nodes.insert({node.id, node});
River Riddleb4992772019-02-04 18:38:47530 } else if (auto loadOp = inst.dyn_cast<LoadOp>()) {
531 // Create graph node for top-level load op.
532 Node node(nextNodeId++, &inst);
533 node.loads.push_back(&inst);
534 auto *memref = inst.cast<LoadOp>()->getMemRef();
535 memrefAccesses[memref].insert(node.id);
536 nodes.insert({node.id, node});
537 } else if (auto storeOp = inst.dyn_cast<StoreOp>()) {
538 // Create graph node for top-level store op.
539 Node node(nextNodeId++, &inst);
540 node.stores.push_back(&inst);
541 auto *memref = inst.cast<StoreOp>()->getMemRef();
542 memrefAccesses[memref].insert(node.id);
543 nodes.insert({node.id, node});
544 } else if (inst.getNumBlockLists() != 0) {
545 // Return false if another region is found (not currently supported).
546 return false;
547 } else if (inst.getNumResults() > 0 && !inst.use_empty()) {
548 // Create graph node for top-level producer of SSA values, which
549 // could be used by loop nest nodes.
550 Node node(nextNodeId++, &inst);
551 nodes.insert({node.id, node});
MLIR Teama0f3db402019-01-29 17:36:41552 }
553 }
554
555 // Add dependence edges between nodes which produce SSA values and their
556 // users.
557 for (auto &idAndNode : nodes) {
558 const Node &node = idAndNode.second;
559 if (!node.loads.empty() || !node.stores.empty())
560 continue;
River Riddleb4992772019-02-04 18:38:47561 auto *opInst = node.inst;
MLIR Teama0f3db402019-01-29 17:36:41562 for (auto *value : opInst->getResults()) {
563 for (auto &use : value->getUses()) {
River Riddle5052bd82019-02-02 00:42:18564 SmallVector<OpPointer<AffineForOp>, 4> loops;
River Riddleb4992772019-02-04 18:38:47565 getLoopIVs(*use.getOwner(), &loops);
MLIR Teama0f3db402019-01-29 17:36:41566 if (loops.empty())
567 continue;
River Riddle5052bd82019-02-02 00:42:18568 assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
569 unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()];
MLIR Teama0f3db402019-01-29 17:36:41570 addEdge(node.id, userLoopNestId, value);
MLIR Team6892ffb2018-12-20 04:42:55571 }
572 }
MLIR Team6892ffb2018-12-20 04:42:55573 }
574
575 // Walk memref access lists and add graph edges between dependent nodes.
576 for (auto &memrefAndList : memrefAccesses) {
577 unsigned n = memrefAndList.second.size();
578 for (unsigned i = 0; i < n; ++i) {
579 unsigned srcId = memrefAndList.second[i];
580 bool srcHasStore =
581 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
582 for (unsigned j = i + 1; j < n; ++j) {
583 unsigned dstId = memrefAndList.second[j];
584 bool dstHasStore =
585 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
586 if (srcHasStore || dstHasStore)
587 addEdge(srcId, dstId, memrefAndList.first);
588 }
589 }
590 }
591 return true;
592}
593
MLIR Team38c2fe32019-01-14 19:26:25594namespace {
595
596// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
597// and operation count) for a loop nest up until the innermost loop body.
598struct LoopNestStats {
River Riddle5052bd82019-02-02 00:42:18599 // Map from AffineForOp to immediate child AffineForOps in its loop body.
600 DenseMap<Instruction *, SmallVector<OpPointer<AffineForOp>, 2>> loopMap;
601 // Map from AffineForOp to count of operations in its loop body.
602 DenseMap<Instruction *, uint64_t> opCountMap;
603 // Map from AffineForOp to its constant trip count.
604 DenseMap<Instruction *, uint64_t> tripCountMap;
MLIR Team38c2fe32019-01-14 19:26:25605};
606
607// LoopNestStatsCollector walks a single loop nest and gathers per-loop
608// trip count and operation count statistics and records them in 'stats'.
River Riddlebf9c3812019-02-05 00:24:44609struct LoopNestStatsCollector {
MLIR Team38c2fe32019-01-14 19:26:25610 LoopNestStats *stats;
611 bool hasLoopWithNonConstTripCount = false;
612
613 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
614
River Riddlebf9c3812019-02-05 00:24:44615 void collect(Instruction *inst) {
616 inst->walk<AffineForOp>([&](OpPointer<AffineForOp> forOp) {
617 auto *forInst = forOp->getInstruction();
618 auto *parentInst = forOp->getInstruction()->getParentInst();
619 if (parentInst != nullptr) {
620 assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
621 // Add mapping to 'forOp' from its parent AffineForOp.
622 stats->loopMap[parentInst].push_back(forOp);
623 }
River Riddle5052bd82019-02-02 00:42:18624
River Riddlebf9c3812019-02-05 00:24:44625 // Record the number of op instructions in the body of 'forOp'.
626 unsigned count = 0;
627 stats->opCountMap[forInst] = 0;
628 for (auto &inst : *forOp->getBody()) {
629 if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
630 ++count;
631 }
632 stats->opCountMap[forInst] = count;
633 // Record trip count for 'forOp'. Set flag if trip count is not
634 // constant.
635 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
636 if (!maybeConstTripCount.hasValue()) {
637 hasLoopWithNonConstTripCount = true;
638 return;
639 }
640 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
641 });
MLIR Team38c2fe32019-01-14 19:26:25642 }
643};
644
River Riddle5052bd82019-02-02 00:42:18645// Computes the total cost of the loop nest rooted at 'forOp'.
MLIR Team38c2fe32019-01-14 19:26:25646// Currently, the total cost is computed by counting the total operation
647// instance count (i.e. total number of operations in the loop bodyloop
648// operation count * loop trip count) for the entire loop nest.
649// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
650// specified in the map when computing the total op instance count.
651// NOTE: this is used to compute the cost of computation slices, which are
652// sliced along the iteration dimension, and thus reduce the trip count.
River Riddle5052bd82019-02-02 00:42:18653// If 'computeCostMap' is non-null, the total op count for forOps specified
MLIR Team38c2fe32019-01-14 19:26:25654// in the map is increased (not overridden) by adding the op count from the
655// map to the existing op count for the for loop. This is done before
656// multiplying by the loop's trip count, and is used to model the cost of
657// inserting a sliced loop nest of known cost into the loop's body.
658// NOTE: this is used to compute the cost of fusing a slice of some loop nest
659// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24660static int64_t getComputeCost(
River Riddle5052bd82019-02-02 00:42:18661 Instruction *forInst, LoopNestStats *stats,
662 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountOverrideMap,
663 DenseMap<Instruction *, int64_t> *computeCostMap) {
664 // 'opCount' is the total number operations in one iteration of 'forOp' body
Uday Bondhugula864d9e02019-01-23 17:16:24665 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25666 if (stats->loopMap.count(forInst) > 0) {
River Riddle5052bd82019-02-02 00:42:18667 for (auto childForOp : stats->loopMap[forInst]) {
668 opCount += getComputeCost(childForOp->getInstruction(), stats,
669 tripCountOverrideMap, computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:25670 }
671 }
672 // Add in additional op instances from slice (if specified in map).
673 if (computeCostMap != nullptr) {
674 auto it = computeCostMap->find(forInst);
675 if (it != computeCostMap->end()) {
676 opCount += it->second;
677 }
678 }
679 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24680 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25681 if (tripCountOverrideMap != nullptr) {
682 auto it = tripCountOverrideMap->find(forInst);
683 if (it != tripCountOverrideMap->end()) {
684 tripCount = it->second;
685 }
686 }
687 // Returns the total number of dynamic instances of operations in loop body.
688 return tripCount * opCount;
689}
690
691} // end anonymous namespace
692
MLIR Team27d067e2019-01-16 17:55:02693static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00694 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
695 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02696 assert(lbMap.getNumDims() == ubMap.getNumDims());
697 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
698 // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
699 // ub_expr - lb_expr
700 AffineExpr lbExpr(lbMap.getResult(0));
701 AffineExpr ubExpr(ubMap.getResult(0));
702 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
703 lbMap.getNumSymbols());
704 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
705 if (!cExpr)
706 return None;
707 return cExpr.getValue();
708}
709
River Riddle5052bd82019-02-02 00:42:18710// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
MLIR Team38c2fe32019-01-14 19:26:25711// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
712// Returns true on success, false otherwise (if a non-constant trip count
713// was encountered).
714// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02715static bool buildSliceTripCountMap(
River Riddleb4992772019-02-04 18:38:47716 Instruction *srcOpInst, ComputationSliceState *sliceState,
River Riddle5052bd82019-02-02 00:42:18717 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
718 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02719 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25720 unsigned numSrcLoopIVs = srcLoopIVs.size();
River Riddle5052bd82019-02-02 00:42:18721 // Populate map from AffineForOp -> trip count
MLIR Team38c2fe32019-01-14 19:26:25722 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
723 AffineMap lbMap = sliceState->lbs[i];
724 AffineMap ubMap = sliceState->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17725 if (lbMap == AffineMap() || ubMap == AffineMap()) {
MLIR Team38c2fe32019-01-14 19:26:25726 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
727 if (srcLoopIVs[i]->hasConstantLowerBound() &&
728 srcLoopIVs[i]->hasConstantUpperBound()) {
River Riddle5052bd82019-02-02 00:42:18729 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] =
MLIR Team38c2fe32019-01-14 19:26:25730 srcLoopIVs[i]->getConstantUpperBound() -
731 srcLoopIVs[i]->getConstantLowerBound();
732 continue;
733 }
734 return false;
735 }
MLIR Team27d067e2019-01-16 17:55:02736 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
737 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25738 return false;
River Riddle5052bd82019-02-02 00:42:18739 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25740 }
741 return true;
742}
743
MLIR Team27d067e2019-01-16 17:55:02744// Removes load operations from 'srcLoads' which operate on 'memref', and
745// adds them to 'dstLoads'.
746static void
747moveLoadsAccessingMemrefTo(Value *memref,
River Riddleb4992772019-02-04 18:38:47748 SmallVectorImpl<Instruction *> *srcLoads,
749 SmallVectorImpl<Instruction *> *dstLoads) {
MLIR Team27d067e2019-01-16 17:55:02750 dstLoads->clear();
River Riddleb4992772019-02-04 18:38:47751 SmallVector<Instruction *, 4> srcLoadsToKeep;
MLIR Team27d067e2019-01-16 17:55:02752 for (auto *load : *srcLoads) {
753 if (load->cast<LoadOp>()->getMemRef() == memref)
754 dstLoads->push_back(load);
755 else
756 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25757 }
MLIR Team27d067e2019-01-16 17:55:02758 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25759}
760
MLIR Team27d067e2019-01-16 17:55:02761// Returns the innermost common loop depth for the set of operations in 'ops'.
River Riddleb4992772019-02-04 18:38:47762static unsigned getInnermostCommonLoopDepth(ArrayRef<Instruction *> ops) {
MLIR Team27d067e2019-01-16 17:55:02763 unsigned numOps = ops.size();
764 assert(numOps > 0);
765
River Riddle5052bd82019-02-02 00:42:18766 std::vector<SmallVector<OpPointer<AffineForOp>, 4>> loops(numOps);
MLIR Team27d067e2019-01-16 17:55:02767 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
768 for (unsigned i = 0; i < numOps; ++i) {
769 getLoopIVs(*ops[i], &loops[i]);
770 loopDepthLimit =
771 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25772 }
MLIR Team27d067e2019-01-16 17:55:02773
774 unsigned loopDepth = 0;
775 for (unsigned d = 0; d < loopDepthLimit; ++d) {
776 unsigned i;
777 for (i = 1; i < numOps; ++i) {
River Riddle5052bd82019-02-02 00:42:18778 if (loops[i - 1][d] != loops[i][d])
MLIR Team27d067e2019-01-16 17:55:02779 break;
MLIR Team27d067e2019-01-16 17:55:02780 }
781 if (i != numOps)
782 break;
783 ++loopDepth;
784 }
785 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25786}
787
MLIR Teamd7c82442019-01-30 23:53:41788// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
789// and 'storeOpInsts' are satisfied.
River Riddleb4992772019-02-04 18:38:47790static unsigned getMaxLoopDepth(ArrayRef<Instruction *> loadOpInsts,
791 ArrayRef<Instruction *> storeOpInsts) {
MLIR Teamd7c82442019-01-30 23:53:41792 // Merge loads and stores into the same array.
River Riddleb4992772019-02-04 18:38:47793 SmallVector<Instruction *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
MLIR Teamd7c82442019-01-30 23:53:41794 ops.append(storeOpInsts.begin(), storeOpInsts.end());
795
796 // Compute the innermost common loop depth for loads and stores.
797 unsigned loopDepth = getInnermostCommonLoopDepth(ops);
798
799 // Return common loop depth for loads if there are no store ops.
800 if (storeOpInsts.empty())
801 return loopDepth;
802
803 // Check dependences on all pairs of ops in 'ops' and store the minimum
804 // loop depth at which a dependence is satisfied.
805 for (unsigned i = 0, e = ops.size(); i < e; ++i) {
806 auto *srcOpInst = ops[i];
807 MemRefAccess srcAccess(srcOpInst);
808 for (unsigned j = 0; j < e; ++j) {
809 auto *dstOpInst = ops[j];
810 MemRefAccess dstAccess(dstOpInst);
811
812 unsigned numCommonLoops =
813 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
814 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
815 FlatAffineConstraints dependenceConstraints;
816 // TODO(andydavis) Cache dependence analysis results, check cache here.
817 if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
818 &dependenceConstraints,
819 /*dependenceComponents=*/nullptr)) {
820 // Store minimum loop depth and break because we want the min 'd' at
821 // which there is a dependence.
822 loopDepth = std::min(loopDepth, d - 1);
823 break;
824 }
825 }
826 }
827 }
828 return loopDepth;
829}
830
Uday Bondhugulac1ca23e2019-01-16 21:13:00831// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
832// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:02833// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
834// and 'sliceStateB' are aligned.
835// Specifically, when taking the union of overlapping intervals, it assumes
836// that both intervals start at zero. Support needs to be added to take into
837// account interval start offset when computing the union.
838// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:00839static bool getSliceUnion(const ComputationSliceState &sliceStateA,
840 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:02841 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
842 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
843
844 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
845 AffineMap lbMapA = sliceStateA.lbs[i];
846 AffineMap ubMapA = sliceStateA.ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17847 if (lbMapA == AffineMap()) {
848 assert(ubMapA == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02849 continue;
850 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00851 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:02852
853 AffineMap lbMapB = sliceStateB->lbs[i];
854 AffineMap ubMapB = sliceStateB->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17855 if (lbMapB == AffineMap()) {
856 assert(ubMapB == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02857 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
858 sliceStateB->lbs[i] = lbMapA;
859 sliceStateB->ubs[i] = ubMapA;
860 continue;
861 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00862
863 // TODO(andydavis) Change this code to take the min across all lower bounds
864 // and max across all upper bounds for each dimension. This code can for
865 // cases where a unique min or max could not be statically determined.
866
867 // Assumption: both lower bounds are the same.
868 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:02869 return false;
870
871 // Add bound with the largest trip count to union.
872 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
873 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
874 if (!tripCountA.hasValue() || !tripCountB.hasValue())
875 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:00876
MLIR Team27d067e2019-01-16 17:55:02877 if (tripCountA.getValue() > tripCountB.getValue()) {
878 sliceStateB->lbs[i] = lbMapA;
879 sliceStateB->ubs[i] = ubMapA;
880 }
881 }
882 return true;
883}
884
Uday Bondhugula8be26272019-02-02 01:06:22885// TODO(mlir-team): improve/complete this when we have target data.
886unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
887 auto elementType = memRefType.getElementType();
888
889 unsigned sizeInBits;
890 if (elementType.isIntOrFloat()) {
891 sizeInBits = elementType.getIntOrFloatBitWidth();
892 } else {
893 auto vectorType = elementType.cast<VectorType>();
894 sizeInBits =
895 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
896 }
897 return llvm::divideCeil(sizeInBits, 8);
898}
899
MLIR Teamc4237ae2019-01-18 16:56:27900// Creates and returns a private (single-user) memref for fused loop rooted
River Riddle5052bd82019-02-02 00:42:18901// at 'forOp', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:52902// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
903// TODO(bondhugula): consider refactoring the common code from generateDma and
904// this one.
River Riddle5052bd82019-02-02 00:42:18905static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
River Riddleb4992772019-02-04 18:38:47906 Instruction *srcStoreOpInst,
Uday Bondhugula8be26272019-02-02 01:06:22907 unsigned dstLoopDepth,
908 Optional<unsigned> fastMemorySpace,
909 unsigned localBufSizeThreshold) {
River Riddle5052bd82019-02-02 00:42:18910 auto *forInst = forOp->getInstruction();
911
912 // Create builder to insert alloc op just before 'forOp'.
MLIR Teamc4237ae2019-01-18 16:56:27913 FuncBuilder b(forInst);
914 // Builder to create constants at the top level.
915 FuncBuilder top(forInst->getFunction());
916 // Create new memref type based on slice bounds.
917 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>()->getMemRef();
918 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
919 unsigned rank = oldMemRefType.getRank();
920
Uday Bondhugula94a03f82019-01-22 21:58:52921 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
Uday Bondhugula0f504142019-02-04 21:48:44922 MemRefRegion region(srcStoreOpInst->getLoc());
923 region.compute(srcStoreOpInst, dstLoopDepth);
River Riddle6859f332019-01-23 22:39:45924 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:27925 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:52926 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:27927 lbs.reserve(rank);
928 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:52929 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27930 Optional<int64_t> numElements =
Uday Bondhugula0f504142019-02-04 21:48:44931 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
Uday Bondhugula8be26272019-02-02 01:06:22932 assert(numElements.hasValue() &&
933 "non-constant number of elts in local buffer");
MLIR Teamc4237ae2019-01-18 16:56:27934
Uday Bondhugula0f504142019-02-04 21:48:44935 const FlatAffineConstraints *cst = region.getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:52936 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
937 // on; this would correspond to loop IVs surrounding the level at which the
938 // slice is being materialized.
939 SmallVector<Value *, 8> outerIVs;
940 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
941
942 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:27943 SmallVector<AffineExpr, 4> offsets;
944 offsets.reserve(rank);
945 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:52946 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
947
MLIR Teamc4237ae2019-01-18 16:56:27948 AffineExpr offset = top.getAffineConstantExpr(0);
949 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
950 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
951 }
Uday Bondhugula94a03f82019-01-22 21:58:52952 assert(lbDivisors[d] > 0);
953 offset =
954 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:27955 offsets.push_back(offset);
956 }
957
958 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
959 // by 'srcStoreOpInst'.
Uday Bondhugula8be26272019-02-02 01:06:22960 uint64_t bufSize =
961 getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
962 unsigned newMemSpace;
963 if (bufSize < localBufSizeThreshold && fastMemorySpace.hasValue()) {
964 newMemSpace = fastMemorySpace.getValue();
965 } else {
966 newMemSpace = oldMemRefType.getMemorySpace();
967 }
968 auto newMemRefType = top.getMemRefType(
969 newShape, oldMemRefType.getElementType(), {}, newMemSpace);
MLIR Teamc4237ae2019-01-18 16:56:27970 // Gather alloc operands for the dynamic dimensions of the memref.
971 SmallVector<Value *, 4> allocOperands;
972 unsigned dynamicDimCount = 0;
973 for (auto dimSize : oldMemRefType.getShape()) {
974 if (dimSize == -1)
975 allocOperands.push_back(
River Riddle5052bd82019-02-02 00:42:18976 top.create<DimOp>(forOp->getLoc(), oldMemRef, dynamicDimCount++));
MLIR Teamc4237ae2019-01-18 16:56:27977 }
978
River Riddle5052bd82019-02-02 00:42:18979 // Create new private memref for fused loop 'forOp'.
MLIR Teama0f3db402019-01-29 17:36:41980 // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
981 // consumer loop nests to reduce their live range. Currently they are added
982 // at the beginning of the function, because loop nests can be reordered
983 // during the fusion pass.
MLIR Teamc4237ae2019-01-18 16:56:27984 Value *newMemRef =
River Riddle5052bd82019-02-02 00:42:18985 top.create<AllocOp>(forOp->getLoc(), newMemRefType, allocOperands);
MLIR Teamc4237ae2019-01-18 16:56:27986
987 // Build an AffineMap to remap access functions based on lower bound offsets.
988 SmallVector<AffineExpr, 4> remapExprs;
989 remapExprs.reserve(rank);
990 unsigned zeroOffsetCount = 0;
991 for (unsigned i = 0; i < rank; i++) {
992 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
993 if (constExpr.getValue() == 0)
994 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:52995 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
996
997 auto remapExpr =
998 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
999 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:271000 }
Uday Bondhugula94a03f82019-01-22 21:58:521001 auto indexRemap =
1002 zeroOffsetCount == rank
Nicolas Vasilache0e7a8a92019-01-26 18:41:171003 ? AffineMap()
Uday Bondhugula94a03f82019-01-22 21:58:521004 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:271005 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:521006 bool ret =
1007 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
1008 /*extraOperands=*/outerIVs,
River Riddle5052bd82019-02-02 00:42:181009 /*domInstFilter=*/&*forOp->getBody()->begin());
Uday Bondhugula94a03f82019-01-22 21:58:521010 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:371011 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:271012 return newMemRef;
1013}
1014
Uday Bondhugula864d9e02019-01-23 17:16:241015// Does the slice have a single iteration?
1016static uint64_t getSliceIterationCount(
River Riddle5052bd82019-02-02 00:42:181017 const llvm::SmallDenseMap<Instruction *, uint64_t, 8> &sliceTripCountMap) {
Uday Bondhugula864d9e02019-01-23 17:16:241018 uint64_t iterCount = 1;
1019 for (const auto &count : sliceTripCountMap) {
1020 iterCount *= count.second;
1021 }
1022 return iterCount;
1023}
1024
MLIR Team27d067e2019-01-16 17:55:021025// Checks the profitability of fusing a backwards slice of the loop nest
MLIR Teamd7c82442019-01-30 23:53:411026// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
Uday Bondhugulab4a14432019-01-26 00:00:501027// Returns true if it is profitable to fuse the candidate loop nests. Returns
1028// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
1029// to materialize the source loop nest slice.
MLIR Team38c2fe32019-01-14 19:26:251030// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:021031// *) Computes the backward computation slice at 'srcOpInst'. This
1032// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:251033// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:021034// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:251035// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
1036// loop nest is the total number of dynamic operation instances in the loop
1037// nest).
1038// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:021039// loop nest at various values of dst loop depth, attempting to fuse
1040// the largest compution slice at the maximal dst loop depth (closest to the
1041// load) to minimize reuse distance and potentially enable subsequent
1042// load/store forwarding.
MLIR Teamd7c82442019-01-30 23:53:411043// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
MLIR Team27d067e2019-01-16 17:55:021044// the same memref as is written by 'srcOpInst', then the union of slice
1045// loop bounds is used to compute the slice and associated slice cost.
Uday Bondhugulab4a14432019-01-26 00:00:501046// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
MLIR Team38c2fe32019-01-14 19:26:251047// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:021048// NOTE: We attempt to maximize the dst loop depth, but there are cases
1049// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:251050// loop (within the src computation slice) at a depth which results in
1051// execessive recomputation (see unit tests for examples).
1052// *) Compares the total cost of the unfused loop nests to the min cost fused
1053// loop nest computed in the previous step, and returns true if the latter
1054// is lower.
River Riddleb4992772019-02-04 18:38:471055static bool isFusionProfitable(Instruction *srcOpInst,
1056 ArrayRef<Instruction *> dstLoadOpInsts,
1057 ArrayRef<Instruction *> dstStoreOpInsts,
MLIR Team38c2fe32019-01-14 19:26:251058 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:021059 unsigned *dstLoopDepth) {
Uday Bondhugula06d21d92019-01-25 01:01:491060 LLVM_DEBUG({
1061 llvm::dbgs() << "Checking whether fusion is profitable between:\n";
1062 llvm::dbgs() << " ";
1063 srcOpInst->dump();
1064 llvm::dbgs() << " and \n";
MLIR Teamd7c82442019-01-30 23:53:411065 for (auto dstOpInst : dstLoadOpInsts) {
Uday Bondhugula06d21d92019-01-25 01:01:491066 llvm::dbgs() << " ";
1067 dstOpInst->dump();
1068 };
1069 });
Uday Bondhugula864d9e02019-01-23 17:16:241070
MLIR Team38c2fe32019-01-14 19:26:251071 // Compute cost of sliced and unsliced src loop nest.
River Riddle5052bd82019-02-02 00:42:181072 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:021073 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251074 unsigned numSrcLoopIVs = srcLoopIVs.size();
1075
1076 // Walk src loop nest and collect stats.
1077 LoopNestStats srcLoopNestStats;
1078 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
River Riddlebf9c3812019-02-05 00:24:441079 srcStatsCollector.collect(srcLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251080 // Currently only constant trip count loop nests are supported.
1081 if (srcStatsCollector.hasLoopWithNonConstTripCount)
1082 return false;
1083
1084 // Compute cost of dst loop nest.
River Riddle5052bd82019-02-02 00:42:181085 SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
MLIR Teamd7c82442019-01-30 23:53:411086 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251087
1088 LoopNestStats dstLoopNestStats;
1089 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
River Riddlebf9c3812019-02-05 00:24:441090 dstStatsCollector.collect(dstLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251091 // Currently only constant trip count loop nests are supported.
1092 if (dstStatsCollector.hasLoopWithNonConstTripCount)
1093 return false;
1094
MLIR Teamd7c82442019-01-30 23:53:411095 // Compute the maximum loop depth at which we can can insert the src slice
1096 // and still satisfy dest loop nest dependences.
1097 unsigned maxDstLoopDepth = getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts);
MLIR Team27d067e2019-01-16 17:55:021098 if (maxDstLoopDepth == 0)
1099 return false;
1100
1101 // Search for min cost value for 'dstLoopDepth'. At each value of
1102 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
1103 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
1104 // of these bounds). Next the union slice bounds are used to calculate
1105 // the cost of the slice and the cost of the slice inserted into the dst
1106 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:241107 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
1108 uint64_t maxStorageReduction = 0;
1109 Optional<uint64_t> sliceMemEstimate = None;
1110
MLIR Team27d067e2019-01-16 17:55:021111 SmallVector<ComputationSliceState, 4> sliceStates;
1112 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:241113 // The best loop depth at which to materialize the slice.
1114 Optional<unsigned> bestDstLoopDepth = None;
1115
1116 // Compute op instance count for the src loop nest without iteration slicing.
River Riddle5052bd82019-02-02 00:42:181117 uint64_t srcLoopNestCost =
1118 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
1119 /*tripCountOverrideMap=*/nullptr,
1120 /*computeCostMap=*/nullptr);
Uday Bondhugula864d9e02019-01-23 17:16:241121
1122 // Compute op instance count for the src loop nest.
River Riddle5052bd82019-02-02 00:42:181123 uint64_t dstLoopNestCost =
1124 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
1125 /*tripCountOverrideMap=*/nullptr,
1126 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:021127
River Riddle5052bd82019-02-02 00:42:181128 llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap;
1129 DenseMap<Instruction *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:021130 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
1131 MemRefAccess srcAccess(srcOpInst);
1132 // Handle the common case of one dst load without a copy.
1133 if (!mlir::getBackwardComputationSliceState(
MLIR Teamd7c82442019-01-30 23:53:411134 srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1]))
MLIR Team27d067e2019-01-16 17:55:021135 return false;
MLIR Teamd7c82442019-01-30 23:53:411136 // Compute the union of slice bound of all ops in 'dstLoadOpInsts'.
1137 for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) {
1138 MemRefAccess dstAccess(dstLoadOpInsts[j]);
MLIR Team27d067e2019-01-16 17:55:021139 ComputationSliceState tmpSliceState;
1140 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
1141 &tmpSliceState))
1142 return false;
1143 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:001144 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:251145 }
Uday Bondhugulab4a14432019-01-26 00:00:501146 // Build trip count map for computation slice. We'll skip cases where the
1147 // trip count was non-constant.
MLIR Team27d067e2019-01-16 17:55:021148 sliceTripCountMap.clear();
1149 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
1150 &sliceTripCountMap))
Uday Bondhugula864d9e02019-01-23 17:16:241151 continue;
1152
1153 // Checks whether a store to load forwarding will happen.
1154 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
Uday Bondhugula864d9e02019-01-23 17:16:241155 assert(sliceIterationCount > 0);
Uday Bondhugulab4a14432019-01-26 00:00:501156 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
Uday Bondhugula864d9e02019-01-23 17:16:241157
1158 // Compute cost of fusion for this dest loop depth.
1159
1160 computeCostMap.clear();
1161
1162 // The store and loads to this memref will disappear.
1163 if (storeLoadFwdGuaranteed) {
1164 // A single store disappears: -1 for that.
River Riddle5052bd82019-02-02 00:42:181165 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
MLIR Teamd7c82442019-01-30 23:53:411166 for (auto *loadOp : dstLoadOpInsts) {
River Riddle5052bd82019-02-02 00:42:181167 auto *parentInst = loadOp->getParentInst();
River Riddleb4992772019-02-04 18:38:471168 if (parentInst && parentInst->isa<AffineForOp>())
River Riddle5052bd82019-02-02 00:42:181169 computeCostMap[parentInst] = -1;
Uday Bondhugula864d9e02019-01-23 17:16:241170 }
1171 }
MLIR Team27d067e2019-01-16 17:55:021172
MLIR Team38c2fe32019-01-14 19:26:251173 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:241174 int64_t sliceComputeCost =
River Riddle5052bd82019-02-02 00:42:181175 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
Uday Bondhugula864d9e02019-01-23 17:16:241176 /*tripCountOverrideMap=*/&sliceTripCountMap,
1177 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:251178
Uday Bondhugula864d9e02019-01-23 17:16:241179 // Compute cost of fusion for this depth.
River Riddle5052bd82019-02-02 00:42:181180 computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:241181
1182 int64_t fusedLoopNestComputeCost =
River Riddle5052bd82019-02-02 00:42:181183 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
MLIR Team27d067e2019-01-16 17:55:021184 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:241185
1186 double additionalComputeFraction =
1187 fusedLoopNestComputeCost /
1188 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1189 1;
1190
1191 // TODO(bondhugula): This is an ugly approximation. Fix this by finding a
1192 // good way to calculate the footprint of the memref in the slice and
1193 // divide it by the total memory footprint of the fused computation.
1194 double storageReduction =
1195 static_cast<double>(srcLoopNestCost) / sliceIterationCount;
1196
Uday Bondhugula06d21d92019-01-25 01:01:491197 LLVM_DEBUG({
1198 std::stringstream msg;
1199 msg << " evaluating fusion profitability at depth : " << i << "\n"
1200 << std::setprecision(2) << " additional compute fraction: "
1201 << 100.0 * additionalComputeFraction << "%\n"
1202 << " storage reduction factor: " << storageReduction << "x\n"
1203 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
1204 << " slice iteration count: " << sliceIterationCount << "\n";
1205 llvm::dbgs() << msg.str();
1206 });
Uday Bondhugula864d9e02019-01-23 17:16:241207
1208 double computeToleranceThreshold =
1209 clFusionAddlComputeTolerance.getNumOccurrences() > 0
1210 ? clFusionAddlComputeTolerance
1211 : LoopFusion::kComputeToleranceThreshold;
1212
1213 // TODO(b/123247369): This is a placeholder cost model.
1214 // Among all choices that add an acceptable amount of redundant computation
1215 // (as per computeToleranceThreshold), we will simply pick the one that
1216 // reduces the intermediary size the most.
1217 if ((storageReduction > maxStorageReduction) &&
1218 (clMaximalLoopFusion ||
1219 (additionalComputeFraction < computeToleranceThreshold))) {
1220 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:021221 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:241222 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
1223 // TODO(bondhugula,andydavis): find a good way to compute the memory
1224 // footprint of the materialized slice.
1225 // Approximating this to the compute cost of the slice. This could be an
1226 // under-approximation or an overapproximation, but in many cases
1227 // accurate.
1228 sliceMemEstimate = sliceIterationCount;
MLIR Team38c2fe32019-01-14 19:26:251229 }
1230 }
1231
Uday Bondhugula864d9e02019-01-23 17:16:241232 // A simple cost model: fuse if it reduces the memory footprint. If
1233 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251234
Uday Bondhugula864d9e02019-01-23 17:16:241235 if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
1236 LLVM_DEBUG(llvm::dbgs()
1237 << "All fusion choices involve more than the threshold amount of"
1238 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251239 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241240 }
1241
1242 assert(bestDstLoopDepth.hasValue() &&
1243 "expected to have a value per logic above");
1244
1245 // Set dstLoopDepth based on best values from search.
1246 *dstLoopDepth = bestDstLoopDepth.getValue();
1247
1248 LLVM_DEBUG(
Uday Bondhugula06d21d92019-01-25 01:01:491249 llvm::dbgs() << " LoopFusion fusion stats:"
1250 << "\n best loop depth: " << bestDstLoopDepth
Uday Bondhugula864d9e02019-01-23 17:16:241251 << "\n src loop nest compute cost: " << srcLoopNestCost
1252 << "\n dst loop nest compute cost: " << dstLoopNestCost
1253 << "\n fused loop nest compute cost: "
1254 << minFusedLoopNestComputeCost << "\n");
1255
River Riddle5052bd82019-02-02 00:42:181256 auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
1257 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Uday Bondhugula864d9e02019-01-23 17:16:241258
1259 Optional<double> storageReduction = None;
1260
1261 if (!clMaximalLoopFusion) {
1262 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1263 LLVM_DEBUG(
1264 llvm::dbgs()
1265 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1266 return false;
1267 }
1268
1269 auto srcMemSizeVal = srcMemSize.getValue();
1270 auto dstMemSizeVal = dstMemSize.getValue();
1271
1272 assert(sliceMemEstimate.hasValue() && "expected value");
1273 // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
1274 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1275
1276 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1277 << " dst mem: " << dstMemSizeVal << "\n"
1278 << " fused mem: " << fusedMem << "\n"
1279 << " slice mem: " << sliceMemEstimate << "\n");
1280
1281 if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
1282 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1283 return false;
1284 }
1285 storageReduction =
1286 100.0 *
1287 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1288 }
1289
1290 double additionalComputeFraction =
1291 100.0 * (minFusedLoopNestComputeCost /
1292 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1293 1);
MLIR Team5c5739d2019-01-25 06:27:401294 (void)additionalComputeFraction;
Uday Bondhugula06d21d92019-01-25 01:01:491295 LLVM_DEBUG({
1296 std::stringstream msg;
1297 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
1298 << setprecision(2) << additionalComputeFraction
1299 << "% redundant computation and a ";
1300 msg << (storageReduction.hasValue()
1301 ? std::to_string(storageReduction.getValue())
1302 : "<unknown>");
1303 msg << "% storage reduction.\n";
1304 llvm::dbgs() << msg.str();
1305 });
Uday Bondhugula864d9e02019-01-23 17:16:241306
MLIR Team27d067e2019-01-16 17:55:021307 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241308 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021309 sliceState->lbs = bestSliceState->lbs;
1310 sliceState->ubs = bestSliceState->ubs;
1311 sliceState->lbOperands = bestSliceState->lbOperands;
1312 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241313
MLIR Team27d067e2019-01-16 17:55:021314 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251315 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Nicolas Vasilache0e7a8a92019-01-26 18:41:171316 if (sliceState->lbs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021317 canonicalizeMapAndOperands(&sliceState->lbs[i],
1318 &sliceState->lbOperands[i]);
1319 }
Nicolas Vasilache0e7a8a92019-01-26 18:41:171320 if (sliceState->ubs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021321 canonicalizeMapAndOperands(&sliceState->ubs[i],
1322 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251323 }
1324 }
1325 return true;
1326}
1327
MLIR Team6892ffb2018-12-20 04:42:551328// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:141329// relationship on a memref, with the goal of improving locality. Currently,
1330// this the producer/consumer relationship is required to be unique in the
Chris Lattner69d9e992018-12-28 16:48:091331// Function (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:001332//
MLIR Team3b692302018-12-17 17:57:141333// The steps of the algorithm are as follows:
1334//
MLIR Team6892ffb2018-12-20 04:42:551335// *) A worklist is initialized with node ids from the dependence graph.
1336// *) For each node id in the worklist:
River Riddle5052bd82019-02-02 00:42:181337// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a
1338// candidate destination AffineForOp into which fusion will be attempted.
1339// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141340// *) For each LoadOp in 'dstLoadOps' do:
Chris Lattner69d9e992018-12-28 16:48:091341// *) Lookup dependent loop nests at earlier positions in the Function
MLIR Team3b692302018-12-17 17:57:141342// which have a single store op to the same memref.
1343// *) Check if dependences would be violated by the fusion. For example,
1344// the src loop nest may load from memrefs which are different than
1345// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:551346// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141347// bounds to be functions of 'dstLoopNest' IVs and symbols.
1348// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
1349// just before the dst load op user.
Chris Lattner456ad6a2018-12-29 00:05:351350// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:141351// and also add newly fuse load ops to 'dstLoopOps' to be considered
1352// as fusion dst load ops in another iteration.
1353// *) Remove old src loop nest and its associated state.
1354//
Chris Lattner456ad6a2018-12-29 00:05:351355// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141356// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551357// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141358//
MLIR Team6892ffb2018-12-20 04:42:551359// This greedy algorithm is not 'maximal' due to the current restriction of
1360// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141361//
1362// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551363// TODO(andydavis) Add support for fusing for input reuse (perhaps by
1364// constructing a graph with edges which represent loads from the same memref
MLIR Team5c5739d2019-01-25 06:27:401365// in two different loop nests.
MLIR Team6892ffb2018-12-20 04:42:551366struct GreedyFusion {
1367public:
1368 MemRefDependenceGraph *mdg;
MLIR Team3b692302018-12-17 17:57:141369 SmallVector<unsigned, 4> worklist;
MLIR Teamf28e4df2018-11-01 14:26:001370
MLIR Team6892ffb2018-12-20 04:42:551371 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
1372 // Initialize worklist with nodes from 'mdg'.
1373 worklist.resize(mdg->nodes.size());
1374 std::iota(worklist.begin(), worklist.end(), 0);
1375 }
MLIR Team3b692302018-12-17 17:57:141376
Uday Bondhugula8be26272019-02-02 01:06:221377 void run(unsigned localBufSizeThreshold, Optional<unsigned> fastMemorySpace) {
MLIR Team3b692302018-12-17 17:57:141378 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551379 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141380 worklist.pop_back();
MLIR Team6892ffb2018-12-20 04:42:551381 // Skip if this node was removed (fused into another node).
1382 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141383 continue;
MLIR Team6892ffb2018-12-20 04:42:551384 // Get 'dstNode' into which to attempt fusion.
1385 auto *dstNode = mdg->getNode(dstId);
1386 // Skip if 'dstNode' is not a loop nest.
River Riddleb4992772019-02-04 18:38:471387 if (!dstNode->inst->isa<AffineForOp>())
MLIR Team3b692302018-12-17 17:57:141388 continue;
1389
River Riddleb4992772019-02-04 18:38:471390 SmallVector<Instruction *, 4> loads = dstNode->loads;
1391 SmallVector<Instruction *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271392 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551393 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021394 // Get memref of load on top of the stack.
1395 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271396 if (visitedMemrefs.count(memref) > 0)
1397 continue;
1398 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021399 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1400 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551401 // Skip if no input edges along which to fuse.
1402 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141403 continue;
MLIR Team1e851912019-01-31 00:01:461404 // Iterate through in edges for 'dstId' and src node id for any
1405 // edges on 'memref'.
1406 SmallVector<unsigned, 2> srcNodeIds;
MLIR Team6892ffb2018-12-20 04:42:551407 for (auto &srcEdge : mdg->inEdges[dstId]) {
1408 // Skip 'srcEdge' if not for 'memref'.
MLIR Teama0f3db402019-01-29 17:36:411409 if (srcEdge.value != memref)
MLIR Team6892ffb2018-12-20 04:42:551410 continue;
MLIR Team1e851912019-01-31 00:01:461411 srcNodeIds.push_back(srcEdge.id);
1412 }
1413 for (unsigned srcId : srcNodeIds) {
1414 // Skip if this node was removed (fused into another node).
1415 if (mdg->nodes.count(srcId) == 0)
1416 continue;
1417 // Get 'srcNode' from which to attempt fusion into 'dstNode'.
1418 auto *srcNode = mdg->getNode(srcId);
MLIR Team6892ffb2018-12-20 04:42:551419 // Skip if 'srcNode' is not a loop nest.
River Riddleb4992772019-02-04 18:38:471420 if (!srcNode->inst->isa<AffineForOp>())
MLIR Team6892ffb2018-12-20 04:42:551421 continue;
MLIR Teamb28009b2019-01-23 19:11:431422 // Skip if 'srcNode' has more than one store to any memref.
1423 // TODO(andydavis) Support fusing multi-output src loop nests.
1424 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551425 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241426
MLIR Teama0f3db402019-01-29 17:36:411427 // Skip 'srcNode' if it has in edges on 'memref'.
MLIR Team6892ffb2018-12-20 04:42:551428 // TODO(andydavis) Track dependence type with edges, and just check
MLIR Teama0f3db402019-01-29 17:36:411429 // for WAW dependence edge here. Note that this check is overly
1430 // conservative and will be removed in the future.
1431 if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0)
MLIR Team6892ffb2018-12-20 04:42:551432 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241433
MLIR Teamd7c82442019-01-30 23:53:411434 // Skip if 'srcNode' writes to any live in or escaping memrefs.
1435 if (mdg->writesToLiveInOrEscapingMemrefs(srcNode->id))
1436 continue;
1437
MLIR Teama0f3db402019-01-29 17:36:411438 // Compute an instruction list insertion point for the fused loop
1439 // nest which preserves dependences.
1440 Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint(
1441 srcNode->id, dstNode->id, memref);
1442 if (insertPointInst == nullptr)
MLIR Team6892ffb2018-12-20 04:42:551443 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241444
MLIR Team6892ffb2018-12-20 04:42:551445 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351446 auto *srcStoreOpInst = srcNode->stores.front();
MLIR Teamd7c82442019-01-30 23:53:411447 // Gather 'dstNode' store ops to 'memref'.
River Riddleb4992772019-02-04 18:38:471448 SmallVector<Instruction *, 2> dstStoreOpInsts;
MLIR Teamd7c82442019-01-30 23:53:411449 for (auto *storeOpInst : dstNode->stores)
1450 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1451 dstStoreOpInsts.push_back(storeOpInst);
1452
Uday Bondhugulab4a14432019-01-26 00:00:501453 unsigned bestDstLoopDepth;
MLIR Team38c2fe32019-01-14 19:26:251454 mlir::ComputationSliceState sliceState;
MLIR Teama0f3db402019-01-29 17:36:411455 // Check if fusion would be profitable.
MLIR Teamd7c82442019-01-30 23:53:411456 if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts,
1457 dstStoreOpInsts, &sliceState,
Uday Bondhugulab4a14432019-01-26 00:00:501458 &bestDstLoopDepth))
MLIR Team38c2fe32019-01-14 19:26:251459 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241460
MLIR Team6892ffb2018-12-20 04:42:551461 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
River Riddle5052bd82019-02-02 00:42:181462 auto sliceLoopNest = mlir::insertBackwardComputationSlice(
Uday Bondhugulab4a14432019-01-26 00:00:501463 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:551464 if (sliceLoopNest != nullptr) {
River Riddle5052bd82019-02-02 00:42:181465 // Move 'dstAffineForOp' before 'insertPointInst' if needed.
River Riddleb4992772019-02-04 18:38:471466 auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
River Riddle5052bd82019-02-02 00:42:181467 if (insertPointInst != dstAffineForOp->getInstruction()) {
1468 dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
MLIR Teama0f3db402019-01-29 17:36:411469 }
MLIR Teamc4237ae2019-01-18 16:56:271470 // Update edges between 'srcNode' and 'dstNode'.
MLIR Teama0f3db402019-01-29 17:36:411471 mdg->updateEdges(srcNode->id, dstNode->id, memref);
MLIR Teamc4237ae2019-01-18 16:56:271472
1473 // Collect slice loop stats.
1474 LoopNestStateCollector sliceCollector;
River Riddlebf9c3812019-02-05 00:24:441475 sliceCollector.collect(sliceLoopNest->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271476 // Promote single iteration slice loops to single IV value.
River Riddle5052bd82019-02-02 00:42:181477 for (auto forOp : sliceCollector.forOps) {
1478 promoteIfSingleIteration(forOp);
MLIR Team6892ffb2018-12-20 04:42:551479 }
River Riddle5052bd82019-02-02 00:42:181480 // Create private memref for 'memref' in 'dstAffineForOp'.
River Riddleb4992772019-02-04 18:38:471481 SmallVector<Instruction *, 4> storesForMemref;
MLIR Teamc4237ae2019-01-18 16:56:271482 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
1483 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1484 storesForMemref.push_back(storeOpInst);
1485 }
1486 assert(storesForMemref.size() == 1);
Uday Bondhugula94a03f82019-01-22 21:58:521487 auto *newMemRef = createPrivateMemRef(
Uday Bondhugula8be26272019-02-02 01:06:221488 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1489 fastMemorySpace, localBufSizeThreshold);
MLIR Teamc4237ae2019-01-18 16:56:271490 visitedMemrefs.insert(newMemRef);
MLIR Teama0f3db402019-01-29 17:36:411491 // Create new node in dependence graph for 'newMemRef' alloc op.
1492 unsigned newMemRefNodeId =
1493 mdg->addNode(newMemRef->getDefiningInst());
1494 // Add edge from 'newMemRef' node to dstNode.
1495 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
MLIR Teamc4237ae2019-01-18 16:56:271496
1497 // Collect dst loop stats after memref privatizaton transformation.
1498 LoopNestStateCollector dstLoopCollector;
River Riddlebf9c3812019-02-05 00:24:441499 dstLoopCollector.collect(dstAffineForOp->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271500
1501 // Add new load ops to current Node load op list 'loads' to
1502 // continue fusing based on new operands.
1503 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
1504 auto *loadMemRef = loadOpInst->cast<LoadOp>()->getMemRef();
1505 if (visitedMemrefs.count(loadMemRef) == 0)
1506 loads.push_back(loadOpInst);
1507 }
1508
1509 // Clear and add back loads and stores
1510 mdg->clearNodeLoadAndStores(dstNode->id);
1511 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1512 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371513 // Remove old src loop nest if it no longer has outgoing dependence
1514 // edges, and it does not write to a memref which escapes the
1515 // function.
MLIR Teama0f3db402019-01-29 17:36:411516 if (mdg->canRemoveNode(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271517 mdg->removeNode(srcNode->id);
River Riddle5052bd82019-02-02 00:42:181518 srcNode->inst->erase();
MLIR Teamc4237ae2019-01-18 16:56:271519 }
MLIR Team3b692302018-12-17 17:57:141520 }
MLIR Team3b692302018-12-17 17:57:141521 }
1522 }
1523 }
MLIR Teamc4237ae2019-01-18 16:56:271524 // Clean up any allocs with no users.
1525 for (auto &pair : mdg->memrefEdgeCount) {
1526 if (pair.second > 0)
1527 continue;
1528 auto *memref = pair.first;
MLIR Team71495d52019-01-22 21:23:371529 // Skip if there exist other uses (return instruction or function calls).
1530 if (!memref->use_empty())
1531 continue;
MLIR Teamc4237ae2019-01-18 16:56:271532 // Use list expected to match the dep graph info.
MLIR Teamc4237ae2019-01-18 16:56:271533 auto *inst = memref->getDefiningInst();
River Riddleb4992772019-02-04 18:38:471534 if (inst && inst->isa<AllocOp>())
1535 inst->erase();
MLIR Teamc4237ae2019-01-18 16:56:271536 }
MLIR Teamf28e4df2018-11-01 14:26:001537 }
MLIR Team3b692302018-12-17 17:57:141538};
1539
1540} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:001541
Chris Lattner79748892018-12-31 07:10:351542PassResult LoopFusion::runOnFunction(Function *f) {
Uday Bondhugula8be26272019-02-02 01:06:221543 if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
1544 fastMemorySpace = clFusionFastMemorySpace.getValue();
1545 }
1546
MLIR Team6892ffb2018-12-20 04:42:551547 MemRefDependenceGraph g;
1548 if (g.init(f))
Uday Bondhugula8be26272019-02-02 01:06:221549 GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace);
MLIR Teamf28e4df2018-11-01 14:26:001550 return success();
1551}
Jacques Pienaar6f0fb222018-11-07 02:34:181552
1553static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");