blob: 304331320ac93eebcbbde66fba2db561a458e694 [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
River Riddle75553832019-01-29 05:23:5322#include "mlir/AffineOps/AffineOps.h"
MLIR Teamf28e4df2018-11-01 14:26:0023#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1424#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0025#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1426#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0027#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinOps.h"
Chris Lattner456ad6a2018-12-29 00:05:3531#include "mlir/IR/InstVisitor.h"
MLIR Teamf28e4df2018-11-01 14:26:0032#include "mlir/Pass.h"
33#include "mlir/StandardOps/StandardOps.h"
34#include "mlir/Transforms/LoopUtils.h"
35#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2736#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0037#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1438#include "llvm/ADT/DenseSet.h"
39#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2340#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2541#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1442#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2443#include <iomanip>
MLIR Team3b692302018-12-17 17:57:1444
MLIR Team38c2fe32019-01-14 19:26:2545#define DEBUG_TYPE "loop-fusion"
46
MLIR Team3b692302018-12-17 17:57:1447using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0048
49using namespace mlir;
50
River Riddle75c21e12019-01-26 06:14:0451static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
52
Uday Bondhugula864d9e02019-01-23 17:16:2453/// Disables fusion profitability check and fuses if valid.
MLIR Teamc4237ae2019-01-18 16:56:2754static llvm::cl::opt<bool>
55 clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
River Riddle75c21e12019-01-26 06:14:0456 llvm::cl::desc("Enables maximal loop fusion"),
57 llvm::cl::cat(clOptionsCategory));
Uday Bondhugula864d9e02019-01-23 17:16:2458
59/// A threshold in percent of additional computation allowed when fusing.
60static llvm::cl::opt<double> clFusionAddlComputeTolerance(
61 "fusion-compute-tolerance", llvm::cl::Hidden,
62 llvm::cl::desc("Fractional increase in additional"
River Riddle75c21e12019-01-26 06:14:0463 " computation tolerated while fusing"),
64 llvm::cl::cat(clOptionsCategory));
MLIR Teamc4237ae2019-01-18 16:56:2765
Uday Bondhugula8be26272019-02-02 01:06:2266static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
67 "fusion-fast-mem-space", llvm::cl::Hidden,
68 llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
69 llvm::cl::cat(clOptionsCategory));
70
71static llvm::cl::opt<unsigned> clFusionLocalBufThreshold(
72 "fusion-local-buf-threshold", llvm::cl::Hidden,
73 llvm::cl::desc("Threshold size (bytes) for promoting local buffers to fast "
74 "memory space"),
75 llvm::cl::cat(clOptionsCategory));
76
MLIR Teamf28e4df2018-11-01 14:26:0077namespace {
78
MLIR Team3b692302018-12-17 17:57:1479/// Loop fusion pass. This pass currently supports a greedy fusion policy,
80/// which fuses loop nests with single-writer/single-reader memref dependences
81/// with the goal of improving locality.
82
83// TODO(andydavis) Support fusion of source loop nests which write to multiple
84// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0085// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
86// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1487
MLIR Teamf28e4df2018-11-01 14:26:0088struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0389 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0090
Chris Lattner79748892018-12-31 07:10:3591 PassResult runOnFunction(Function *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1892 static char passID;
Uday Bondhugula864d9e02019-01-23 17:16:2493
Uday Bondhugula8be26272019-02-02 01:06:2294 // Any local buffers smaller than this size will be created in
95 // `fastMemorySpace` if provided.
96 unsigned localBufSizeThreshold = 1024;
97 Optional<unsigned> fastMemorySpace = None;
98
Uday Bondhugula864d9e02019-01-23 17:16:2499 // The amount of additional computation that is tolerated while fusing
100 // pair-wise as a fraction of the total computation.
101 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:00102};
103
MLIR Teamf28e4df2018-11-01 14:26:00104} // end anonymous namespace
105
Jacques Pienaar6f0fb222018-11-07 02:34:18106char LoopFusion::passID = 0;
107
MLIR Teamf28e4df2018-11-01 14:26:00108FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
109
MLIR Team3b692302018-12-17 17:57:14110namespace {
MLIR Teamf28e4df2018-11-01 14:26:00111
MLIR Team3b692302018-12-17 17:57:14112// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:35113// operations, and whether or not an IfInst was encountered in the loop nest.
114class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
MLIR Team3b692302018-12-17 17:57:14115public:
River Riddle5052bd82019-02-02 00:42:18116 SmallVector<OpPointer<AffineForOp>, 4> forOps;
Chris Lattner456ad6a2018-12-29 00:05:35117 SmallVector<OperationInst *, 4> loadOpInsts;
118 SmallVector<OperationInst *, 4> storeOpInsts;
River Riddle75553832019-01-29 05:23:53119 bool hasNonForRegion = false;
MLIR Team3b692302018-12-17 17:57:14120
River Riddlea3d9ccae2019-02-04 18:30:45121 void visitInstruction(OperationInst *opInst) {
River Riddle5052bd82019-02-02 00:42:18122 if (opInst->isa<AffineForOp>())
123 forOps.push_back(opInst->cast<AffineForOp>());
124 else if (opInst->getNumBlockLists() != 0)
River Riddle75553832019-01-29 05:23:53125 hasNonForRegion = true;
126 else if (opInst->isa<LoadOp>())
Chris Lattner456ad6a2018-12-29 00:05:35127 loadOpInsts.push_back(opInst);
River Riddle75553832019-01-29 05:23:53128 else if (opInst->isa<StoreOp>())
Chris Lattner456ad6a2018-12-29 00:05:35129 storeOpInsts.push_back(opInst);
MLIR Team3b692302018-12-17 17:57:14130 }
131};
132
MLIR Team71495d52019-01-22 21:23:37133// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
134static bool isMemRefDereferencingOp(const OperationInst &op) {
135 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
136 op.isa<DmaWaitOp>())
137 return true;
138 return false;
139}
MLIR Team6892ffb2018-12-20 04:42:55140// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:35141// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55142// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27143// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55144// TODO(andydavis) Add a depth parameter to dependence graph construction.
145struct MemRefDependenceGraph {
146public:
147 // Node represents a node in the graph. A Node is either an entire loop nest
148 // rooted at the top level which contains loads/stores, or a top level
149 // load/store.
150 struct Node {
151 // The unique identifier of this node in the graph.
152 unsigned id;
153 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35154 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41155 // List of load operations.
156 SmallVector<OperationInst *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35157 // List of store op insts.
Chris Lattner5187cfc2018-12-28 05:21:41158 SmallVector<OperationInst *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35159 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55160
161 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10162 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55163 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35164 for (auto *loadOpInst : loads) {
165 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55166 ++loadOpCount;
167 }
168 return loadOpCount;
169 }
170
171 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10172 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55173 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35174 for (auto *storeOpInst : stores) {
175 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55176 ++storeOpCount;
177 }
178 return storeOpCount;
179 }
180 };
181
MLIR Teama0f3db402019-01-29 17:36:41182 // Edge represents a data dependece between nodes in the graph.
MLIR Team6892ffb2018-12-20 04:42:55183 struct Edge {
184 // The id of the node at the other end of the edge.
MLIR Team1e851912019-01-31 00:01:46185 // If this edge is stored in Edge = Node.inEdges[i], then
186 // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
187 // If this edge is stored in Edge = Node.outEdges[i], then
188 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
MLIR Team6892ffb2018-12-20 04:42:55189 unsigned id;
MLIR Teama0f3db402019-01-29 17:36:41190 // The SSA value on which this edge represents a dependence.
191 // If the value is a memref, then the dependence is between graph nodes
192 // which contain accesses to the same memref 'value'. If the value is a
193 // non-memref value, then the dependence is between a graph node which
194 // defines an SSA value and another graph node which uses the SSA value
195 // (e.g. a constant instruction defining a value which is used inside a loop
196 // nest).
197 Value *value;
MLIR Team6892ffb2018-12-20 04:42:55198 };
199
200 // Map from node id to Node.
201 DenseMap<unsigned, Node> nodes;
202 // Map from node id to list of input edges.
203 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
204 // Map from node id to list of output edges.
205 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27206 // Map from memref to a count on the dependence edges associated with that
207 // memref.
208 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Teama0f3db402019-01-29 17:36:41209 // The next unique identifier to use for newly created graph nodes.
210 unsigned nextNodeId = 0;
MLIR Team6892ffb2018-12-20 04:42:55211
212 MemRefDependenceGraph() {}
213
214 // Initializes the dependence graph based on operations in 'f'.
215 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09216 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55217
218 // Returns the graph node for 'id'.
219 Node *getNode(unsigned id) {
220 auto it = nodes.find(id);
221 assert(it != nodes.end());
222 return &it->second;
223 }
224
MLIR Teama0f3db402019-01-29 17:36:41225 // Adds a node with 'inst' to the graph and returns its unique identifier.
226 unsigned addNode(Instruction *inst) {
227 Node node(nextNodeId++, inst);
228 nodes.insert({node.id, node});
229 return node.id;
230 }
231
MLIR Teamc4237ae2019-01-18 16:56:27232 // Remove node 'id' (and its associated edges) from graph.
233 void removeNode(unsigned id) {
234 // Remove each edge in 'inEdges[id]'.
235 if (inEdges.count(id) > 0) {
236 SmallVector<Edge, 2> oldInEdges = inEdges[id];
237 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41238 removeEdge(inEdge.id, id, inEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27239 }
240 }
241 // Remove each edge in 'outEdges[id]'.
242 if (outEdges.count(id) > 0) {
243 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
244 for (auto &outEdge : oldOutEdges) {
MLIR Teama0f3db402019-01-29 17:36:41245 removeEdge(id, outEdge.id, outEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27246 }
247 }
248 // Erase remaining node state.
249 inEdges.erase(id);
250 outEdges.erase(id);
251 nodes.erase(id);
252 }
253
MLIR Teamd7c82442019-01-30 23:53:41254 // Returns true if node 'id' writes to any memref which escapes (or is an
255 // argument to) the function/block. Returns false otherwise.
256 bool writesToLiveInOrEscapingMemrefs(unsigned id) {
MLIR Team71495d52019-01-22 21:23:37257 Node *node = getNode(id);
258 for (auto *storeOpInst : node->stores) {
259 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
260 auto *inst = memref->getDefiningInst();
261 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
262 // Return false if 'memref' is a function argument.
263 if (opInst == nullptr)
MLIR Teamd7c82442019-01-30 23:53:41264 return true;
MLIR Team71495d52019-01-22 21:23:37265 // Return false if any use of 'memref' escapes the function.
266 for (auto &use : memref->getUses()) {
267 auto *user = dyn_cast<OperationInst>(use.getOwner());
268 if (!user || !isMemRefDereferencingOp(*user))
MLIR Teamd7c82442019-01-30 23:53:41269 return true;
MLIR Team71495d52019-01-22 21:23:37270 }
MLIR Teamd7c82442019-01-30 23:53:41271 }
272 return false;
273 }
274
275 // Returns true if node 'id' can be removed from the graph. Returns false
276 // otherwise. A node can be removed from the graph iff the following
277 // conditions are met:
278 // *) The node does not write to any memref which escapes (or is a
279 // function/block argument).
280 // *) The node has no successors in the dependence graph.
281 bool canRemoveNode(unsigned id) {
282 if (writesToLiveInOrEscapingMemrefs(id))
283 return false;
284 Node *node = getNode(id);
285 for (auto *storeOpInst : node->stores) {
MLIR Teama0f3db402019-01-29 17:36:41286 // Return false if there exist out edges from 'id' on 'memref'.
MLIR Teamd7c82442019-01-30 23:53:41287 if (getOutEdgeCount(id, storeOpInst->cast<StoreOp>()->getMemRef()) > 0)
MLIR Teama0f3db402019-01-29 17:36:41288 return false;
MLIR Team71495d52019-01-22 21:23:37289 }
MLIR Teama0f3db402019-01-29 17:36:41290 return true;
MLIR Team71495d52019-01-22 21:23:37291 }
292
MLIR Team27d067e2019-01-16 17:55:02293 // Returns true iff there is an edge from node 'srcId' to node 'dstId' for
MLIR Teama0f3db402019-01-29 17:36:41294 // 'value'. Returns false otherwise.
295 bool hasEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team27d067e2019-01-16 17:55:02296 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
297 return false;
298 }
299 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41300 return edge.id == dstId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02301 });
302 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41303 return edge.id == srcId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02304 });
305 return hasOutEdge && hasInEdge;
306 }
307
MLIR Teama0f3db402019-01-29 17:36:41308 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
309 void addEdge(unsigned srcId, unsigned dstId, Value *value) {
310 if (!hasEdge(srcId, dstId, value)) {
311 outEdges[srcId].push_back({dstId, value});
312 inEdges[dstId].push_back({srcId, value});
313 if (value->getType().isa<MemRefType>())
314 memrefEdgeCount[value]++;
MLIR Team27d067e2019-01-16 17:55:02315 }
MLIR Team6892ffb2018-12-20 04:42:55316 }
317
MLIR Teama0f3db402019-01-29 17:36:41318 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
319 void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team6892ffb2018-12-20 04:42:55320 assert(inEdges.count(dstId) > 0);
321 assert(outEdges.count(srcId) > 0);
MLIR Teama0f3db402019-01-29 17:36:41322 if (value->getType().isa<MemRefType>()) {
323 assert(memrefEdgeCount.count(value) > 0);
324 memrefEdgeCount[value]--;
325 }
MLIR Team6892ffb2018-12-20 04:42:55326 // Remove 'srcId' from 'inEdges[dstId]'.
327 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41328 if ((*it).id == srcId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55329 inEdges[dstId].erase(it);
330 break;
331 }
332 }
333 // Remove 'dstId' from 'outEdges[srcId]'.
334 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41335 if ((*it).id == dstId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55336 outEdges[srcId].erase(it);
337 break;
338 }
339 }
340 }
341
MLIR Teama0f3db402019-01-29 17:36:41342 // Returns the input edge count for node 'id' and 'memref' from src nodes
343 // which access 'memref'.
344 unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55345 unsigned inEdgeCount = 0;
346 if (inEdges.count(id) > 0)
347 for (auto &inEdge : inEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41348 if (inEdge.value == memref) {
349 Node *srcNode = getNode(inEdge.id);
350 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
351 if (srcNode->getLoadOpCount(memref) > 0 ||
352 srcNode->getStoreOpCount(memref) > 0)
353 ++inEdgeCount;
354 }
MLIR Team6892ffb2018-12-20 04:42:55355 return inEdgeCount;
356 }
357
358 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10359 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55360 unsigned outEdgeCount = 0;
361 if (outEdges.count(id) > 0)
362 for (auto &outEdge : outEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41363 if (outEdge.value == memref)
MLIR Team6892ffb2018-12-20 04:42:55364 ++outEdgeCount;
365 return outEdgeCount;
366 }
367
MLIR Teama0f3db402019-01-29 17:36:41368 // Computes and returns an insertion point instruction, before which the
369 // the fused <srcId, dstId> loop nest can be inserted while preserving
370 // dependences. Returns nullptr if no such insertion point is found.
371 Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId,
372 Value *memrefToSkip) {
MLIR Team5c5739d2019-01-25 06:27:40373 if (outEdges.count(srcId) == 0)
MLIR Teama0f3db402019-01-29 17:36:41374 return getNode(dstId)->inst;
375
376 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
377 SmallPtrSet<Instruction *, 2> srcDepInsts;
378 for (auto &outEdge : outEdges[srcId])
379 if (outEdge.id != dstId && outEdge.value != memrefToSkip)
380 srcDepInsts.insert(getNode(outEdge.id)->inst);
381
382 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
383 SmallPtrSet<Instruction *, 2> dstDepInsts;
384 for (auto &inEdge : inEdges[dstId])
385 if (inEdge.id != srcId && inEdge.value != memrefToSkip)
386 dstDepInsts.insert(getNode(inEdge.id)->inst);
387
388 Instruction *srcNodeInst = getNode(srcId)->inst;
389 Instruction *dstNodeInst = getNode(dstId)->inst;
390
391 // Computing insertion point:
392 // *) Walk all instruction positions in Block instruction list in the
393 // range (src, dst). For each instruction 'inst' visited in this search:
394 // *) Store in 'firstSrcDepPos' the first position where 'inst' has a
395 // dependence edge from 'srcNode'.
396 // *) Store in 'lastDstDepPost' the last position where 'inst' has a
397 // dependence edge to 'dstNode'.
398 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
399 // instruction insertion point (or return null pointer if no such
400 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
401 SmallVector<Instruction *, 2> depInsts;
402 Optional<unsigned> firstSrcDepPos;
403 Optional<unsigned> lastDstDepPos;
404 unsigned pos = 0;
405 for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
406 it != Block::iterator(dstNodeInst); ++it) {
407 Instruction *inst = &(*it);
408 if (srcDepInsts.count(inst) > 0 && firstSrcDepPos == None)
409 firstSrcDepPos = pos;
410 if (dstDepInsts.count(inst) > 0)
411 lastDstDepPos = pos;
412 depInsts.push_back(inst);
413 ++pos;
MLIR Team5c5739d2019-01-25 06:27:40414 }
MLIR Teama0f3db402019-01-29 17:36:41415
416 if (firstSrcDepPos.hasValue()) {
417 if (lastDstDepPos.hasValue()) {
418 if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
419 // No valid insertion point exists which preserves dependences.
420 return nullptr;
421 }
422 }
423 // Return the insertion point at 'firstSrcDepPos'.
424 return depInsts[firstSrcDepPos.getValue()];
425 }
426 // No dependence targets in range (or only dst deps in range), return
427 // 'dstNodInst' insertion point.
428 return dstNodeInst;
MLIR Team6892ffb2018-12-20 04:42:55429 }
430
MLIR Teama0f3db402019-01-29 17:36:41431 // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
432 // has been replaced in node at 'dstId' by a private memref.
433 void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
MLIR Team6892ffb2018-12-20 04:42:55434 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
435 if (inEdges.count(srcId) > 0) {
436 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
437 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41438 // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
439 if (inEdge.value != oldMemRef)
440 addEdge(inEdge.id, dstId, inEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55441 }
442 }
MLIR Teamc4237ae2019-01-18 16:56:27443 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55444 if (outEdges.count(srcId) > 0) {
445 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
446 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27447 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
448 if (outEdge.id == dstId)
MLIR Teama0f3db402019-01-29 17:36:41449 removeEdge(srcId, outEdge.id, outEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55450 }
451 }
MLIR Teama0f3db402019-01-29 17:36:41452 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
453 // replaced by a private memref). These edges could come from nodes
454 // other than 'srcId' which were removed in the previous step.
455 if (inEdges.count(dstId) > 0) {
456 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
457 for (auto &inEdge : oldInEdges)
458 if (inEdge.value == oldMemRef)
459 removeEdge(inEdge.id, dstId, inEdge.value);
460 }
MLIR Team6892ffb2018-12-20 04:42:55461 }
462
463 // Adds ops in 'loads' and 'stores' to node at 'id'.
Chris Lattner5187cfc2018-12-28 05:21:41464 void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
465 const SmallVectorImpl<OperationInst *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55466 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35467 for (auto *loadOpInst : loads)
468 node->loads.push_back(loadOpInst);
469 for (auto *storeOpInst : stores)
470 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55471 }
472
MLIR Teamc4237ae2019-01-18 16:56:27473 void clearNodeLoadAndStores(unsigned id) {
474 Node *node = getNode(id);
475 node->loads.clear();
476 node->stores.clear();
477 }
478
MLIR Team6892ffb2018-12-20 04:42:55479 void print(raw_ostream &os) const {
480 os << "\nMemRefDependenceGraph\n";
481 os << "\nNodes:\n";
482 for (auto &idAndNode : nodes) {
483 os << "Node: " << idAndNode.first << "\n";
484 auto it = inEdges.find(idAndNode.first);
485 if (it != inEdges.end()) {
486 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41487 os << " InEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55488 }
489 it = outEdges.find(idAndNode.first);
490 if (it != outEdges.end()) {
491 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41492 os << " OutEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55493 }
494 }
495 }
496 void dump() const { print(llvm::errs()); }
497};
498
Chris Lattner456ad6a2018-12-29 00:05:35499// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55500// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39501// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55502// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09503bool MemRefDependenceGraph::init(Function *f) {
Chris Lattner3f190312018-12-27 22:35:10504 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43505
506 // TODO: support multi-block functions.
507 if (f->getBlocks().size() != 1)
508 return false;
509
River Riddle5052bd82019-02-02 00:42:18510 DenseMap<Instruction *, unsigned> forToNodeMap;
Chris Lattnerdffc5892018-12-29 23:33:43511 for (auto &inst : f->front()) {
River Riddle5052bd82019-02-02 00:42:18512 if (auto forOp = cast<OperationInst>(&inst)->dyn_cast<AffineForOp>()) {
513 // Create graph node 'id' to represent top-level 'forOp' and record
MLIR Team6892ffb2018-12-20 04:42:55514 // all loads and store accesses it contains.
515 LoopNestStateCollector collector;
River Riddle5052bd82019-02-02 00:42:18516 collector.walk(&inst);
517 // Return false if a non 'for' region was found (not currently supported).
River Riddle75553832019-01-29 05:23:53518 if (collector.hasNonForRegion)
MLIR Team6892ffb2018-12-20 04:42:55519 return false;
MLIR Teama0f3db402019-01-29 17:36:41520 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35521 for (auto *opInst : collector.loadOpInsts) {
522 node.loads.push_back(opInst);
523 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55524 memrefAccesses[memref].insert(node.id);
525 }
Chris Lattner456ad6a2018-12-29 00:05:35526 for (auto *opInst : collector.storeOpInsts) {
527 node.stores.push_back(opInst);
528 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55529 memrefAccesses[memref].insert(node.id);
530 }
River Riddle5052bd82019-02-02 00:42:18531 forToNodeMap[&inst] = node.id;
MLIR Team6892ffb2018-12-20 04:42:55532 nodes.insert({node.id, node});
River Riddle5052bd82019-02-02 00:42:18533 } else if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
Chris Lattner456ad6a2018-12-29 00:05:35534 if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55535 // Create graph node for top-level load op.
MLIR Teama0f3db402019-01-29 17:36:41536 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35537 node.loads.push_back(opInst);
538 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55539 memrefAccesses[memref].insert(node.id);
540 nodes.insert({node.id, node});
River Riddle75553832019-01-29 05:23:53541 } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55542 // Create graph node for top-level store op.
MLIR Teama0f3db402019-01-29 17:36:41543 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35544 node.stores.push_back(opInst);
545 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55546 memrefAccesses[memref].insert(node.id);
547 nodes.insert({node.id, node});
River Riddle75553832019-01-29 05:23:53548 } else if (opInst->getNumBlockLists() != 0) {
549 // Return false if another region is found (not currently supported).
550 return false;
MLIR Teama0f3db402019-01-29 17:36:41551 } else if (opInst->getNumResults() > 0 && !opInst->use_empty()) {
552 // Create graph node for top-level producer of SSA values, which
553 // could be used by loop nest nodes.
554 Node node(nextNodeId++, &inst);
555 nodes.insert({node.id, node});
556 }
557 }
558 }
559
560 // Add dependence edges between nodes which produce SSA values and their
561 // users.
562 for (auto &idAndNode : nodes) {
563 const Node &node = idAndNode.second;
564 if (!node.loads.empty() || !node.stores.empty())
565 continue;
566 auto *opInst = cast<OperationInst>(node.inst);
567 for (auto *value : opInst->getResults()) {
568 for (auto &use : value->getUses()) {
569 auto *userOpInst = cast<OperationInst>(use.getOwner());
River Riddle5052bd82019-02-02 00:42:18570 SmallVector<OpPointer<AffineForOp>, 4> loops;
MLIR Teama0f3db402019-01-29 17:36:41571 getLoopIVs(*userOpInst, &loops);
572 if (loops.empty())
573 continue;
River Riddle5052bd82019-02-02 00:42:18574 assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
575 unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()];
MLIR Teama0f3db402019-01-29 17:36:41576 addEdge(node.id, userLoopNestId, value);
MLIR Team6892ffb2018-12-20 04:42:55577 }
578 }
MLIR Team6892ffb2018-12-20 04:42:55579 }
580
581 // Walk memref access lists and add graph edges between dependent nodes.
582 for (auto &memrefAndList : memrefAccesses) {
583 unsigned n = memrefAndList.second.size();
584 for (unsigned i = 0; i < n; ++i) {
585 unsigned srcId = memrefAndList.second[i];
586 bool srcHasStore =
587 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
588 for (unsigned j = i + 1; j < n; ++j) {
589 unsigned dstId = memrefAndList.second[j];
590 bool dstHasStore =
591 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
592 if (srcHasStore || dstHasStore)
593 addEdge(srcId, dstId, memrefAndList.first);
594 }
595 }
596 }
597 return true;
598}
599
MLIR Team38c2fe32019-01-14 19:26:25600namespace {
601
602// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
603// and operation count) for a loop nest up until the innermost loop body.
604struct LoopNestStats {
River Riddle5052bd82019-02-02 00:42:18605 // Map from AffineForOp to immediate child AffineForOps in its loop body.
606 DenseMap<Instruction *, SmallVector<OpPointer<AffineForOp>, 2>> loopMap;
607 // Map from AffineForOp to count of operations in its loop body.
608 DenseMap<Instruction *, uint64_t> opCountMap;
609 // Map from AffineForOp to its constant trip count.
610 DenseMap<Instruction *, uint64_t> tripCountMap;
MLIR Team38c2fe32019-01-14 19:26:25611};
612
613// LoopNestStatsCollector walks a single loop nest and gathers per-loop
614// trip count and operation count statistics and records them in 'stats'.
615class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
616public:
617 LoopNestStats *stats;
618 bool hasLoopWithNonConstTripCount = false;
619
620 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
621
River Riddlea3d9ccae2019-02-04 18:30:45622 void visitInstruction(OperationInst *opInst) {
River Riddle5052bd82019-02-02 00:42:18623 auto forOp = opInst->dyn_cast<AffineForOp>();
624 if (!forOp)
625 return;
626
627 auto *forInst = forOp->getInstruction();
628 auto *parentInst = forOp->getInstruction()->getParentInst();
MLIR Team38c2fe32019-01-14 19:26:25629 if (parentInst != nullptr) {
River Riddle5052bd82019-02-02 00:42:18630 assert(cast<OperationInst>(parentInst)->isa<AffineForOp>() &&
631 "Expected parent AffineForOp");
632 // Add mapping to 'forOp' from its parent AffineForOp.
633 stats->loopMap[parentInst].push_back(forOp);
MLIR Team38c2fe32019-01-14 19:26:25634 }
River Riddle5052bd82019-02-02 00:42:18635
636 // Record the number of op instructions in the body of 'forOp'.
MLIR Team38c2fe32019-01-14 19:26:25637 unsigned count = 0;
638 stats->opCountMap[forInst] = 0;
River Riddle5052bd82019-02-02 00:42:18639 for (auto &inst : *forOp->getBody()) {
640 if (!(cast<OperationInst>(inst).isa<AffineForOp>() ||
641 cast<OperationInst>(inst).isa<AffineIfOp>()))
MLIR Team38c2fe32019-01-14 19:26:25642 ++count;
643 }
644 stats->opCountMap[forInst] = count;
River Riddle5052bd82019-02-02 00:42:18645 // Record trip count for 'forOp'. Set flag if trip count is not constant.
646 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
MLIR Team38c2fe32019-01-14 19:26:25647 if (!maybeConstTripCount.hasValue()) {
648 hasLoopWithNonConstTripCount = true;
649 return;
650 }
651 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
652 }
653};
654
River Riddle5052bd82019-02-02 00:42:18655// Computes the total cost of the loop nest rooted at 'forOp'.
MLIR Team38c2fe32019-01-14 19:26:25656// Currently, the total cost is computed by counting the total operation
657// instance count (i.e. total number of operations in the loop bodyloop
658// operation count * loop trip count) for the entire loop nest.
659// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
660// specified in the map when computing the total op instance count.
661// NOTE: this is used to compute the cost of computation slices, which are
662// sliced along the iteration dimension, and thus reduce the trip count.
River Riddle5052bd82019-02-02 00:42:18663// If 'computeCostMap' is non-null, the total op count for forOps specified
MLIR Team38c2fe32019-01-14 19:26:25664// in the map is increased (not overridden) by adding the op count from the
665// map to the existing op count for the for loop. This is done before
666// multiplying by the loop's trip count, and is used to model the cost of
667// inserting a sliced loop nest of known cost into the loop's body.
668// NOTE: this is used to compute the cost of fusing a slice of some loop nest
669// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24670static int64_t getComputeCost(
River Riddle5052bd82019-02-02 00:42:18671 Instruction *forInst, LoopNestStats *stats,
672 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountOverrideMap,
673 DenseMap<Instruction *, int64_t> *computeCostMap) {
674 // 'opCount' is the total number operations in one iteration of 'forOp' body
Uday Bondhugula864d9e02019-01-23 17:16:24675 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25676 if (stats->loopMap.count(forInst) > 0) {
River Riddle5052bd82019-02-02 00:42:18677 for (auto childForOp : stats->loopMap[forInst]) {
678 opCount += getComputeCost(childForOp->getInstruction(), stats,
679 tripCountOverrideMap, computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:25680 }
681 }
682 // Add in additional op instances from slice (if specified in map).
683 if (computeCostMap != nullptr) {
684 auto it = computeCostMap->find(forInst);
685 if (it != computeCostMap->end()) {
686 opCount += it->second;
687 }
688 }
689 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24690 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25691 if (tripCountOverrideMap != nullptr) {
692 auto it = tripCountOverrideMap->find(forInst);
693 if (it != tripCountOverrideMap->end()) {
694 tripCount = it->second;
695 }
696 }
697 // Returns the total number of dynamic instances of operations in loop body.
698 return tripCount * opCount;
699}
700
701} // end anonymous namespace
702
MLIR Team27d067e2019-01-16 17:55:02703static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00704 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
705 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02706 assert(lbMap.getNumDims() == ubMap.getNumDims());
707 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
708 // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
709 // ub_expr - lb_expr
710 AffineExpr lbExpr(lbMap.getResult(0));
711 AffineExpr ubExpr(ubMap.getResult(0));
712 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
713 lbMap.getNumSymbols());
714 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
715 if (!cExpr)
716 return None;
717 return cExpr.getValue();
718}
719
River Riddle5052bd82019-02-02 00:42:18720// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
MLIR Team38c2fe32019-01-14 19:26:25721// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
722// Returns true on success, false otherwise (if a non-constant trip count
723// was encountered).
724// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02725static bool buildSliceTripCountMap(
726 OperationInst *srcOpInst, ComputationSliceState *sliceState,
River Riddle5052bd82019-02-02 00:42:18727 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
728 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02729 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25730 unsigned numSrcLoopIVs = srcLoopIVs.size();
River Riddle5052bd82019-02-02 00:42:18731 // Populate map from AffineForOp -> trip count
MLIR Team38c2fe32019-01-14 19:26:25732 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
733 AffineMap lbMap = sliceState->lbs[i];
734 AffineMap ubMap = sliceState->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17735 if (lbMap == AffineMap() || ubMap == AffineMap()) {
MLIR Team38c2fe32019-01-14 19:26:25736 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
737 if (srcLoopIVs[i]->hasConstantLowerBound() &&
738 srcLoopIVs[i]->hasConstantUpperBound()) {
River Riddle5052bd82019-02-02 00:42:18739 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] =
MLIR Team38c2fe32019-01-14 19:26:25740 srcLoopIVs[i]->getConstantUpperBound() -
741 srcLoopIVs[i]->getConstantLowerBound();
742 continue;
743 }
744 return false;
745 }
MLIR Team27d067e2019-01-16 17:55:02746 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
747 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25748 return false;
River Riddle5052bd82019-02-02 00:42:18749 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25750 }
751 return true;
752}
753
MLIR Team27d067e2019-01-16 17:55:02754// Removes load operations from 'srcLoads' which operate on 'memref', and
755// adds them to 'dstLoads'.
756static void
757moveLoadsAccessingMemrefTo(Value *memref,
758 SmallVectorImpl<OperationInst *> *srcLoads,
759 SmallVectorImpl<OperationInst *> *dstLoads) {
760 dstLoads->clear();
761 SmallVector<OperationInst *, 4> srcLoadsToKeep;
762 for (auto *load : *srcLoads) {
763 if (load->cast<LoadOp>()->getMemRef() == memref)
764 dstLoads->push_back(load);
765 else
766 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25767 }
MLIR Team27d067e2019-01-16 17:55:02768 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25769}
770
MLIR Team27d067e2019-01-16 17:55:02771// Returns the innermost common loop depth for the set of operations in 'ops'.
772static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
773 unsigned numOps = ops.size();
774 assert(numOps > 0);
775
River Riddle5052bd82019-02-02 00:42:18776 std::vector<SmallVector<OpPointer<AffineForOp>, 4>> loops(numOps);
MLIR Team27d067e2019-01-16 17:55:02777 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
778 for (unsigned i = 0; i < numOps; ++i) {
779 getLoopIVs(*ops[i], &loops[i]);
780 loopDepthLimit =
781 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25782 }
MLIR Team27d067e2019-01-16 17:55:02783
784 unsigned loopDepth = 0;
785 for (unsigned d = 0; d < loopDepthLimit; ++d) {
786 unsigned i;
787 for (i = 1; i < numOps; ++i) {
River Riddle5052bd82019-02-02 00:42:18788 if (loops[i - 1][d] != loops[i][d])
MLIR Team27d067e2019-01-16 17:55:02789 break;
MLIR Team27d067e2019-01-16 17:55:02790 }
791 if (i != numOps)
792 break;
793 ++loopDepth;
794 }
795 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25796}
797
MLIR Teamd7c82442019-01-30 23:53:41798// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
799// and 'storeOpInsts' are satisfied.
800static unsigned getMaxLoopDepth(ArrayRef<OperationInst *> loadOpInsts,
801 ArrayRef<OperationInst *> storeOpInsts) {
802 // Merge loads and stores into the same array.
803 SmallVector<OperationInst *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
804 ops.append(storeOpInsts.begin(), storeOpInsts.end());
805
806 // Compute the innermost common loop depth for loads and stores.
807 unsigned loopDepth = getInnermostCommonLoopDepth(ops);
808
809 // Return common loop depth for loads if there are no store ops.
810 if (storeOpInsts.empty())
811 return loopDepth;
812
813 // Check dependences on all pairs of ops in 'ops' and store the minimum
814 // loop depth at which a dependence is satisfied.
815 for (unsigned i = 0, e = ops.size(); i < e; ++i) {
816 auto *srcOpInst = ops[i];
817 MemRefAccess srcAccess(srcOpInst);
818 for (unsigned j = 0; j < e; ++j) {
819 auto *dstOpInst = ops[j];
820 MemRefAccess dstAccess(dstOpInst);
821
822 unsigned numCommonLoops =
823 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
824 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
825 FlatAffineConstraints dependenceConstraints;
826 // TODO(andydavis) Cache dependence analysis results, check cache here.
827 if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
828 &dependenceConstraints,
829 /*dependenceComponents=*/nullptr)) {
830 // Store minimum loop depth and break because we want the min 'd' at
831 // which there is a dependence.
832 loopDepth = std::min(loopDepth, d - 1);
833 break;
834 }
835 }
836 }
837 }
838 return loopDepth;
839}
840
Uday Bondhugulac1ca23e2019-01-16 21:13:00841// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
842// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:02843// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
844// and 'sliceStateB' are aligned.
845// Specifically, when taking the union of overlapping intervals, it assumes
846// that both intervals start at zero. Support needs to be added to take into
847// account interval start offset when computing the union.
848// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:00849static bool getSliceUnion(const ComputationSliceState &sliceStateA,
850 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:02851 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
852 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
853
854 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
855 AffineMap lbMapA = sliceStateA.lbs[i];
856 AffineMap ubMapA = sliceStateA.ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17857 if (lbMapA == AffineMap()) {
858 assert(ubMapA == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02859 continue;
860 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00861 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:02862
863 AffineMap lbMapB = sliceStateB->lbs[i];
864 AffineMap ubMapB = sliceStateB->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17865 if (lbMapB == AffineMap()) {
866 assert(ubMapB == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02867 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
868 sliceStateB->lbs[i] = lbMapA;
869 sliceStateB->ubs[i] = ubMapA;
870 continue;
871 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00872
873 // TODO(andydavis) Change this code to take the min across all lower bounds
874 // and max across all upper bounds for each dimension. This code can for
875 // cases where a unique min or max could not be statically determined.
876
877 // Assumption: both lower bounds are the same.
878 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:02879 return false;
880
881 // Add bound with the largest trip count to union.
882 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
883 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
884 if (!tripCountA.hasValue() || !tripCountB.hasValue())
885 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:00886
MLIR Team27d067e2019-01-16 17:55:02887 if (tripCountA.getValue() > tripCountB.getValue()) {
888 sliceStateB->lbs[i] = lbMapA;
889 sliceStateB->ubs[i] = ubMapA;
890 }
891 }
892 return true;
893}
894
Uday Bondhugula8be26272019-02-02 01:06:22895// TODO(mlir-team): improve/complete this when we have target data.
896unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
897 auto elementType = memRefType.getElementType();
898
899 unsigned sizeInBits;
900 if (elementType.isIntOrFloat()) {
901 sizeInBits = elementType.getIntOrFloatBitWidth();
902 } else {
903 auto vectorType = elementType.cast<VectorType>();
904 sizeInBits =
905 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
906 }
907 return llvm::divideCeil(sizeInBits, 8);
908}
909
MLIR Teamc4237ae2019-01-18 16:56:27910// Creates and returns a private (single-user) memref for fused loop rooted
River Riddle5052bd82019-02-02 00:42:18911// at 'forOp', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:52912// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
913// TODO(bondhugula): consider refactoring the common code from generateDma and
914// this one.
River Riddle5052bd82019-02-02 00:42:18915static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
Uday Bondhugula94a03f82019-01-22 21:58:52916 OperationInst *srcStoreOpInst,
Uday Bondhugula8be26272019-02-02 01:06:22917 unsigned dstLoopDepth,
918 Optional<unsigned> fastMemorySpace,
919 unsigned localBufSizeThreshold) {
River Riddle5052bd82019-02-02 00:42:18920 auto *forInst = forOp->getInstruction();
921
922 // Create builder to insert alloc op just before 'forOp'.
MLIR Teamc4237ae2019-01-18 16:56:27923 FuncBuilder b(forInst);
924 // Builder to create constants at the top level.
925 FuncBuilder top(forInst->getFunction());
926 // Create new memref type based on slice bounds.
927 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>()->getMemRef();
928 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
929 unsigned rank = oldMemRefType.getRank();
930
Uday Bondhugula94a03f82019-01-22 21:58:52931 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
Uday Bondhugulab26900d2019-02-04 15:58:42932 auto region = getMemRefRegion(srcStoreOpInst, dstLoopDepth);
River Riddle6859f332019-01-23 22:39:45933 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:27934 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:52935 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:27936 lbs.reserve(rank);
937 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:52938 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27939 Optional<int64_t> numElements =
Uday Bondhugulab26900d2019-02-04 15:58:42940 region->getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
Uday Bondhugula8be26272019-02-02 01:06:22941 assert(numElements.hasValue() &&
942 "non-constant number of elts in local buffer");
MLIR Teamc4237ae2019-01-18 16:56:27943
Uday Bondhugulab26900d2019-02-04 15:58:42944 const FlatAffineConstraints *cst = region->getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:52945 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
946 // on; this would correspond to loop IVs surrounding the level at which the
947 // slice is being materialized.
948 SmallVector<Value *, 8> outerIVs;
949 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
950
951 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:27952 SmallVector<AffineExpr, 4> offsets;
953 offsets.reserve(rank);
954 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:52955 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
956
MLIR Teamc4237ae2019-01-18 16:56:27957 AffineExpr offset = top.getAffineConstantExpr(0);
958 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
959 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
960 }
Uday Bondhugula94a03f82019-01-22 21:58:52961 assert(lbDivisors[d] > 0);
962 offset =
963 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:27964 offsets.push_back(offset);
965 }
966
967 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
968 // by 'srcStoreOpInst'.
Uday Bondhugula8be26272019-02-02 01:06:22969 uint64_t bufSize =
970 getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
971 unsigned newMemSpace;
972 if (bufSize < localBufSizeThreshold && fastMemorySpace.hasValue()) {
973 newMemSpace = fastMemorySpace.getValue();
974 } else {
975 newMemSpace = oldMemRefType.getMemorySpace();
976 }
977 auto newMemRefType = top.getMemRefType(
978 newShape, oldMemRefType.getElementType(), {}, newMemSpace);
MLIR Teamc4237ae2019-01-18 16:56:27979 // Gather alloc operands for the dynamic dimensions of the memref.
980 SmallVector<Value *, 4> allocOperands;
981 unsigned dynamicDimCount = 0;
982 for (auto dimSize : oldMemRefType.getShape()) {
983 if (dimSize == -1)
984 allocOperands.push_back(
River Riddle5052bd82019-02-02 00:42:18985 top.create<DimOp>(forOp->getLoc(), oldMemRef, dynamicDimCount++));
MLIR Teamc4237ae2019-01-18 16:56:27986 }
987
River Riddle5052bd82019-02-02 00:42:18988 // Create new private memref for fused loop 'forOp'.
MLIR Teama0f3db402019-01-29 17:36:41989 // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
990 // consumer loop nests to reduce their live range. Currently they are added
991 // at the beginning of the function, because loop nests can be reordered
992 // during the fusion pass.
MLIR Teamc4237ae2019-01-18 16:56:27993 Value *newMemRef =
River Riddle5052bd82019-02-02 00:42:18994 top.create<AllocOp>(forOp->getLoc(), newMemRefType, allocOperands);
MLIR Teamc4237ae2019-01-18 16:56:27995
996 // Build an AffineMap to remap access functions based on lower bound offsets.
997 SmallVector<AffineExpr, 4> remapExprs;
998 remapExprs.reserve(rank);
999 unsigned zeroOffsetCount = 0;
1000 for (unsigned i = 0; i < rank; i++) {
1001 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
1002 if (constExpr.getValue() == 0)
1003 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:521004 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
1005
1006 auto remapExpr =
1007 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
1008 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:271009 }
Uday Bondhugula94a03f82019-01-22 21:58:521010 auto indexRemap =
1011 zeroOffsetCount == rank
Nicolas Vasilache0e7a8a92019-01-26 18:41:171012 ? AffineMap()
Uday Bondhugula94a03f82019-01-22 21:58:521013 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:271014 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:521015 bool ret =
1016 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
1017 /*extraOperands=*/outerIVs,
River Riddle5052bd82019-02-02 00:42:181018 /*domInstFilter=*/&*forOp->getBody()->begin());
Uday Bondhugula94a03f82019-01-22 21:58:521019 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:371020 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:271021 return newMemRef;
1022}
1023
Uday Bondhugula864d9e02019-01-23 17:16:241024// Does the slice have a single iteration?
1025static uint64_t getSliceIterationCount(
River Riddle5052bd82019-02-02 00:42:181026 const llvm::SmallDenseMap<Instruction *, uint64_t, 8> &sliceTripCountMap) {
Uday Bondhugula864d9e02019-01-23 17:16:241027 uint64_t iterCount = 1;
1028 for (const auto &count : sliceTripCountMap) {
1029 iterCount *= count.second;
1030 }
1031 return iterCount;
1032}
1033
MLIR Team27d067e2019-01-16 17:55:021034// Checks the profitability of fusing a backwards slice of the loop nest
MLIR Teamd7c82442019-01-30 23:53:411035// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
Uday Bondhugulab4a14432019-01-26 00:00:501036// Returns true if it is profitable to fuse the candidate loop nests. Returns
1037// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
1038// to materialize the source loop nest slice.
MLIR Team38c2fe32019-01-14 19:26:251039// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:021040// *) Computes the backward computation slice at 'srcOpInst'. This
1041// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:251042// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:021043// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:251044// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
1045// loop nest is the total number of dynamic operation instances in the loop
1046// nest).
1047// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:021048// loop nest at various values of dst loop depth, attempting to fuse
1049// the largest compution slice at the maximal dst loop depth (closest to the
1050// load) to minimize reuse distance and potentially enable subsequent
1051// load/store forwarding.
MLIR Teamd7c82442019-01-30 23:53:411052// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
MLIR Team27d067e2019-01-16 17:55:021053// the same memref as is written by 'srcOpInst', then the union of slice
1054// loop bounds is used to compute the slice and associated slice cost.
Uday Bondhugulab4a14432019-01-26 00:00:501055// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
MLIR Team38c2fe32019-01-14 19:26:251056// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:021057// NOTE: We attempt to maximize the dst loop depth, but there are cases
1058// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:251059// loop (within the src computation slice) at a depth which results in
1060// execessive recomputation (see unit tests for examples).
1061// *) Compares the total cost of the unfused loop nests to the min cost fused
1062// loop nest computed in the previous step, and returns true if the latter
1063// is lower.
MLIR Team27d067e2019-01-16 17:55:021064static bool isFusionProfitable(OperationInst *srcOpInst,
MLIR Teamd7c82442019-01-30 23:53:411065 ArrayRef<OperationInst *> dstLoadOpInsts,
1066 ArrayRef<OperationInst *> dstStoreOpInsts,
MLIR Team38c2fe32019-01-14 19:26:251067 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:021068 unsigned *dstLoopDepth) {
Uday Bondhugula06d21d92019-01-25 01:01:491069 LLVM_DEBUG({
1070 llvm::dbgs() << "Checking whether fusion is profitable between:\n";
1071 llvm::dbgs() << " ";
1072 srcOpInst->dump();
1073 llvm::dbgs() << " and \n";
MLIR Teamd7c82442019-01-30 23:53:411074 for (auto dstOpInst : dstLoadOpInsts) {
Uday Bondhugula06d21d92019-01-25 01:01:491075 llvm::dbgs() << " ";
1076 dstOpInst->dump();
1077 };
1078 });
Uday Bondhugula864d9e02019-01-23 17:16:241079
MLIR Team38c2fe32019-01-14 19:26:251080 // Compute cost of sliced and unsliced src loop nest.
River Riddle5052bd82019-02-02 00:42:181081 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:021082 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251083 unsigned numSrcLoopIVs = srcLoopIVs.size();
1084
1085 // Walk src loop nest and collect stats.
1086 LoopNestStats srcLoopNestStats;
1087 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
River Riddle5052bd82019-02-02 00:42:181088 srcStatsCollector.walk(srcLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251089 // Currently only constant trip count loop nests are supported.
1090 if (srcStatsCollector.hasLoopWithNonConstTripCount)
1091 return false;
1092
1093 // Compute cost of dst loop nest.
River Riddle5052bd82019-02-02 00:42:181094 SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
MLIR Teamd7c82442019-01-30 23:53:411095 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251096
1097 LoopNestStats dstLoopNestStats;
1098 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
River Riddle5052bd82019-02-02 00:42:181099 dstStatsCollector.walk(dstLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251100 // Currently only constant trip count loop nests are supported.
1101 if (dstStatsCollector.hasLoopWithNonConstTripCount)
1102 return false;
1103
MLIR Teamd7c82442019-01-30 23:53:411104 // Compute the maximum loop depth at which we can can insert the src slice
1105 // and still satisfy dest loop nest dependences.
1106 unsigned maxDstLoopDepth = getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts);
MLIR Team27d067e2019-01-16 17:55:021107 if (maxDstLoopDepth == 0)
1108 return false;
1109
1110 // Search for min cost value for 'dstLoopDepth'. At each value of
1111 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
1112 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
1113 // of these bounds). Next the union slice bounds are used to calculate
1114 // the cost of the slice and the cost of the slice inserted into the dst
1115 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:241116 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
1117 uint64_t maxStorageReduction = 0;
1118 Optional<uint64_t> sliceMemEstimate = None;
1119
MLIR Team27d067e2019-01-16 17:55:021120 SmallVector<ComputationSliceState, 4> sliceStates;
1121 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:241122 // The best loop depth at which to materialize the slice.
1123 Optional<unsigned> bestDstLoopDepth = None;
1124
1125 // Compute op instance count for the src loop nest without iteration slicing.
River Riddle5052bd82019-02-02 00:42:181126 uint64_t srcLoopNestCost =
1127 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
1128 /*tripCountOverrideMap=*/nullptr,
1129 /*computeCostMap=*/nullptr);
Uday Bondhugula864d9e02019-01-23 17:16:241130
1131 // Compute op instance count for the src loop nest.
River Riddle5052bd82019-02-02 00:42:181132 uint64_t dstLoopNestCost =
1133 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
1134 /*tripCountOverrideMap=*/nullptr,
1135 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:021136
River Riddle5052bd82019-02-02 00:42:181137 llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap;
1138 DenseMap<Instruction *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:021139 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
1140 MemRefAccess srcAccess(srcOpInst);
1141 // Handle the common case of one dst load without a copy.
1142 if (!mlir::getBackwardComputationSliceState(
MLIR Teamd7c82442019-01-30 23:53:411143 srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1]))
MLIR Team27d067e2019-01-16 17:55:021144 return false;
MLIR Teamd7c82442019-01-30 23:53:411145 // Compute the union of slice bound of all ops in 'dstLoadOpInsts'.
1146 for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) {
1147 MemRefAccess dstAccess(dstLoadOpInsts[j]);
MLIR Team27d067e2019-01-16 17:55:021148 ComputationSliceState tmpSliceState;
1149 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
1150 &tmpSliceState))
1151 return false;
1152 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:001153 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:251154 }
Uday Bondhugulab4a14432019-01-26 00:00:501155 // Build trip count map for computation slice. We'll skip cases where the
1156 // trip count was non-constant.
MLIR Team27d067e2019-01-16 17:55:021157 sliceTripCountMap.clear();
1158 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
1159 &sliceTripCountMap))
Uday Bondhugula864d9e02019-01-23 17:16:241160 continue;
1161
1162 // Checks whether a store to load forwarding will happen.
1163 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
Uday Bondhugula864d9e02019-01-23 17:16:241164 assert(sliceIterationCount > 0);
Uday Bondhugulab4a14432019-01-26 00:00:501165 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
Uday Bondhugula864d9e02019-01-23 17:16:241166
1167 // Compute cost of fusion for this dest loop depth.
1168
1169 computeCostMap.clear();
1170
1171 // The store and loads to this memref will disappear.
1172 if (storeLoadFwdGuaranteed) {
1173 // A single store disappears: -1 for that.
River Riddle5052bd82019-02-02 00:42:181174 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
MLIR Teamd7c82442019-01-30 23:53:411175 for (auto *loadOp : dstLoadOpInsts) {
River Riddle5052bd82019-02-02 00:42:181176 auto *parentInst = loadOp->getParentInst();
1177 if (parentInst && cast<OperationInst>(parentInst)->isa<AffineForOp>())
1178 computeCostMap[parentInst] = -1;
Uday Bondhugula864d9e02019-01-23 17:16:241179 }
1180 }
MLIR Team27d067e2019-01-16 17:55:021181
MLIR Team38c2fe32019-01-14 19:26:251182 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:241183 int64_t sliceComputeCost =
River Riddle5052bd82019-02-02 00:42:181184 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
Uday Bondhugula864d9e02019-01-23 17:16:241185 /*tripCountOverrideMap=*/&sliceTripCountMap,
1186 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:251187
Uday Bondhugula864d9e02019-01-23 17:16:241188 // Compute cost of fusion for this depth.
River Riddle5052bd82019-02-02 00:42:181189 computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:241190
1191 int64_t fusedLoopNestComputeCost =
River Riddle5052bd82019-02-02 00:42:181192 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
MLIR Team27d067e2019-01-16 17:55:021193 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:241194
1195 double additionalComputeFraction =
1196 fusedLoopNestComputeCost /
1197 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1198 1;
1199
1200 // TODO(bondhugula): This is an ugly approximation. Fix this by finding a
1201 // good way to calculate the footprint of the memref in the slice and
1202 // divide it by the total memory footprint of the fused computation.
1203 double storageReduction =
1204 static_cast<double>(srcLoopNestCost) / sliceIterationCount;
1205
Uday Bondhugula06d21d92019-01-25 01:01:491206 LLVM_DEBUG({
1207 std::stringstream msg;
1208 msg << " evaluating fusion profitability at depth : " << i << "\n"
1209 << std::setprecision(2) << " additional compute fraction: "
1210 << 100.0 * additionalComputeFraction << "%\n"
1211 << " storage reduction factor: " << storageReduction << "x\n"
1212 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
1213 << " slice iteration count: " << sliceIterationCount << "\n";
1214 llvm::dbgs() << msg.str();
1215 });
Uday Bondhugula864d9e02019-01-23 17:16:241216
1217 double computeToleranceThreshold =
1218 clFusionAddlComputeTolerance.getNumOccurrences() > 0
1219 ? clFusionAddlComputeTolerance
1220 : LoopFusion::kComputeToleranceThreshold;
1221
1222 // TODO(b/123247369): This is a placeholder cost model.
1223 // Among all choices that add an acceptable amount of redundant computation
1224 // (as per computeToleranceThreshold), we will simply pick the one that
1225 // reduces the intermediary size the most.
1226 if ((storageReduction > maxStorageReduction) &&
1227 (clMaximalLoopFusion ||
1228 (additionalComputeFraction < computeToleranceThreshold))) {
1229 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:021230 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:241231 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
1232 // TODO(bondhugula,andydavis): find a good way to compute the memory
1233 // footprint of the materialized slice.
1234 // Approximating this to the compute cost of the slice. This could be an
1235 // under-approximation or an overapproximation, but in many cases
1236 // accurate.
1237 sliceMemEstimate = sliceIterationCount;
MLIR Team38c2fe32019-01-14 19:26:251238 }
1239 }
1240
Uday Bondhugula864d9e02019-01-23 17:16:241241 // A simple cost model: fuse if it reduces the memory footprint. If
1242 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251243
Uday Bondhugula864d9e02019-01-23 17:16:241244 if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
1245 LLVM_DEBUG(llvm::dbgs()
1246 << "All fusion choices involve more than the threshold amount of"
1247 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251248 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241249 }
1250
1251 assert(bestDstLoopDepth.hasValue() &&
1252 "expected to have a value per logic above");
1253
1254 // Set dstLoopDepth based on best values from search.
1255 *dstLoopDepth = bestDstLoopDepth.getValue();
1256
1257 LLVM_DEBUG(
Uday Bondhugula06d21d92019-01-25 01:01:491258 llvm::dbgs() << " LoopFusion fusion stats:"
1259 << "\n best loop depth: " << bestDstLoopDepth
Uday Bondhugula864d9e02019-01-23 17:16:241260 << "\n src loop nest compute cost: " << srcLoopNestCost
1261 << "\n dst loop nest compute cost: " << dstLoopNestCost
1262 << "\n fused loop nest compute cost: "
1263 << minFusedLoopNestComputeCost << "\n");
1264
River Riddle5052bd82019-02-02 00:42:181265 auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
1266 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Uday Bondhugula864d9e02019-01-23 17:16:241267
1268 Optional<double> storageReduction = None;
1269
1270 if (!clMaximalLoopFusion) {
1271 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1272 LLVM_DEBUG(
1273 llvm::dbgs()
1274 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1275 return false;
1276 }
1277
1278 auto srcMemSizeVal = srcMemSize.getValue();
1279 auto dstMemSizeVal = dstMemSize.getValue();
1280
1281 assert(sliceMemEstimate.hasValue() && "expected value");
1282 // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
1283 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1284
1285 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1286 << " dst mem: " << dstMemSizeVal << "\n"
1287 << " fused mem: " << fusedMem << "\n"
1288 << " slice mem: " << sliceMemEstimate << "\n");
1289
1290 if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
1291 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1292 return false;
1293 }
1294 storageReduction =
1295 100.0 *
1296 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1297 }
1298
1299 double additionalComputeFraction =
1300 100.0 * (minFusedLoopNestComputeCost /
1301 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1302 1);
MLIR Team5c5739d2019-01-25 06:27:401303 (void)additionalComputeFraction;
Uday Bondhugula06d21d92019-01-25 01:01:491304 LLVM_DEBUG({
1305 std::stringstream msg;
1306 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
1307 << setprecision(2) << additionalComputeFraction
1308 << "% redundant computation and a ";
1309 msg << (storageReduction.hasValue()
1310 ? std::to_string(storageReduction.getValue())
1311 : "<unknown>");
1312 msg << "% storage reduction.\n";
1313 llvm::dbgs() << msg.str();
1314 });
Uday Bondhugula864d9e02019-01-23 17:16:241315
MLIR Team27d067e2019-01-16 17:55:021316 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241317 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021318 sliceState->lbs = bestSliceState->lbs;
1319 sliceState->ubs = bestSliceState->ubs;
1320 sliceState->lbOperands = bestSliceState->lbOperands;
1321 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241322
MLIR Team27d067e2019-01-16 17:55:021323 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251324 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Nicolas Vasilache0e7a8a92019-01-26 18:41:171325 if (sliceState->lbs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021326 canonicalizeMapAndOperands(&sliceState->lbs[i],
1327 &sliceState->lbOperands[i]);
1328 }
Nicolas Vasilache0e7a8a92019-01-26 18:41:171329 if (sliceState->ubs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021330 canonicalizeMapAndOperands(&sliceState->ubs[i],
1331 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251332 }
1333 }
1334 return true;
1335}
1336
MLIR Team6892ffb2018-12-20 04:42:551337// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:141338// relationship on a memref, with the goal of improving locality. Currently,
1339// this the producer/consumer relationship is required to be unique in the
Chris Lattner69d9e992018-12-28 16:48:091340// Function (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:001341//
MLIR Team3b692302018-12-17 17:57:141342// The steps of the algorithm are as follows:
1343//
MLIR Team6892ffb2018-12-20 04:42:551344// *) A worklist is initialized with node ids from the dependence graph.
1345// *) For each node id in the worklist:
River Riddle5052bd82019-02-02 00:42:181346// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a
1347// candidate destination AffineForOp into which fusion will be attempted.
1348// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141349// *) For each LoadOp in 'dstLoadOps' do:
Chris Lattner69d9e992018-12-28 16:48:091350// *) Lookup dependent loop nests at earlier positions in the Function
MLIR Team3b692302018-12-17 17:57:141351// which have a single store op to the same memref.
1352// *) Check if dependences would be violated by the fusion. For example,
1353// the src loop nest may load from memrefs which are different than
1354// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:551355// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141356// bounds to be functions of 'dstLoopNest' IVs and symbols.
1357// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
1358// just before the dst load op user.
Chris Lattner456ad6a2018-12-29 00:05:351359// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:141360// and also add newly fuse load ops to 'dstLoopOps' to be considered
1361// as fusion dst load ops in another iteration.
1362// *) Remove old src loop nest and its associated state.
1363//
Chris Lattner456ad6a2018-12-29 00:05:351364// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141365// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551366// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141367//
MLIR Team6892ffb2018-12-20 04:42:551368// This greedy algorithm is not 'maximal' due to the current restriction of
1369// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141370//
1371// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551372// TODO(andydavis) Add support for fusing for input reuse (perhaps by
1373// constructing a graph with edges which represent loads from the same memref
MLIR Team5c5739d2019-01-25 06:27:401374// in two different loop nests.
MLIR Team6892ffb2018-12-20 04:42:551375struct GreedyFusion {
1376public:
1377 MemRefDependenceGraph *mdg;
MLIR Team3b692302018-12-17 17:57:141378 SmallVector<unsigned, 4> worklist;
MLIR Teamf28e4df2018-11-01 14:26:001379
MLIR Team6892ffb2018-12-20 04:42:551380 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
1381 // Initialize worklist with nodes from 'mdg'.
1382 worklist.resize(mdg->nodes.size());
1383 std::iota(worklist.begin(), worklist.end(), 0);
1384 }
MLIR Team3b692302018-12-17 17:57:141385
Uday Bondhugula8be26272019-02-02 01:06:221386 void run(unsigned localBufSizeThreshold, Optional<unsigned> fastMemorySpace) {
MLIR Team3b692302018-12-17 17:57:141387 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551388 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141389 worklist.pop_back();
MLIR Team6892ffb2018-12-20 04:42:551390 // Skip if this node was removed (fused into another node).
1391 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141392 continue;
MLIR Team6892ffb2018-12-20 04:42:551393 // Get 'dstNode' into which to attempt fusion.
1394 auto *dstNode = mdg->getNode(dstId);
1395 // Skip if 'dstNode' is not a loop nest.
River Riddle5052bd82019-02-02 00:42:181396 if (!cast<OperationInst>(dstNode->inst)->isa<AffineForOp>())
MLIR Team3b692302018-12-17 17:57:141397 continue;
1398
Chris Lattner5187cfc2018-12-28 05:21:411399 SmallVector<OperationInst *, 4> loads = dstNode->loads;
MLIR Team27d067e2019-01-16 17:55:021400 SmallVector<OperationInst *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271401 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551402 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021403 // Get memref of load on top of the stack.
1404 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271405 if (visitedMemrefs.count(memref) > 0)
1406 continue;
1407 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021408 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1409 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551410 // Skip if no input edges along which to fuse.
1411 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141412 continue;
MLIR Team1e851912019-01-31 00:01:461413 // Iterate through in edges for 'dstId' and src node id for any
1414 // edges on 'memref'.
1415 SmallVector<unsigned, 2> srcNodeIds;
MLIR Team6892ffb2018-12-20 04:42:551416 for (auto &srcEdge : mdg->inEdges[dstId]) {
1417 // Skip 'srcEdge' if not for 'memref'.
MLIR Teama0f3db402019-01-29 17:36:411418 if (srcEdge.value != memref)
MLIR Team6892ffb2018-12-20 04:42:551419 continue;
MLIR Team1e851912019-01-31 00:01:461420 srcNodeIds.push_back(srcEdge.id);
1421 }
1422 for (unsigned srcId : srcNodeIds) {
1423 // Skip if this node was removed (fused into another node).
1424 if (mdg->nodes.count(srcId) == 0)
1425 continue;
1426 // Get 'srcNode' from which to attempt fusion into 'dstNode'.
1427 auto *srcNode = mdg->getNode(srcId);
MLIR Team6892ffb2018-12-20 04:42:551428 // Skip if 'srcNode' is not a loop nest.
River Riddle5052bd82019-02-02 00:42:181429 if (!cast<OperationInst>(srcNode->inst)->isa<AffineForOp>())
MLIR Team6892ffb2018-12-20 04:42:551430 continue;
MLIR Teamb28009b2019-01-23 19:11:431431 // Skip if 'srcNode' has more than one store to any memref.
1432 // TODO(andydavis) Support fusing multi-output src loop nests.
1433 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551434 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241435
MLIR Teama0f3db402019-01-29 17:36:411436 // Skip 'srcNode' if it has in edges on 'memref'.
MLIR Team6892ffb2018-12-20 04:42:551437 // TODO(andydavis) Track dependence type with edges, and just check
MLIR Teama0f3db402019-01-29 17:36:411438 // for WAW dependence edge here. Note that this check is overly
1439 // conservative and will be removed in the future.
1440 if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0)
MLIR Team6892ffb2018-12-20 04:42:551441 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241442
MLIR Teamd7c82442019-01-30 23:53:411443 // Skip if 'srcNode' writes to any live in or escaping memrefs.
1444 if (mdg->writesToLiveInOrEscapingMemrefs(srcNode->id))
1445 continue;
1446
MLIR Teama0f3db402019-01-29 17:36:411447 // Compute an instruction list insertion point for the fused loop
1448 // nest which preserves dependences.
1449 Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint(
1450 srcNode->id, dstNode->id, memref);
1451 if (insertPointInst == nullptr)
MLIR Team6892ffb2018-12-20 04:42:551452 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241453
MLIR Team6892ffb2018-12-20 04:42:551454 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351455 auto *srcStoreOpInst = srcNode->stores.front();
MLIR Teamd7c82442019-01-30 23:53:411456 // Gather 'dstNode' store ops to 'memref'.
1457 SmallVector<OperationInst *, 2> dstStoreOpInsts;
1458 for (auto *storeOpInst : dstNode->stores)
1459 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1460 dstStoreOpInsts.push_back(storeOpInst);
1461
Uday Bondhugulab4a14432019-01-26 00:00:501462 unsigned bestDstLoopDepth;
MLIR Team38c2fe32019-01-14 19:26:251463 mlir::ComputationSliceState sliceState;
MLIR Teama0f3db402019-01-29 17:36:411464 // Check if fusion would be profitable.
MLIR Teamd7c82442019-01-30 23:53:411465 if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts,
1466 dstStoreOpInsts, &sliceState,
Uday Bondhugulab4a14432019-01-26 00:00:501467 &bestDstLoopDepth))
MLIR Team38c2fe32019-01-14 19:26:251468 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241469
MLIR Team6892ffb2018-12-20 04:42:551470 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
River Riddle5052bd82019-02-02 00:42:181471 auto sliceLoopNest = mlir::insertBackwardComputationSlice(
Uday Bondhugulab4a14432019-01-26 00:00:501472 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:551473 if (sliceLoopNest != nullptr) {
River Riddle5052bd82019-02-02 00:42:181474 // Move 'dstAffineForOp' before 'insertPointInst' if needed.
1475 auto dstAffineForOp =
1476 cast<OperationInst>(dstNode->inst)->cast<AffineForOp>();
1477 if (insertPointInst != dstAffineForOp->getInstruction()) {
1478 dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
MLIR Teama0f3db402019-01-29 17:36:411479 }
MLIR Teamc4237ae2019-01-18 16:56:271480 // Update edges between 'srcNode' and 'dstNode'.
MLIR Teama0f3db402019-01-29 17:36:411481 mdg->updateEdges(srcNode->id, dstNode->id, memref);
MLIR Teamc4237ae2019-01-18 16:56:271482
1483 // Collect slice loop stats.
1484 LoopNestStateCollector sliceCollector;
River Riddle5052bd82019-02-02 00:42:181485 sliceCollector.walk(sliceLoopNest->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271486 // Promote single iteration slice loops to single IV value.
River Riddle5052bd82019-02-02 00:42:181487 for (auto forOp : sliceCollector.forOps) {
1488 promoteIfSingleIteration(forOp);
MLIR Team6892ffb2018-12-20 04:42:551489 }
River Riddle5052bd82019-02-02 00:42:181490 // Create private memref for 'memref' in 'dstAffineForOp'.
MLIR Teamc4237ae2019-01-18 16:56:271491 SmallVector<OperationInst *, 4> storesForMemref;
1492 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
1493 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1494 storesForMemref.push_back(storeOpInst);
1495 }
1496 assert(storesForMemref.size() == 1);
Uday Bondhugula94a03f82019-01-22 21:58:521497 auto *newMemRef = createPrivateMemRef(
Uday Bondhugula8be26272019-02-02 01:06:221498 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1499 fastMemorySpace, localBufSizeThreshold);
MLIR Teamc4237ae2019-01-18 16:56:271500 visitedMemrefs.insert(newMemRef);
MLIR Teama0f3db402019-01-29 17:36:411501 // Create new node in dependence graph for 'newMemRef' alloc op.
1502 unsigned newMemRefNodeId =
1503 mdg->addNode(newMemRef->getDefiningInst());
1504 // Add edge from 'newMemRef' node to dstNode.
1505 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
MLIR Teamc4237ae2019-01-18 16:56:271506
1507 // Collect dst loop stats after memref privatizaton transformation.
1508 LoopNestStateCollector dstLoopCollector;
River Riddle5052bd82019-02-02 00:42:181509 dstLoopCollector.walk(dstAffineForOp->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271510
1511 // Add new load ops to current Node load op list 'loads' to
1512 // continue fusing based on new operands.
1513 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
1514 auto *loadMemRef = loadOpInst->cast<LoadOp>()->getMemRef();
1515 if (visitedMemrefs.count(loadMemRef) == 0)
1516 loads.push_back(loadOpInst);
1517 }
1518
1519 // Clear and add back loads and stores
1520 mdg->clearNodeLoadAndStores(dstNode->id);
1521 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1522 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371523 // Remove old src loop nest if it no longer has outgoing dependence
1524 // edges, and it does not write to a memref which escapes the
1525 // function.
MLIR Teama0f3db402019-01-29 17:36:411526 if (mdg->canRemoveNode(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271527 mdg->removeNode(srcNode->id);
River Riddle5052bd82019-02-02 00:42:181528 srcNode->inst->erase();
MLIR Teamc4237ae2019-01-18 16:56:271529 }
MLIR Team3b692302018-12-17 17:57:141530 }
MLIR Team3b692302018-12-17 17:57:141531 }
1532 }
1533 }
MLIR Teamc4237ae2019-01-18 16:56:271534 // Clean up any allocs with no users.
1535 for (auto &pair : mdg->memrefEdgeCount) {
1536 if (pair.second > 0)
1537 continue;
1538 auto *memref = pair.first;
MLIR Team71495d52019-01-22 21:23:371539 // Skip if there exist other uses (return instruction or function calls).
1540 if (!memref->use_empty())
1541 continue;
MLIR Teamc4237ae2019-01-18 16:56:271542 // Use list expected to match the dep graph info.
MLIR Teamc4237ae2019-01-18 16:56:271543 auto *inst = memref->getDefiningInst();
1544 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
1545 if (opInst && opInst->isa<AllocOp>())
1546 opInst->erase();
1547 }
MLIR Teamf28e4df2018-11-01 14:26:001548 }
MLIR Team3b692302018-12-17 17:57:141549};
1550
1551} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:001552
Chris Lattner79748892018-12-31 07:10:351553PassResult LoopFusion::runOnFunction(Function *f) {
Uday Bondhugula8be26272019-02-02 01:06:221554 if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
1555 fastMemorySpace = clFusionFastMemorySpace.getValue();
1556 }
1557
MLIR Team6892ffb2018-12-20 04:42:551558 MemRefDependenceGraph g;
1559 if (g.init(f))
Uday Bondhugula8be26272019-02-02 01:06:221560 GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace);
MLIR Teamf28e4df2018-11-01 14:26:001561 return success();
1562}
Jacques Pienaar6f0fb222018-11-07 02:34:181563
1564static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");