blob: 1c4a4d1f755c2f422dde1c86c3954937209ff876 [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
River Riddle75553832019-01-29 05:23:5322#include "mlir/AffineOps/AffineOps.h"
MLIR Teamf28e4df2018-11-01 14:26:0023#include "mlir/Analysis/AffineAnalysis.h"
Uday Bondhuguladfe07b72019-02-23 00:51:0824#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0025#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1426#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0027#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
River Riddle48ccae22019-02-20 01:17:4630#include "mlir/Pass/Pass.h"
Lei Zhang85d9b6c2019-03-01 21:48:2431#include "mlir/StandardOps/Ops.h"
MLIR Teamf28e4df2018-11-01 14:26:0032#include "mlir/Transforms/LoopUtils.h"
33#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2734#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0035#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1436#include "llvm/ADT/DenseSet.h"
37#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2338#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2539#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1440#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2441#include <iomanip>
Nicolas Vasilache258e8d92019-05-03 18:07:3742#define DEBUG_TYPE "affine-loop-fusion"
MLIR Team38c2fe32019-01-14 19:26:2543
MLIR Team3b692302018-12-17 17:57:1444using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0045
46using namespace mlir;
47
River Riddle75c21e12019-01-26 06:14:0448static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
49
Uday Bondhugulace7e59532019-03-08 17:21:5250/// Disables fusion profitability check and fuses if valid. Ignore any
51/// additional (redundant) computation tolerance threshold
52/// that would have prevented fusion.
MLIR Teamc4237ae2019-01-18 16:56:2753static llvm::cl::opt<bool>
Uday Bondhugulaeee85362019-03-02 01:42:1354 clMaximalLoopFusion("fusion-maximal",
River Riddle75c21e12019-01-26 06:14:0455 llvm::cl::desc("Enables maximal loop fusion"),
56 llvm::cl::cat(clOptionsCategory));
Uday Bondhugula864d9e02019-01-23 17:16:2457
58/// A threshold in percent of additional computation allowed when fusing.
59static llvm::cl::opt<double> clFusionAddlComputeTolerance(
Uday Bondhugulaeee85362019-03-02 01:42:1360 "fusion-compute-tolerance",
Uday Bondhugulaa1dad3a2019-02-20 02:17:1961 llvm::cl::desc("Fractional increase in additional "
62 "computation tolerated while fusing"),
River Riddle75c21e12019-01-26 06:14:0463 llvm::cl::cat(clOptionsCategory));
MLIR Teamc4237ae2019-01-18 16:56:2764
Uday Bondhugula8be26272019-02-02 01:06:2265static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
Uday Bondhugulaeee85362019-03-02 01:42:1366 "fusion-fast-mem-space",
Uday Bondhugula8be26272019-02-02 01:06:2267 llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
68 llvm::cl::cat(clOptionsCategory));
69
Uday Bondhugulace7e59532019-03-08 17:21:5270// A local buffer of size less than or equal to this size is automatically
71// promoted to fast memory after producer-consumer fusion.
Uday Bondhugulad4b3ff12019-02-27 00:10:1972static llvm::cl::opt<unsigned long long> clFusionLocalBufThreshold(
Uday Bondhugulaeee85362019-03-02 01:42:1373 "fusion-local-buf-threshold",
Uday Bondhugulad4b3ff12019-02-27 00:10:1974 llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast "
Uday Bondhugula8be26272019-02-02 01:06:2275 "memory space"),
76 llvm::cl::cat(clOptionsCategory));
77
MLIR Teamf28e4df2018-11-01 14:26:0078namespace {
79
MLIR Team3b692302018-12-17 17:57:1480/// Loop fusion pass. This pass currently supports a greedy fusion policy,
81/// which fuses loop nests with single-writer/single-reader memref dependences
82/// with the goal of improving locality.
83
84// TODO(andydavis) Support fusion of source loop nests which write to multiple
85// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0086// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
87// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1488
River Riddlec6c53442019-02-27 18:59:2989struct LoopFusion : public FunctionPass<LoopFusion> {
Uday Bondhugulace7e59532019-03-08 17:21:5290 LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0,
91 bool maximalFusion = false)
River Riddlec6c53442019-02-27 18:59:2992 : localBufSizeThreshold(localBufSizeThreshold),
Uday Bondhugulace7e59532019-03-08 17:21:5293 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {}
MLIR Teamf28e4df2018-11-01 14:26:0094
River Riddleed5fe202019-02-28 22:50:4295 void runOnFunction() override;
Uday Bondhugula864d9e02019-01-23 17:16:2496
Uday Bondhugulad4b3ff12019-02-27 00:10:1997 // Any local buffers smaller than this size (in bytes) will be created in
Uday Bondhugula8be26272019-02-02 01:06:2298 // `fastMemorySpace` if provided.
Uday Bondhugulad4b3ff12019-02-27 00:10:1999 uint64_t localBufSizeThreshold;
Uday Bondhugula8be26272019-02-02 01:06:22100 Optional<unsigned> fastMemorySpace = None;
Uday Bondhugulace7e59532019-03-08 17:21:52101 // If true, ignore any additional (redundant) computation tolerance threshold
102 // that would have prevented fusion.
103 bool maximalFusion;
Uday Bondhugula8be26272019-02-02 01:06:22104
Uday Bondhugula864d9e02019-01-23 17:16:24105 // The amount of additional computation that is tolerated while fusing
106 // pair-wise as a fraction of the total computation.
107 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:00108};
109
MLIR Teamf28e4df2018-11-01 14:26:00110} // end anonymous namespace
111
River Riddlec6c53442019-02-27 18:59:29112FunctionPassBase *mlir::createLoopFusionPass(unsigned fastMemorySpace,
Uday Bondhugulace7e59532019-03-08 17:21:52113 uint64_t localBufSizeThreshold,
114 bool maximalFusion) {
115 return new LoopFusion(fastMemorySpace, localBufSizeThreshold, maximalFusion);
Uday Bondhugulad4b3ff12019-02-27 00:10:19116}
MLIR Teamf28e4df2018-11-01 14:26:00117
MLIR Team3b692302018-12-17 17:57:14118namespace {
MLIR Teamf28e4df2018-11-01 14:26:00119
MLIR Team3b692302018-12-17 17:57:14120// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:35121// operations, and whether or not an IfInst was encountered in the loop nest.
River Riddlebf9c3812019-02-05 00:24:44122struct LoopNestStateCollector {
Chris Lattnerd9b5bc82019-03-25 02:53:05123 SmallVector<AffineForOp, 4> forOps;
River Riddle99b87c92019-03-27 21:02:02124 SmallVector<Operation *, 4> loadOpInsts;
125 SmallVector<Operation *, 4> storeOpInsts;
River Riddle75553832019-01-29 05:23:53126 bool hasNonForRegion = false;
MLIR Team3b692302018-12-17 17:57:14127
River Riddle99b87c92019-03-27 21:02:02128 void collect(Operation *opToWalk) {
129 opToWalk->walk([&](Operation *op) {
130 if (op->isa<AffineForOp>())
131 forOps.push_back(op->cast<AffineForOp>());
132 else if (op->getNumRegions() != 0)
River Riddlebf9c3812019-02-05 00:24:44133 hasNonForRegion = true;
River Riddle99b87c92019-03-27 21:02:02134 else if (op->isa<LoadOp>())
135 loadOpInsts.push_back(op);
136 else if (op->isa<StoreOp>())
137 storeOpInsts.push_back(op);
River Riddlebf9c3812019-02-05 00:24:44138 });
MLIR Team3b692302018-12-17 17:57:14139 }
140};
141
MLIR Team71495d52019-01-22 21:23:37142// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
River Riddle99b87c92019-03-27 21:02:02143static bool isMemRefDereferencingOp(Operation &op) {
MLIR Team71495d52019-01-22 21:23:37144 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
145 op.isa<DmaWaitOp>())
146 return true;
147 return false;
148}
MLIR Teamd038e342019-03-01 19:50:25149
MLIR Team6892ffb2018-12-20 04:42:55150// MemRefDependenceGraph is a graph data structure where graph nodes are
River Riddle99b87c92019-03-27 21:02:02151// top-level operations in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55152// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27153// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55154// TODO(andydavis) Add a depth parameter to dependence graph construction.
155struct MemRefDependenceGraph {
156public:
157 // Node represents a node in the graph. A Node is either an entire loop nest
158 // rooted at the top level which contains loads/stores, or a top level
159 // load/store.
160 struct Node {
161 // The unique identifier of this node in the graph.
162 unsigned id;
Amit Sabne70a416d2019-04-09 16:17:40163 // The top-level statement which is (or contains) a load/store.
River Riddle99b87c92019-03-27 21:02:02164 Operation *op;
Chris Lattner5187cfc2018-12-28 05:21:41165 // List of load operations.
River Riddle99b87c92019-03-27 21:02:02166 SmallVector<Operation *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35167 // List of store op insts.
River Riddle99b87c92019-03-27 21:02:02168 SmallVector<Operation *, 4> stores;
169 Node(unsigned id, Operation *op) : id(id), op(op) {}
MLIR Team6892ffb2018-12-20 04:42:55170
171 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10172 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55173 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35174 for (auto *loadOpInst : loads) {
River Riddle96ebde92019-03-25 20:02:06175 if (memref == loadOpInst->cast<LoadOp>().getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55176 ++loadOpCount;
177 }
178 return loadOpCount;
179 }
180
181 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10182 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55183 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35184 for (auto *storeOpInst : stores) {
River Riddle96ebde92019-03-25 20:02:06185 if (memref == storeOpInst->cast<StoreOp>().getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55186 ++storeOpCount;
187 }
188 return storeOpCount;
189 }
MLIR Team58aa3832019-02-16 01:12:19190
MLIR Teamd038e342019-03-01 19:50:25191 // Returns all store ops in 'storeOps' which access 'memref'.
MLIR Team58aa3832019-02-16 01:12:19192 void getStoreOpsForMemref(Value *memref,
River Riddle99b87c92019-03-27 21:02:02193 SmallVectorImpl<Operation *> *storeOps) {
MLIR Team58aa3832019-02-16 01:12:19194 for (auto *storeOpInst : stores) {
River Riddle96ebde92019-03-25 20:02:06195 if (memref == storeOpInst->cast<StoreOp>().getMemRef())
MLIR Team58aa3832019-02-16 01:12:19196 storeOps->push_back(storeOpInst);
197 }
198 }
MLIR Teamd038e342019-03-01 19:50:25199
200 // Returns all load ops in 'loadOps' which access 'memref'.
201 void getLoadOpsForMemref(Value *memref,
River Riddle99b87c92019-03-27 21:02:02202 SmallVectorImpl<Operation *> *loadOps) {
MLIR Teamd038e342019-03-01 19:50:25203 for (auto *loadOpInst : loads) {
River Riddle96ebde92019-03-25 20:02:06204 if (memref == loadOpInst->cast<LoadOp>().getMemRef())
MLIR Teamd038e342019-03-01 19:50:25205 loadOps->push_back(loadOpInst);
206 }
207 }
208
209 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
210 // has at least one load and store operation.
211 void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
212 llvm::SmallDenseSet<Value *, 2> loadMemrefs;
213 for (auto *loadOpInst : loads) {
River Riddle96ebde92019-03-25 20:02:06214 loadMemrefs.insert(loadOpInst->cast<LoadOp>().getMemRef());
MLIR Teamd038e342019-03-01 19:50:25215 }
216 for (auto *storeOpInst : stores) {
River Riddle96ebde92019-03-25 20:02:06217 auto *memref = storeOpInst->cast<StoreOp>().getMemRef();
MLIR Teamd038e342019-03-01 19:50:25218 if (loadMemrefs.count(memref) > 0)
219 loadAndStoreMemrefSet->insert(memref);
220 }
221 }
MLIR Team6892ffb2018-12-20 04:42:55222 };
223
MLIR Teama0f3db402019-01-29 17:36:41224 // Edge represents a data dependece between nodes in the graph.
MLIR Team6892ffb2018-12-20 04:42:55225 struct Edge {
226 // The id of the node at the other end of the edge.
MLIR Team1e851912019-01-31 00:01:46227 // If this edge is stored in Edge = Node.inEdges[i], then
228 // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
229 // If this edge is stored in Edge = Node.outEdges[i], then
230 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
MLIR Team6892ffb2018-12-20 04:42:55231 unsigned id;
MLIR Teama0f3db402019-01-29 17:36:41232 // The SSA value on which this edge represents a dependence.
233 // If the value is a memref, then the dependence is between graph nodes
234 // which contain accesses to the same memref 'value'. If the value is a
235 // non-memref value, then the dependence is between a graph node which
236 // defines an SSA value and another graph node which uses the SSA value
River Riddle99b87c92019-03-27 21:02:02237 // (e.g. a constant operation defining a value which is used inside a loop
MLIR Teama0f3db402019-01-29 17:36:41238 // nest).
239 Value *value;
MLIR Team6892ffb2018-12-20 04:42:55240 };
241
242 // Map from node id to Node.
243 DenseMap<unsigned, Node> nodes;
244 // Map from node id to list of input edges.
245 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
246 // Map from node id to list of output edges.
247 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27248 // Map from memref to a count on the dependence edges associated with that
249 // memref.
250 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Teama0f3db402019-01-29 17:36:41251 // The next unique identifier to use for newly created graph nodes.
252 unsigned nextNodeId = 0;
MLIR Team6892ffb2018-12-20 04:42:55253
254 MemRefDependenceGraph() {}
255
256 // Initializes the dependence graph based on operations in 'f'.
257 // Returns true on success, false otherwise.
Chris Lattner46ade282019-03-26 01:02:49258 bool init(Function &f);
MLIR Team6892ffb2018-12-20 04:42:55259
260 // Returns the graph node for 'id'.
261 Node *getNode(unsigned id) {
262 auto it = nodes.find(id);
263 assert(it != nodes.end());
264 return &it->second;
265 }
266
MLIR Team9d30b362019-03-29 15:06:25267 // Returns the graph node for 'forOp'.
268 Node *getForOpNode(AffineForOp forOp) {
269 for (auto &idAndNode : nodes)
270 if (idAndNode.second.op == forOp.getOperation())
271 return &idAndNode.second;
272 return nullptr;
273 }
274
River Riddle99b87c92019-03-27 21:02:02275 // Adds a node with 'op' to the graph and returns its unique identifier.
276 unsigned addNode(Operation *op) {
277 Node node(nextNodeId++, op);
MLIR Teama0f3db402019-01-29 17:36:41278 nodes.insert({node.id, node});
279 return node.id;
280 }
281
MLIR Teamc4237ae2019-01-18 16:56:27282 // Remove node 'id' (and its associated edges) from graph.
283 void removeNode(unsigned id) {
284 // Remove each edge in 'inEdges[id]'.
285 if (inEdges.count(id) > 0) {
286 SmallVector<Edge, 2> oldInEdges = inEdges[id];
287 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41288 removeEdge(inEdge.id, id, inEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27289 }
290 }
291 // Remove each edge in 'outEdges[id]'.
292 if (outEdges.count(id) > 0) {
293 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
294 for (auto &outEdge : oldOutEdges) {
MLIR Teama0f3db402019-01-29 17:36:41295 removeEdge(id, outEdge.id, outEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27296 }
297 }
298 // Erase remaining node state.
299 inEdges.erase(id);
300 outEdges.erase(id);
301 nodes.erase(id);
302 }
303
MLIR Teamd7c82442019-01-30 23:53:41304 // Returns true if node 'id' writes to any memref which escapes (or is an
305 // argument to) the function/block. Returns false otherwise.
306 bool writesToLiveInOrEscapingMemrefs(unsigned id) {
MLIR Team71495d52019-01-22 21:23:37307 Node *node = getNode(id);
308 for (auto *storeOpInst : node->stores) {
River Riddle96ebde92019-03-25 20:02:06309 auto *memref = storeOpInst->cast<StoreOp>().getMemRef();
River Riddle99b87c92019-03-27 21:02:02310 auto *op = memref->getDefiningOp();
MLIR Team58aa3832019-02-16 01:12:19311 // Return true if 'memref' is a block argument.
River Riddle99b87c92019-03-27 21:02:02312 if (!op)
MLIR Teamd7c82442019-01-30 23:53:41313 return true;
MLIR Team58aa3832019-02-16 01:12:19314 // Return true if any use of 'memref' escapes the function.
River Riddleb4992772019-02-04 18:38:47315 for (auto &use : memref->getUses())
316 if (!isMemRefDereferencingOp(*use.getOwner()))
MLIR Teamd7c82442019-01-30 23:53:41317 return true;
MLIR Teamd7c82442019-01-30 23:53:41318 }
319 return false;
320 }
321
322 // Returns true if node 'id' can be removed from the graph. Returns false
323 // otherwise. A node can be removed from the graph iff the following
324 // conditions are met:
325 // *) The node does not write to any memref which escapes (or is a
326 // function/block argument).
327 // *) The node has no successors in the dependence graph.
328 bool canRemoveNode(unsigned id) {
329 if (writesToLiveInOrEscapingMemrefs(id))
330 return false;
331 Node *node = getNode(id);
332 for (auto *storeOpInst : node->stores) {
MLIR Teama0f3db402019-01-29 17:36:41333 // Return false if there exist out edges from 'id' on 'memref'.
River Riddle96ebde92019-03-25 20:02:06334 if (getOutEdgeCount(id, storeOpInst->cast<StoreOp>().getMemRef()) > 0)
MLIR Teama0f3db402019-01-29 17:36:41335 return false;
MLIR Team71495d52019-01-22 21:23:37336 }
MLIR Teama0f3db402019-01-29 17:36:41337 return true;
MLIR Team71495d52019-01-22 21:23:37338 }
339
MLIR Teamd038e342019-03-01 19:50:25340 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
341 // is for 'value' if non-null, or for any value otherwise. Returns false
342 // otherwise.
343 bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) {
MLIR Team27d067e2019-01-16 17:55:02344 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
345 return false;
346 }
347 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
MLIR Teamd038e342019-03-01 19:50:25348 return edge.id == dstId && (!value || edge.value == value);
MLIR Team27d067e2019-01-16 17:55:02349 });
350 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
MLIR Teamd038e342019-03-01 19:50:25351 return edge.id == srcId && (!value || edge.value == value);
MLIR Team27d067e2019-01-16 17:55:02352 });
353 return hasOutEdge && hasInEdge;
354 }
355
MLIR Teama0f3db402019-01-29 17:36:41356 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
357 void addEdge(unsigned srcId, unsigned dstId, Value *value) {
358 if (!hasEdge(srcId, dstId, value)) {
359 outEdges[srcId].push_back({dstId, value});
360 inEdges[dstId].push_back({srcId, value});
361 if (value->getType().isa<MemRefType>())
362 memrefEdgeCount[value]++;
MLIR Team27d067e2019-01-16 17:55:02363 }
MLIR Team6892ffb2018-12-20 04:42:55364 }
365
MLIR Teama0f3db402019-01-29 17:36:41366 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
367 void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team6892ffb2018-12-20 04:42:55368 assert(inEdges.count(dstId) > 0);
369 assert(outEdges.count(srcId) > 0);
MLIR Teama0f3db402019-01-29 17:36:41370 if (value->getType().isa<MemRefType>()) {
371 assert(memrefEdgeCount.count(value) > 0);
372 memrefEdgeCount[value]--;
373 }
MLIR Team6892ffb2018-12-20 04:42:55374 // Remove 'srcId' from 'inEdges[dstId]'.
375 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41376 if ((*it).id == srcId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55377 inEdges[dstId].erase(it);
378 break;
379 }
380 }
381 // Remove 'dstId' from 'outEdges[srcId]'.
382 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41383 if ((*it).id == dstId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55384 outEdges[srcId].erase(it);
385 break;
386 }
387 }
388 }
389
MLIR Teamd038e342019-03-01 19:50:25390 // Returns true if there is a path in the dependence graph from node 'srcId'
391 // to node 'dstId'. Returns false otherwise.
392 bool hasDependencePath(unsigned srcId, unsigned dstId) {
393 // Worklist state is: <node-id, next-output-edge-index-to-visit>
394 SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
395 worklist.push_back({srcId, 0});
396 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
397 while (!worklist.empty()) {
398 auto &idAndIndex = worklist.back();
399 // Return true if we have reached 'dstId'.
400 if (idAndIndex.first == dstId)
401 return true;
402 // Pop and continue if node has no out edges, or if all out edges have
403 // already been visited.
404 if (outEdges.count(idAndIndex.first) == 0 ||
405 idAndIndex.second == outEdges[idAndIndex.first].size()) {
406 worklist.pop_back();
407 continue;
408 }
409 // Get graph edge to traverse.
410 Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
411 // Increment next output edge index for 'idAndIndex'.
412 ++idAndIndex.second;
413 // Add node at 'edge.id' to worklist.
414 worklist.push_back({edge.id, 0});
415 }
416 return false;
417 }
418
MLIR Teama0f3db402019-01-29 17:36:41419 // Returns the input edge count for node 'id' and 'memref' from src nodes
MLIR Teamd038e342019-03-01 19:50:25420 // which access 'memref' with a store operation.
MLIR Teama0f3db402019-01-29 17:36:41421 unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55422 unsigned inEdgeCount = 0;
423 if (inEdges.count(id) > 0)
424 for (auto &inEdge : inEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41425 if (inEdge.value == memref) {
426 Node *srcNode = getNode(inEdge.id);
427 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
MLIR Teamd038e342019-03-01 19:50:25428 if (srcNode->getStoreOpCount(memref) > 0)
MLIR Teama0f3db402019-01-29 17:36:41429 ++inEdgeCount;
430 }
MLIR Team6892ffb2018-12-20 04:42:55431 return inEdgeCount;
432 }
433
MLIR Teamd038e342019-03-01 19:50:25434 // Returns the output edge count for node 'id' and 'memref' (if non-null),
435 // otherwise returns the total output edge count from node 'id'.
436 unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) {
MLIR Team6892ffb2018-12-20 04:42:55437 unsigned outEdgeCount = 0;
438 if (outEdges.count(id) > 0)
439 for (auto &outEdge : outEdges[id])
MLIR Teamd038e342019-03-01 19:50:25440 if (!memref || outEdge.value == memref)
MLIR Team6892ffb2018-12-20 04:42:55441 ++outEdgeCount;
442 return outEdgeCount;
443 }
444
River Riddle99b87c92019-03-27 21:02:02445 // Computes and returns an insertion point operation, before which the
MLIR Teama0f3db402019-01-29 17:36:41446 // the fused <srcId, dstId> loop nest can be inserted while preserving
447 // dependences. Returns nullptr if no such insertion point is found.
River Riddle99b87c92019-03-27 21:02:02448 Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
MLIR Team5c5739d2019-01-25 06:27:40449 if (outEdges.count(srcId) == 0)
River Riddle99b87c92019-03-27 21:02:02450 return getNode(dstId)->op;
MLIR Teama0f3db402019-01-29 17:36:41451
452 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
River Riddle99b87c92019-03-27 21:02:02453 SmallPtrSet<Operation *, 2> srcDepInsts;
MLIR Teama0f3db402019-01-29 17:36:41454 for (auto &outEdge : outEdges[srcId])
MLIR Teama78edcd2019-02-05 14:57:08455 if (outEdge.id != dstId)
River Riddle99b87c92019-03-27 21:02:02456 srcDepInsts.insert(getNode(outEdge.id)->op);
MLIR Teama0f3db402019-01-29 17:36:41457
458 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
River Riddle99b87c92019-03-27 21:02:02459 SmallPtrSet<Operation *, 2> dstDepInsts;
MLIR Teama0f3db402019-01-29 17:36:41460 for (auto &inEdge : inEdges[dstId])
MLIR Teama78edcd2019-02-05 14:57:08461 if (inEdge.id != srcId)
River Riddle99b87c92019-03-27 21:02:02462 dstDepInsts.insert(getNode(inEdge.id)->op);
MLIR Teama0f3db402019-01-29 17:36:41463
River Riddle99b87c92019-03-27 21:02:02464 Operation *srcNodeInst = getNode(srcId)->op;
465 Operation *dstNodeInst = getNode(dstId)->op;
MLIR Teama0f3db402019-01-29 17:36:41466
467 // Computing insertion point:
River Riddle99b87c92019-03-27 21:02:02468 // *) Walk all operation positions in Block operation list in the
469 // range (src, dst). For each operation 'op' visited in this search:
470 // *) Store in 'firstSrcDepPos' the first position where 'op' has a
MLIR Teama0f3db402019-01-29 17:36:41471 // dependence edge from 'srcNode'.
River Riddle99b87c92019-03-27 21:02:02472 // *) Store in 'lastDstDepPost' the last position where 'op' has a
MLIR Teama0f3db402019-01-29 17:36:41473 // dependence edge to 'dstNode'.
474 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
River Riddle99b87c92019-03-27 21:02:02475 // operation insertion point (or return null pointer if no such
MLIR Teama0f3db402019-01-29 17:36:41476 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
River Riddle99b87c92019-03-27 21:02:02477 SmallVector<Operation *, 2> depInsts;
MLIR Teama0f3db402019-01-29 17:36:41478 Optional<unsigned> firstSrcDepPos;
479 Optional<unsigned> lastDstDepPos;
480 unsigned pos = 0;
481 for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
482 it != Block::iterator(dstNodeInst); ++it) {
River Riddle99b87c92019-03-27 21:02:02483 Operation *op = &(*it);
484 if (srcDepInsts.count(op) > 0 && firstSrcDepPos == None)
MLIR Teama0f3db402019-01-29 17:36:41485 firstSrcDepPos = pos;
River Riddle99b87c92019-03-27 21:02:02486 if (dstDepInsts.count(op) > 0)
MLIR Teama0f3db402019-01-29 17:36:41487 lastDstDepPos = pos;
River Riddle99b87c92019-03-27 21:02:02488 depInsts.push_back(op);
MLIR Teama0f3db402019-01-29 17:36:41489 ++pos;
MLIR Team5c5739d2019-01-25 06:27:40490 }
MLIR Teama0f3db402019-01-29 17:36:41491
492 if (firstSrcDepPos.hasValue()) {
493 if (lastDstDepPos.hasValue()) {
494 if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
495 // No valid insertion point exists which preserves dependences.
496 return nullptr;
497 }
498 }
499 // Return the insertion point at 'firstSrcDepPos'.
500 return depInsts[firstSrcDepPos.getValue()];
501 }
502 // No dependence targets in range (or only dst deps in range), return
503 // 'dstNodInst' insertion point.
504 return dstNodeInst;
MLIR Team6892ffb2018-12-20 04:42:55505 }
506
MLIR Teama0f3db402019-01-29 17:36:41507 // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
508 // has been replaced in node at 'dstId' by a private memref.
509 void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
MLIR Team6892ffb2018-12-20 04:42:55510 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
511 if (inEdges.count(srcId) > 0) {
512 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
513 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41514 // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
515 if (inEdge.value != oldMemRef)
516 addEdge(inEdge.id, dstId, inEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55517 }
518 }
MLIR Teamc4237ae2019-01-18 16:56:27519 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55520 if (outEdges.count(srcId) > 0) {
521 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
522 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27523 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
524 if (outEdge.id == dstId)
MLIR Teama0f3db402019-01-29 17:36:41525 removeEdge(srcId, outEdge.id, outEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55526 }
527 }
MLIR Teama0f3db402019-01-29 17:36:41528 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
529 // replaced by a private memref). These edges could come from nodes
530 // other than 'srcId' which were removed in the previous step.
531 if (inEdges.count(dstId) > 0) {
532 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
533 for (auto &inEdge : oldInEdges)
534 if (inEdge.value == oldMemRef)
535 removeEdge(inEdge.id, dstId, inEdge.value);
536 }
MLIR Team6892ffb2018-12-20 04:42:55537 }
538
MLIR Teamd038e342019-03-01 19:50:25539 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
540 // of sibling node 'sidId' into node 'dstId'.
541 void updateEdges(unsigned sibId, unsigned dstId) {
542 // For each edge in 'inEdges[sibId]':
543 // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
544 // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
545 if (inEdges.count(sibId) > 0) {
546 SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
547 for (auto &inEdge : oldInEdges) {
548 addEdge(inEdge.id, dstId, inEdge.value);
549 removeEdge(inEdge.id, sibId, inEdge.value);
550 }
551 }
552
553 // For each edge in 'outEdges[sibId]' to node 'id'
554 // *) Add new edge from 'dstId' to 'outEdge.id'.
555 // *) Remove edge from 'sibId' to 'outEdge.id'.
556 if (outEdges.count(sibId) > 0) {
557 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
558 for (auto &outEdge : oldOutEdges) {
559 addEdge(dstId, outEdge.id, outEdge.value);
560 removeEdge(sibId, outEdge.id, outEdge.value);
561 }
562 }
563 }
564
MLIR Team6892ffb2018-12-20 04:42:55565 // Adds ops in 'loads' and 'stores' to node at 'id'.
River Riddle99b87c92019-03-27 21:02:02566 void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
567 const SmallVectorImpl<Operation *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55568 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35569 for (auto *loadOpInst : loads)
570 node->loads.push_back(loadOpInst);
571 for (auto *storeOpInst : stores)
572 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55573 }
574
MLIR Teamc4237ae2019-01-18 16:56:27575 void clearNodeLoadAndStores(unsigned id) {
576 Node *node = getNode(id);
577 node->loads.clear();
578 node->stores.clear();
579 }
580
MLIR Teamd038e342019-03-01 19:50:25581 // Calls 'callback' for each input edge incident to node 'id' which carries a
582 // memref dependence.
583 void forEachMemRefInputEdge(unsigned id,
584 const std::function<void(Edge)> &callback) {
585 if (inEdges.count(id) > 0)
586 forEachMemRefEdge(inEdges[id], callback);
587 }
Amit Sabne70a416d2019-04-09 16:17:40588
MLIR Teamd038e342019-03-01 19:50:25589 // Calls 'callback' for each output edge from node 'id' which carries a
590 // memref dependence.
591 void forEachMemRefOutputEdge(unsigned id,
592 const std::function<void(Edge)> &callback) {
593 if (outEdges.count(id) > 0)
594 forEachMemRefEdge(outEdges[id], callback);
595 }
Amit Sabne70a416d2019-04-09 16:17:40596
MLIR Teamd038e342019-03-01 19:50:25597 // Calls 'callback' for each edge in 'edges' which carries a memref
598 // dependence.
599 void forEachMemRefEdge(ArrayRef<Edge> edges,
600 const std::function<void(Edge)> &callback) {
601 for (auto &edge : edges) {
602 // Skip if 'edge' is not a memref dependence edge.
603 if (!edge.value->getType().isa<MemRefType>())
604 continue;
605 assert(nodes.count(edge.id) > 0);
606 // Skip if 'edge.id' is not a loop nest.
River Riddle99b87c92019-03-27 21:02:02607 if (!getNode(edge.id)->op->isa<AffineForOp>())
MLIR Teamd038e342019-03-01 19:50:25608 continue;
609 // Visit current input edge 'edge'.
610 callback(edge);
611 }
612 }
613
MLIR Team6892ffb2018-12-20 04:42:55614 void print(raw_ostream &os) const {
615 os << "\nMemRefDependenceGraph\n";
616 os << "\nNodes:\n";
617 for (auto &idAndNode : nodes) {
618 os << "Node: " << idAndNode.first << "\n";
619 auto it = inEdges.find(idAndNode.first);
620 if (it != inEdges.end()) {
621 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41622 os << " InEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55623 }
624 it = outEdges.find(idAndNode.first);
625 if (it != outEdges.end()) {
626 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41627 os << " OutEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55628 }
629 }
630 }
631 void dump() const { print(llvm::errs()); }
632};
633
River Riddle99b87c92019-03-27 21:02:02634// Intializes the data dependence graph by walking operations in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55635// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39636// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55637// dependence graph at a different depth.
Chris Lattner46ade282019-03-26 01:02:49638bool MemRefDependenceGraph::init(Function &f) {
Chris Lattner3f190312018-12-27 22:35:10639 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43640
641 // TODO: support multi-block functions.
Chris Lattner46ade282019-03-26 01:02:49642 if (f.getBlocks().size() != 1)
Chris Lattnerdffc5892018-12-29 23:33:43643 return false;
644
River Riddle99b87c92019-03-27 21:02:02645 DenseMap<Operation *, unsigned> forToNodeMap;
646 for (auto &op : f.front()) {
River Riddlec5ecf992019-05-11 22:56:50647 if (auto forOp = dyn_cast<AffineForOp>(op)) {
River Riddle5052bd82019-02-02 00:42:18648 // Create graph node 'id' to represent top-level 'forOp' and record
MLIR Team6892ffb2018-12-20 04:42:55649 // all loads and store accesses it contains.
650 LoopNestStateCollector collector;
River Riddle99b87c92019-03-27 21:02:02651 collector.collect(&op);
River Riddle832567b2019-03-25 17:14:34652 // Return false if a non 'affine.for' region was found (not currently
653 // supported).
River Riddle75553832019-01-29 05:23:53654 if (collector.hasNonForRegion)
MLIR Team6892ffb2018-12-20 04:42:55655 return false;
River Riddle99b87c92019-03-27 21:02:02656 Node node(nextNodeId++, &op);
Chris Lattner456ad6a2018-12-29 00:05:35657 for (auto *opInst : collector.loadOpInsts) {
658 node.loads.push_back(opInst);
River Riddle96ebde92019-03-25 20:02:06659 auto *memref = opInst->cast<LoadOp>().getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55660 memrefAccesses[memref].insert(node.id);
661 }
Chris Lattner456ad6a2018-12-29 00:05:35662 for (auto *opInst : collector.storeOpInsts) {
663 node.stores.push_back(opInst);
River Riddle96ebde92019-03-25 20:02:06664 auto *memref = opInst->cast<StoreOp>().getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55665 memrefAccesses[memref].insert(node.id);
666 }
River Riddle99b87c92019-03-27 21:02:02667 forToNodeMap[&op] = node.id;
MLIR Team6892ffb2018-12-20 04:42:55668 nodes.insert({node.id, node});
River Riddlec5ecf992019-05-11 22:56:50669 } else if (auto loadOp = dyn_cast<LoadOp>(op)) {
River Riddleb4992772019-02-04 18:38:47670 // Create graph node for top-level load op.
River Riddle99b87c92019-03-27 21:02:02671 Node node(nextNodeId++, &op);
672 node.loads.push_back(&op);
673 auto *memref = op.cast<LoadOp>().getMemRef();
River Riddleb4992772019-02-04 18:38:47674 memrefAccesses[memref].insert(node.id);
675 nodes.insert({node.id, node});
River Riddlec5ecf992019-05-11 22:56:50676 } else if (auto storeOp = dyn_cast<StoreOp>(op)) {
River Riddleb4992772019-02-04 18:38:47677 // Create graph node for top-level store op.
River Riddle99b87c92019-03-27 21:02:02678 Node node(nextNodeId++, &op);
679 node.stores.push_back(&op);
680 auto *memref = op.cast<StoreOp>().getMemRef();
River Riddleb4992772019-02-04 18:38:47681 memrefAccesses[memref].insert(node.id);
682 nodes.insert({node.id, node});
River Riddle99b87c92019-03-27 21:02:02683 } else if (op.getNumRegions() != 0) {
River Riddleb4992772019-02-04 18:38:47684 // Return false if another region is found (not currently supported).
685 return false;
River Riddle99b87c92019-03-27 21:02:02686 } else if (op.getNumResults() > 0 && !op.use_empty()) {
River Riddleb4992772019-02-04 18:38:47687 // Create graph node for top-level producer of SSA values, which
688 // could be used by loop nest nodes.
River Riddle99b87c92019-03-27 21:02:02689 Node node(nextNodeId++, &op);
River Riddleb4992772019-02-04 18:38:47690 nodes.insert({node.id, node});
MLIR Teama0f3db402019-01-29 17:36:41691 }
692 }
693
694 // Add dependence edges between nodes which produce SSA values and their
695 // users.
696 for (auto &idAndNode : nodes) {
697 const Node &node = idAndNode.second;
698 if (!node.loads.empty() || !node.stores.empty())
699 continue;
River Riddle99b87c92019-03-27 21:02:02700 auto *opInst = node.op;
MLIR Teama0f3db402019-01-29 17:36:41701 for (auto *value : opInst->getResults()) {
702 for (auto &use : value->getUses()) {
Chris Lattnerd9b5bc82019-03-25 02:53:05703 SmallVector<AffineForOp, 4> loops;
River Riddleb4992772019-02-04 18:38:47704 getLoopIVs(*use.getOwner(), &loops);
MLIR Teama0f3db402019-01-29 17:36:41705 if (loops.empty())
706 continue;
River Riddlef9d91532019-03-27 00:05:09707 assert(forToNodeMap.count(loops[0].getOperation()) > 0);
708 unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()];
MLIR Teama0f3db402019-01-29 17:36:41709 addEdge(node.id, userLoopNestId, value);
MLIR Team6892ffb2018-12-20 04:42:55710 }
711 }
MLIR Team6892ffb2018-12-20 04:42:55712 }
713
714 // Walk memref access lists and add graph edges between dependent nodes.
715 for (auto &memrefAndList : memrefAccesses) {
716 unsigned n = memrefAndList.second.size();
717 for (unsigned i = 0; i < n; ++i) {
718 unsigned srcId = memrefAndList.second[i];
719 bool srcHasStore =
720 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
721 for (unsigned j = i + 1; j < n; ++j) {
722 unsigned dstId = memrefAndList.second[j];
723 bool dstHasStore =
724 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
725 if (srcHasStore || dstHasStore)
726 addEdge(srcId, dstId, memrefAndList.first);
727 }
728 }
729 }
730 return true;
731}
732
MLIR Team38c2fe32019-01-14 19:26:25733namespace {
734
735// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
736// and operation count) for a loop nest up until the innermost loop body.
737struct LoopNestStats {
River Riddle5052bd82019-02-02 00:42:18738 // Map from AffineForOp to immediate child AffineForOps in its loop body.
River Riddle99b87c92019-03-27 21:02:02739 DenseMap<Operation *, SmallVector<AffineForOp, 2>> loopMap;
River Riddle5052bd82019-02-02 00:42:18740 // Map from AffineForOp to count of operations in its loop body.
River Riddle99b87c92019-03-27 21:02:02741 DenseMap<Operation *, uint64_t> opCountMap;
River Riddle5052bd82019-02-02 00:42:18742 // Map from AffineForOp to its constant trip count.
River Riddle99b87c92019-03-27 21:02:02743 DenseMap<Operation *, uint64_t> tripCountMap;
MLIR Team38c2fe32019-01-14 19:26:25744};
745
746// LoopNestStatsCollector walks a single loop nest and gathers per-loop
747// trip count and operation count statistics and records them in 'stats'.
River Riddlebf9c3812019-02-05 00:24:44748struct LoopNestStatsCollector {
MLIR Team38c2fe32019-01-14 19:26:25749 LoopNestStats *stats;
750 bool hasLoopWithNonConstTripCount = false;
751
752 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
753
River Riddle99b87c92019-03-27 21:02:02754 void collect(Operation *op) {
755 op->walk<AffineForOp>([&](AffineForOp forOp) {
River Riddlef9d91532019-03-27 00:05:09756 auto *forInst = forOp.getOperation();
River Riddle9c085402019-03-27 15:55:17757 auto *parentInst = forOp.getOperation()->getParentOp();
River Riddlebf9c3812019-02-05 00:24:44758 if (parentInst != nullptr) {
759 assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
760 // Add mapping to 'forOp' from its parent AffineForOp.
761 stats->loopMap[parentInst].push_back(forOp);
762 }
River Riddle5052bd82019-02-02 00:42:18763
River Riddle99b87c92019-03-27 21:02:02764 // Record the number of op operations in the body of 'forOp'.
River Riddlebf9c3812019-02-05 00:24:44765 unsigned count = 0;
766 stats->opCountMap[forInst] = 0;
River Riddle99b87c92019-03-27 21:02:02767 for (auto &op : *forOp.getBody()) {
768 if (!op.isa<AffineForOp>() && !op.isa<AffineIfOp>())
River Riddlebf9c3812019-02-05 00:24:44769 ++count;
770 }
771 stats->opCountMap[forInst] = count;
772 // Record trip count for 'forOp'. Set flag if trip count is not
773 // constant.
774 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
775 if (!maybeConstTripCount.hasValue()) {
776 hasLoopWithNonConstTripCount = true;
777 return;
778 }
779 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
780 });
MLIR Team38c2fe32019-01-14 19:26:25781 }
782};
783
River Riddle5052bd82019-02-02 00:42:18784// Computes the total cost of the loop nest rooted at 'forOp'.
MLIR Team38c2fe32019-01-14 19:26:25785// Currently, the total cost is computed by counting the total operation
786// instance count (i.e. total number of operations in the loop bodyloop
787// operation count * loop trip count) for the entire loop nest.
788// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
789// specified in the map when computing the total op instance count.
Amit Sabne70a416d2019-04-09 16:17:40790// NOTEs: 1) This is used to compute the cost of computation slices, which are
MLIR Team38c2fe32019-01-14 19:26:25791// sliced along the iteration dimension, and thus reduce the trip count.
River Riddle5052bd82019-02-02 00:42:18792// If 'computeCostMap' is non-null, the total op count for forOps specified
MLIR Team38c2fe32019-01-14 19:26:25793// in the map is increased (not overridden) by adding the op count from the
794// map to the existing op count for the for loop. This is done before
795// multiplying by the loop's trip count, and is used to model the cost of
796// inserting a sliced loop nest of known cost into the loop's body.
Amit Sabne70a416d2019-04-09 16:17:40797// 2) This is also used to compute the cost of fusing a slice of some loop nest
MLIR Team38c2fe32019-01-14 19:26:25798// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24799static int64_t getComputeCost(
River Riddle99b87c92019-03-27 21:02:02800 Operation *forInst, LoopNestStats *stats,
801 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
802 DenseMap<Operation *, int64_t> *computeCostMap) {
River Riddle5052bd82019-02-02 00:42:18803 // 'opCount' is the total number operations in one iteration of 'forOp' body
Uday Bondhugula864d9e02019-01-23 17:16:24804 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25805 if (stats->loopMap.count(forInst) > 0) {
River Riddle5052bd82019-02-02 00:42:18806 for (auto childForOp : stats->loopMap[forInst]) {
River Riddlef9d91532019-03-27 00:05:09807 opCount += getComputeCost(childForOp.getOperation(), stats,
River Riddle5052bd82019-02-02 00:42:18808 tripCountOverrideMap, computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:25809 }
810 }
811 // Add in additional op instances from slice (if specified in map).
812 if (computeCostMap != nullptr) {
813 auto it = computeCostMap->find(forInst);
814 if (it != computeCostMap->end()) {
815 opCount += it->second;
816 }
817 }
818 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24819 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25820 if (tripCountOverrideMap != nullptr) {
821 auto it = tripCountOverrideMap->find(forInst);
822 if (it != tripCountOverrideMap->end()) {
823 tripCount = it->second;
824 }
825 }
826 // Returns the total number of dynamic instances of operations in loop body.
827 return tripCount * opCount;
828}
829
830} // end anonymous namespace
831
Uday Bondhugula7aa60a32019-02-27 01:32:47832// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
MLIR Team27d067e2019-01-16 17:55:02833static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00834 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
835 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02836 assert(lbMap.getNumDims() == ubMap.getNumDims());
837 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
MLIR Team27d067e2019-01-16 17:55:02838 AffineExpr lbExpr(lbMap.getResult(0));
839 AffineExpr ubExpr(ubMap.getResult(0));
840 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
841 lbMap.getNumSymbols());
842 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
843 if (!cExpr)
844 return None;
845 return cExpr.getValue();
846}
847
River Riddle5052bd82019-02-02 00:42:18848// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
MLIR Team38c2fe32019-01-14 19:26:25849// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
850// Returns true on success, false otherwise (if a non-constant trip count
851// was encountered).
852// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02853static bool buildSliceTripCountMap(
River Riddle99b87c92019-03-27 21:02:02854 Operation *srcOpInst, ComputationSliceState *sliceState,
855 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
Chris Lattnerd9b5bc82019-03-25 02:53:05856 SmallVector<AffineForOp, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02857 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25858 unsigned numSrcLoopIVs = srcLoopIVs.size();
River Riddle5052bd82019-02-02 00:42:18859 // Populate map from AffineForOp -> trip count
MLIR Team38c2fe32019-01-14 19:26:25860 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
861 AffineMap lbMap = sliceState->lbs[i];
862 AffineMap ubMap = sliceState->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17863 if (lbMap == AffineMap() || ubMap == AffineMap()) {
MLIR Team38c2fe32019-01-14 19:26:25864 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
River Riddleaf1abcc2019-03-25 18:13:31865 if (srcLoopIVs[i].hasConstantLowerBound() &&
866 srcLoopIVs[i].hasConstantUpperBound()) {
River Riddlef9d91532019-03-27 00:05:09867 (*tripCountMap)[srcLoopIVs[i].getOperation()] =
River Riddleaf1abcc2019-03-25 18:13:31868 srcLoopIVs[i].getConstantUpperBound() -
869 srcLoopIVs[i].getConstantLowerBound();
MLIR Team38c2fe32019-01-14 19:26:25870 continue;
871 }
872 return false;
873 }
MLIR Team27d067e2019-01-16 17:55:02874 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
875 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25876 return false;
River Riddlef9d91532019-03-27 00:05:09877 (*tripCountMap)[srcLoopIVs[i].getOperation()] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25878 }
879 return true;
880}
881
MLIR Team27d067e2019-01-16 17:55:02882// Removes load operations from 'srcLoads' which operate on 'memref', and
883// adds them to 'dstLoads'.
River Riddle99b87c92019-03-27 21:02:02884static void moveLoadsAccessingMemrefTo(Value *memref,
885 SmallVectorImpl<Operation *> *srcLoads,
886 SmallVectorImpl<Operation *> *dstLoads) {
MLIR Team27d067e2019-01-16 17:55:02887 dstLoads->clear();
River Riddle99b87c92019-03-27 21:02:02888 SmallVector<Operation *, 4> srcLoadsToKeep;
MLIR Team27d067e2019-01-16 17:55:02889 for (auto *load : *srcLoads) {
River Riddle96ebde92019-03-25 20:02:06890 if (load->cast<LoadOp>().getMemRef() == memref)
MLIR Team27d067e2019-01-16 17:55:02891 dstLoads->push_back(load);
892 else
893 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25894 }
MLIR Team27d067e2019-01-16 17:55:02895 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25896}
897
MLIR Team27d067e2019-01-16 17:55:02898// Returns the innermost common loop depth for the set of operations in 'ops'.
River Riddle99b87c92019-03-27 21:02:02899static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) {
MLIR Team27d067e2019-01-16 17:55:02900 unsigned numOps = ops.size();
901 assert(numOps > 0);
902
Chris Lattnerd9b5bc82019-03-25 02:53:05903 std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
MLIR Team27d067e2019-01-16 17:55:02904 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
905 for (unsigned i = 0; i < numOps; ++i) {
906 getLoopIVs(*ops[i], &loops[i]);
907 loopDepthLimit =
908 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25909 }
MLIR Team27d067e2019-01-16 17:55:02910
911 unsigned loopDepth = 0;
912 for (unsigned d = 0; d < loopDepthLimit; ++d) {
913 unsigned i;
914 for (i = 1; i < numOps; ++i) {
River Riddle5052bd82019-02-02 00:42:18915 if (loops[i - 1][d] != loops[i][d])
MLIR Team27d067e2019-01-16 17:55:02916 break;
MLIR Team27d067e2019-01-16 17:55:02917 }
918 if (i != numOps)
919 break;
920 ++loopDepth;
921 }
922 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25923}
924
MLIR Teamd7c82442019-01-30 23:53:41925// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
926// and 'storeOpInsts' are satisfied.
River Riddle99b87c92019-03-27 21:02:02927static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
928 ArrayRef<Operation *> storeOpInsts) {
MLIR Teamd7c82442019-01-30 23:53:41929 // Merge loads and stores into the same array.
River Riddle99b87c92019-03-27 21:02:02930 SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
MLIR Teamd7c82442019-01-30 23:53:41931 ops.append(storeOpInsts.begin(), storeOpInsts.end());
932
933 // Compute the innermost common loop depth for loads and stores.
934 unsigned loopDepth = getInnermostCommonLoopDepth(ops);
935
936 // Return common loop depth for loads if there are no store ops.
937 if (storeOpInsts.empty())
938 return loopDepth;
939
940 // Check dependences on all pairs of ops in 'ops' and store the minimum
941 // loop depth at which a dependence is satisfied.
942 for (unsigned i = 0, e = ops.size(); i < e; ++i) {
943 auto *srcOpInst = ops[i];
944 MemRefAccess srcAccess(srcOpInst);
945 for (unsigned j = 0; j < e; ++j) {
946 auto *dstOpInst = ops[j];
947 MemRefAccess dstAccess(dstOpInst);
948
949 unsigned numCommonLoops =
950 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
951 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
952 FlatAffineConstraints dependenceConstraints;
953 // TODO(andydavis) Cache dependence analysis results, check cache here.
954 if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
955 &dependenceConstraints,
956 /*dependenceComponents=*/nullptr)) {
957 // Store minimum loop depth and break because we want the min 'd' at
958 // which there is a dependence.
959 loopDepth = std::min(loopDepth, d - 1);
960 break;
961 }
962 }
963 }
964 }
965 return loopDepth;
966}
967
MLIR Team8f5f2c72019-02-15 17:32:18968// Compute loop interchange permutation:
Andy Davis44f6dff2019-04-09 19:21:28969// *) Computes dependence components between all op pairs of ops in loop nest
970// rooted at 'loops[0]', for loop depths in range [1, 'maxLoopDepth'].
MLIR Team8f5f2c72019-02-15 17:32:18971// *) Classifies the outermost 'maxLoopDepth' loops surrounding 'ops' as either
972// parallel or sequential.
973// *) Computes the loop permutation which sinks sequential loops deeper into
974// the loop nest, while preserving the relative order between other loops.
975// *) Checks each dependence component against the permutation to see if the
Amit Sabne70a416d2019-04-09 16:17:40976// desired loop interchange would violate dependences by making the
MLIR Team8f5f2c72019-02-15 17:32:18977// dependence componenent lexicographically negative.
978// TODO(andydavis) Move this function to LoopUtils.
979static bool
Andy Davis44f6dff2019-04-09 19:21:28980computeLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
MLIR Team8f5f2c72019-02-15 17:32:18981 SmallVectorImpl<unsigned> *loopPermMap) {
Andy Davis44f6dff2019-04-09 19:21:28982 assert(loops.size() > 1);
983 // Gather dependence components for dependences between all ops in loop nest
984 // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
985 unsigned maxLoopDepth = loops.size();
MLIR Team8f5f2c72019-02-15 17:32:18986 std::vector<llvm::SmallVector<DependenceComponent, 2>> depCompsVec;
Andy Davis44f6dff2019-04-09 19:21:28987 getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
988 // Mark loops as either parallel or sequential.
MLIR Team8f5f2c72019-02-15 17:32:18989 llvm::SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
Andy Davis44f6dff2019-04-09 19:21:28990 for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
991 llvm::SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
992 assert(depComps.size() >= maxLoopDepth);
993 for (unsigned j = 0; j < maxLoopDepth; ++j) {
994 DependenceComponent &depComp = depComps[j];
995 assert(depComp.lb.hasValue() && depComp.ub.hasValue());
996 if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0)
997 isParallelLoop[j] = false;
MLIR Team8f5f2c72019-02-15 17:32:18998 }
999 }
Amit Sabne70a416d2019-04-09 16:17:401000
MLIR Team8f5f2c72019-02-15 17:32:181001 // Count the number of parallel loops.
1002 unsigned numParallelLoops = 0;
1003 for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i)
1004 if (isParallelLoop[i])
1005 ++numParallelLoops;
1006
1007 // Compute permutation of loops that sinks sequential loops (and thus raises
1008 // parallel loops) while preserving relative order.
1009 llvm::SmallVector<unsigned, 4> loopPermMapInv;
1010 loopPermMapInv.resize(maxLoopDepth);
1011 loopPermMap->resize(maxLoopDepth);
1012 unsigned nextSequentialLoop = numParallelLoops;
1013 unsigned nextParallelLoop = 0;
1014 for (unsigned i = 0; i < maxLoopDepth; ++i) {
1015 if (isParallelLoop[i]) {
1016 (*loopPermMap)[i] = nextParallelLoop;
1017 loopPermMapInv[nextParallelLoop++] = i;
1018 } else {
1019 (*loopPermMap)[i] = nextSequentialLoop;
1020 loopPermMapInv[nextSequentialLoop++] = i;
1021 }
1022 }
1023
1024 // Check each dependence component against the permutation to see if the
1025 // desired loop interchange permutation would make the dependence vectors
1026 // lexicographically negative.
1027 // Example 1: [-1, 1][0, 0]
1028 // Example 2: [0, 0][-1, 1]
1029 for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
1030 llvm::SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
1031 assert(depComps.size() >= maxLoopDepth);
1032 // Check if the first non-zero dependence component is positive.
1033 for (unsigned j = 0; j < maxLoopDepth; ++j) {
1034 unsigned permIndex = loopPermMapInv[j];
1035 assert(depComps[permIndex].lb.hasValue());
1036 int64_t depCompLb = depComps[permIndex].lb.getValue();
1037 if (depCompLb > 0)
1038 break;
1039 if (depCompLb < 0)
1040 return false;
1041 }
1042 }
1043 return true;
1044}
1045
1046// Sinks all sequential loops to the innermost levels (while preserving
1047// relative order among them) and moves all parallel loops to the
1048// outermost (while again preserving relative order among them).
1049// This can increase the loop depth at which we can fuse a slice, since we are
1050// pushing loop carried dependence to a greater depth in the loop nest.
1051static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
River Riddle99b87c92019-03-27 21:02:021052 assert(node->op->isa<AffineForOp>());
Chris Lattnerd9b5bc82019-03-25 02:53:051053 SmallVector<AffineForOp, 4> loops;
River Riddle99b87c92019-03-27 21:02:021054 AffineForOp curr = node->op->cast<AffineForOp>();
MLIR Team0cd589c2019-04-04 22:19:171055 getPerfectlyNestedLoops(loops, curr);
MLIR Team8f5f2c72019-02-15 17:32:181056 if (loops.size() < 2)
1057 return;
1058
MLIR Team8f5f2c72019-02-15 17:32:181059 // Compute loop permutation in 'loopPermMap'.
1060 llvm::SmallVector<unsigned, 4> loopPermMap;
Andy Davis44f6dff2019-04-09 19:21:281061 if (!computeLoopInterchangePermutation(loops, &loopPermMap))
MLIR Team8f5f2c72019-02-15 17:32:181062 return;
1063
1064 int loopNestRootIndex = -1;
1065 for (int i = loops.size() - 1; i >= 0; --i) {
1066 int permIndex = static_cast<int>(loopPermMap[i]);
1067 // Store the index of the for loop which will be the new loop nest root.
1068 if (permIndex == 0)
1069 loopNestRootIndex = i;
1070 if (permIndex > i) {
1071 // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest.
1072 sinkLoop(loops[i], permIndex - i);
1073 }
1074 }
1075 assert(loopNestRootIndex != -1 && "invalid root index");
River Riddle99b87c92019-03-27 21:02:021076 node->op = loops[loopNestRootIndex].getOperation();
MLIR Team8f5f2c72019-02-15 17:32:181077}
1078
Uday Bondhugula8be26272019-02-02 01:06:221079// TODO(mlir-team): improve/complete this when we have target data.
1080unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
1081 auto elementType = memRefType.getElementType();
1082
1083 unsigned sizeInBits;
1084 if (elementType.isIntOrFloat()) {
1085 sizeInBits = elementType.getIntOrFloatBitWidth();
1086 } else {
1087 auto vectorType = elementType.cast<VectorType>();
1088 sizeInBits =
1089 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1090 }
1091 return llvm::divideCeil(sizeInBits, 8);
1092}
1093
MLIR Teamc4237ae2019-01-18 16:56:271094// Creates and returns a private (single-user) memref for fused loop rooted
River Riddle5052bd82019-02-02 00:42:181095// at 'forOp', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:521096// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
1097// TODO(bondhugula): consider refactoring the common code from generateDma and
1098// this one.
River Riddle99b87c92019-03-27 21:02:021099static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
Uday Bondhugula8be26272019-02-02 01:06:221100 unsigned dstLoopDepth,
1101 Optional<unsigned> fastMemorySpace,
Uday Bondhugulad4b3ff12019-02-27 00:10:191102 uint64_t localBufSizeThreshold) {
River Riddlef9d91532019-03-27 00:05:091103 auto *forInst = forOp.getOperation();
River Riddle5052bd82019-02-02 00:42:181104
1105 // Create builder to insert alloc op just before 'forOp'.
MLIR Teamc4237ae2019-01-18 16:56:271106 FuncBuilder b(forInst);
1107 // Builder to create constants at the top level.
1108 FuncBuilder top(forInst->getFunction());
1109 // Create new memref type based on slice bounds.
River Riddle96ebde92019-03-25 20:02:061110 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>().getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271111 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
1112 unsigned rank = oldMemRefType.getRank();
1113
Uday Bondhugula94a03f82019-01-22 21:58:521114 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
Uday Bondhugula0f504142019-02-04 21:48:441115 MemRefRegion region(srcStoreOpInst->getLoc());
River Riddle1e55ae12019-03-08 06:14:471116 bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
MLIR Teamd42ef782019-03-04 19:01:251117 (void)validRegion;
1118 assert(validRegion && "unexpected memref region failure");
River Riddle6859f332019-01-23 22:39:451119 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:271120 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:521121 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:271122 lbs.reserve(rank);
1123 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:521124 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:271125 Optional<int64_t> numElements =
Uday Bondhugula0f504142019-02-04 21:48:441126 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
Uday Bondhugula8be26272019-02-02 01:06:221127 assert(numElements.hasValue() &&
1128 "non-constant number of elts in local buffer");
MLIR Teamc4237ae2019-01-18 16:56:271129
Uday Bondhugula0f504142019-02-04 21:48:441130 const FlatAffineConstraints *cst = region.getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:521131 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
1132 // on; this would correspond to loop IVs surrounding the level at which the
1133 // slice is being materialized.
1134 SmallVector<Value *, 8> outerIVs;
1135 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
1136
1137 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:271138 SmallVector<AffineExpr, 4> offsets;
1139 offsets.reserve(rank);
1140 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:521141 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
1142
MLIR Teamc4237ae2019-01-18 16:56:271143 AffineExpr offset = top.getAffineConstantExpr(0);
1144 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
1145 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
1146 }
Uday Bondhugula94a03f82019-01-22 21:58:521147 assert(lbDivisors[d] > 0);
1148 offset =
1149 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:271150 offsets.push_back(offset);
1151 }
1152
1153 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
1154 // by 'srcStoreOpInst'.
Uday Bondhugula8be26272019-02-02 01:06:221155 uint64_t bufSize =
1156 getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
1157 unsigned newMemSpace;
Uday Bondhugulad4b3ff12019-02-27 00:10:191158 if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) {
Uday Bondhugula8be26272019-02-02 01:06:221159 newMemSpace = fastMemorySpace.getValue();
1160 } else {
1161 newMemSpace = oldMemRefType.getMemorySpace();
1162 }
1163 auto newMemRefType = top.getMemRefType(
1164 newShape, oldMemRefType.getElementType(), {}, newMemSpace);
MLIR Teamc4237ae2019-01-18 16:56:271165 // Gather alloc operands for the dynamic dimensions of the memref.
1166 SmallVector<Value *, 4> allocOperands;
1167 unsigned dynamicDimCount = 0;
1168 for (auto dimSize : oldMemRefType.getShape()) {
1169 if (dimSize == -1)
1170 allocOperands.push_back(
River Riddleaf1abcc2019-03-25 18:13:311171 top.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++));
MLIR Teamc4237ae2019-01-18 16:56:271172 }
1173
River Riddle5052bd82019-02-02 00:42:181174 // Create new private memref for fused loop 'forOp'.
MLIR Teama0f3db402019-01-29 17:36:411175 // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
1176 // consumer loop nests to reduce their live range. Currently they are added
1177 // at the beginning of the function, because loop nests can be reordered
1178 // during the fusion pass.
MLIR Teamc4237ae2019-01-18 16:56:271179 Value *newMemRef =
River Riddleaf1abcc2019-03-25 18:13:311180 top.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands);
MLIR Teamc4237ae2019-01-18 16:56:271181
1182 // Build an AffineMap to remap access functions based on lower bound offsets.
1183 SmallVector<AffineExpr, 4> remapExprs;
1184 remapExprs.reserve(rank);
1185 unsigned zeroOffsetCount = 0;
1186 for (unsigned i = 0; i < rank; i++) {
1187 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
1188 if (constExpr.getValue() == 0)
1189 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:521190 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
1191
1192 auto remapExpr =
1193 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
1194 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:271195 }
Uday Bondhugula94a03f82019-01-22 21:58:521196 auto indexRemap =
1197 zeroOffsetCount == rank
Nicolas Vasilache0e7a8a92019-01-26 18:41:171198 ? AffineMap()
Uday Bondhugula94a03f82019-01-22 21:58:521199 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:271200 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:521201 bool ret =
1202 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
1203 /*extraOperands=*/outerIVs,
River Riddleaf1abcc2019-03-25 18:13:311204 /*domInstFilter=*/&*forOp.getBody()->begin());
Uday Bondhugula94a03f82019-01-22 21:58:521205 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:371206 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:271207 return newMemRef;
1208}
1209
Amit Sabne70a416d2019-04-09 16:17:401210// Return the number of iterations in the given slice.
Uday Bondhugula864d9e02019-01-23 17:16:241211static uint64_t getSliceIterationCount(
River Riddle99b87c92019-03-27 21:02:021212 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
Uday Bondhugula864d9e02019-01-23 17:16:241213 uint64_t iterCount = 1;
1214 for (const auto &count : sliceTripCountMap) {
1215 iterCount *= count.second;
1216 }
1217 return iterCount;
1218}
1219
MLIR Team58aa3832019-02-16 01:12:191220// Checks if node 'srcId' (which writes to a live out memref), can be safely
1221// fused into node 'dstId'. Returns true if the following conditions are met:
Andy Davis7c1fc9e2019-04-02 13:37:401222// *) 'srcNode' only writes to live out 'memref'.
Amit Sabne70a416d2019-04-09 16:17:401223// *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId').
Andy Davis7c1fc9e2019-04-02 13:37:401224// *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's
1225// write region to 'memref'.
MLIR Team58aa3832019-02-16 01:12:191226// TODO(andydavis) Generalize this to handle more live in/out cases.
1227static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
1228 Value *memref,
1229 MemRefDependenceGraph *mdg) {
1230 auto *srcNode = mdg->getNode(srcId);
1231 auto *dstNode = mdg->getNode(dstId);
1232
Andy Davis7c1fc9e2019-04-02 13:37:401233 // Gather all memrefs from 'srcNode' store ops.
1234 DenseSet<Value *> storeMemrefs;
1235 for (auto *storeOpInst : srcNode->stores) {
1236 storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
1237 }
MLIR Team58aa3832019-02-16 01:12:191238 // Return false if any of the following are true:
1239 // *) 'srcNode' writes to a live in/out memref other than 'memref'.
1240 // *) 'srcNode' has more than one output edge on 'memref'.
Andy Davis7c1fc9e2019-04-02 13:37:401241 // Check that all stores are to the same memref.
1242 if (storeMemrefs.size() != 1 ||
1243 mdg->getOutEdgeCount(srcNode->id, memref) != 1)
MLIR Team58aa3832019-02-16 01:12:191244 return false;
1245 // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
1246 auto *srcStoreOpInst = srcNode->stores.front();
1247 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
River Riddle1e55ae12019-03-08 06:14:471248 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
MLIR Teamd42ef782019-03-04 19:01:251249 LLVM_DEBUG(llvm::dbgs()
1250 << "Unable to compute MemRefRegion for source operation\n.");
1251 return false;
1252 }
MLIR Team58aa3832019-02-16 01:12:191253 SmallVector<int64_t, 4> srcShape;
1254 // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
1255 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
1256 Optional<int64_t> srcNumElements =
1257 srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
1258 if (!srcNumElements.hasValue())
1259 return false;
1260
Andy Davis7c1fc9e2019-04-02 13:37:401261 // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'.
MLIR Team9d9675f2019-03-28 21:54:491262 // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for
1263 // each store op in 'dstStoreOps').
Andy Davis7c1fc9e2019-04-02 13:37:401264 SmallVector<Operation *, 2> dstStoreOps;
1265 dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
1266 SmallVector<Operation *, 2> dstLoadOps;
1267 dstNode->getLoadOpsForMemref(memref, &dstLoadOps);
1268
1269 auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0];
1270 MemRefRegion dstRegion(dstOpInst->getLoc());
1271 if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) {
MLIR Teamd42ef782019-03-04 19:01:251272 LLVM_DEBUG(llvm::dbgs()
1273 << "Unable to compute MemRefRegion for dest operation\n.");
1274 return false;
1275 }
MLIR Team58aa3832019-02-16 01:12:191276 SmallVector<int64_t, 4> dstShape;
Andy Davis7c1fc9e2019-04-02 13:37:401277 // Query 'dstRegion' for 'dstShape' and 'dstNumElements'.
1278 // by 'dstOpInst' at depth 'dstLoopDepth'.
MLIR Team58aa3832019-02-16 01:12:191279 Optional<int64_t> dstNumElements =
Andy Davis7c1fc9e2019-04-02 13:37:401280 dstRegion.getConstantBoundingSizeAndShape(&dstShape);
MLIR Team58aa3832019-02-16 01:12:191281 if (!dstNumElements.hasValue())
1282 return false;
1283
1284 // Return false if write region is not a superset of 'srcNodes' write
1285 // region to 'memref'.
1286 // TODO(andydavis) Check the shape and lower bounds here too.
1287 if (srcNumElements != dstNumElements)
1288 return false;
1289 return true;
1290}
1291
MLIR Teamc1ff9e82019-03-06 04:33:301292// Computes the union of all slice bounds computed between 'srcOpInst'
1293// and each load op in 'dstLoadOpInsts' at 'dstLoopDepth', and returns
1294// the union in 'sliceState'. Returns true on success, false otherwise.
1295// TODO(andydavis) Move this to a loop fusion utility function.
River Riddle99b87c92019-03-27 21:02:021296static bool getSliceUnion(Operation *srcOpInst,
1297 ArrayRef<Operation *> dstLoadOpInsts,
MLIR Teamc1ff9e82019-03-06 04:33:301298 unsigned numSrcLoopIVs, unsigned dstLoopDepth,
1299 ComputationSliceState *sliceState) {
1300 MemRefAccess srcAccess(srcOpInst);
1301 unsigned numDstLoadOpInsts = dstLoadOpInsts.size();
1302 assert(numDstLoadOpInsts > 0);
1303 // Compute the slice bounds between 'srcOpInst' and 'dstLoadOpInsts[0]'.
River Riddle1e55ae12019-03-08 06:14:471304 if (failed(mlir::getBackwardComputationSliceState(
1305 srcAccess, MemRefAccess(dstLoadOpInsts[0]), dstLoopDepth,
1306 sliceState)))
MLIR Teamc1ff9e82019-03-06 04:33:301307 return false;
1308 // Handle the common case of one dst load without a copy.
1309 if (numDstLoadOpInsts == 1)
1310 return true;
1311
1312 // Initialize 'sliceUnionCst' with the bounds computed in previous step.
1313 FlatAffineConstraints sliceUnionCst;
River Riddle1e55ae12019-03-08 06:14:471314 if (failed(sliceState->getAsConstraints(&sliceUnionCst))) {
MLIR Teamc1ff9e82019-03-06 04:33:301315 LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n.");
1316 return false;
1317 }
1318
1319 // Compute the union of slice bounds between 'srcOpInst' and each load
1320 // in 'dstLoadOpInsts' in range [1, numDstLoadOpInsts), in 'sliceUnionCst'.
1321 for (unsigned i = 1; i < numDstLoadOpInsts; ++i) {
1322 MemRefAccess dstAccess(dstLoadOpInsts[i]);
1323 // Compute slice bounds for 'srcOpInst' and 'dstLoadOpInsts[i]'.
1324 ComputationSliceState tmpSliceState;
River Riddle1e55ae12019-03-08 06:14:471325 if (failed(mlir::getBackwardComputationSliceState(
1326 srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) {
MLIR Teamc1ff9e82019-03-06 04:33:301327 LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n.");
1328 return false;
1329 }
1330
1331 // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
1332 FlatAffineConstraints tmpSliceCst;
River Riddle1e55ae12019-03-08 06:14:471333 if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
MLIR Teamc1ff9e82019-03-06 04:33:301334 LLVM_DEBUG(llvm::dbgs()
1335 << "Unable to compute slice bound constraints\n.");
1336 return false;
1337 }
1338 // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
River Riddle1e55ae12019-03-08 06:14:471339 if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
MLIR Teamc1ff9e82019-03-06 04:33:301340 LLVM_DEBUG(llvm::dbgs()
1341 << "Unable to compute union bounding box of slice bounds.\n.");
1342 return false;
1343 }
1344 }
1345
1346 // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
1347 sliceUnionCst.convertLoopIVSymbolsToDims();
1348
1349 sliceState->clearBounds();
1350 sliceState->lbs.resize(numSrcLoopIVs, AffineMap());
1351 sliceState->ubs.resize(numSrcLoopIVs, AffineMap());
1352
1353 // Get slice bounds from slice union constraints 'sliceUnionCst'.
1354 sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOpInst->getContext(),
1355 &sliceState->lbs, &sliceState->ubs);
1356 // Add slice bound operands of union.
1357 SmallVector<Value *, 4> sliceBoundOperands;
1358 sliceUnionCst.getIdValues(numSrcLoopIVs,
1359 sliceUnionCst.getNumDimAndSymbolIds(),
1360 &sliceBoundOperands);
1361 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1362 // canonicalization.
1363 sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands);
1364 sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands);
1365 return true;
1366}
1367
MLIR Team27d067e2019-01-16 17:55:021368// Checks the profitability of fusing a backwards slice of the loop nest
MLIR Teamd7c82442019-01-30 23:53:411369// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
MLIR Teamd038e342019-03-01 19:50:251370// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
1371// the memref being produced and consumed, which is an input to the cost model.
1372// For producer-constumer fusion, 'srcStoreOpInst' will be the same as
1373// 'srcOpInst', as we are slicing w.r.t to that producer.
1374// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
1375// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
1376// will be the unique store op in the src node, which will be used to check
1377// that the write region is the same after input-reuse fusion.
Uday Bondhugulab4a14432019-01-26 00:00:501378// Returns true if it is profitable to fuse the candidate loop nests. Returns
1379// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
1380// to materialize the source loop nest slice.
MLIR Team38c2fe32019-01-14 19:26:251381// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:021382// *) Computes the backward computation slice at 'srcOpInst'. This
1383// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:251384// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:021385// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:251386// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
1387// loop nest is the total number of dynamic operation instances in the loop
1388// nest).
1389// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:021390// loop nest at various values of dst loop depth, attempting to fuse
1391// the largest compution slice at the maximal dst loop depth (closest to the
1392// load) to minimize reuse distance and potentially enable subsequent
1393// load/store forwarding.
MLIR Teamd7c82442019-01-30 23:53:411394// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
MLIR Team27d067e2019-01-16 17:55:021395// the same memref as is written by 'srcOpInst', then the union of slice
1396// loop bounds is used to compute the slice and associated slice cost.
Uday Bondhugulab4a14432019-01-26 00:00:501397// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
MLIR Team38c2fe32019-01-14 19:26:251398// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:021399// NOTE: We attempt to maximize the dst loop depth, but there are cases
1400// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:251401// loop (within the src computation slice) at a depth which results in
1402// execessive recomputation (see unit tests for examples).
1403// *) Compares the total cost of the unfused loop nests to the min cost fused
1404// loop nest computed in the previous step, and returns true if the latter
1405// is lower.
River Riddle99b87c92019-03-27 21:02:021406static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
1407 ArrayRef<Operation *> dstLoadOpInsts,
1408 ArrayRef<Operation *> dstStoreOpInsts,
MLIR Team38c2fe32019-01-14 19:26:251409 ComputationSliceState *sliceState,
Uday Bondhugulace7e59532019-03-08 17:21:521410 unsigned *dstLoopDepth, bool maximalFusion) {
Uday Bondhugula06d21d92019-01-25 01:01:491411 LLVM_DEBUG({
1412 llvm::dbgs() << "Checking whether fusion is profitable between:\n";
Uday Bondhugulaa1dad3a2019-02-20 02:17:191413 llvm::dbgs() << " " << *srcOpInst << " and \n";
MLIR Teamd7c82442019-01-30 23:53:411414 for (auto dstOpInst : dstLoadOpInsts) {
Uday Bondhugulaa1dad3a2019-02-20 02:17:191415 llvm::dbgs() << " " << *dstOpInst << "\n";
Uday Bondhugula06d21d92019-01-25 01:01:491416 };
1417 });
Uday Bondhugula864d9e02019-01-23 17:16:241418
MLIR Team38c2fe32019-01-14 19:26:251419 // Compute cost of sliced and unsliced src loop nest.
Chris Lattnerd9b5bc82019-03-25 02:53:051420 SmallVector<AffineForOp, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:021421 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251422 unsigned numSrcLoopIVs = srcLoopIVs.size();
1423
1424 // Walk src loop nest and collect stats.
1425 LoopNestStats srcLoopNestStats;
1426 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
River Riddlef9d91532019-03-27 00:05:091427 srcStatsCollector.collect(srcLoopIVs[0].getOperation());
MLIR Team38c2fe32019-01-14 19:26:251428 // Currently only constant trip count loop nests are supported.
MLIR Teamc1ff9e82019-03-06 04:33:301429 if (srcStatsCollector.hasLoopWithNonConstTripCount) {
1430 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
MLIR Team38c2fe32019-01-14 19:26:251431 return false;
MLIR Teamc1ff9e82019-03-06 04:33:301432 }
MLIR Team38c2fe32019-01-14 19:26:251433 // Compute cost of dst loop nest.
Chris Lattnerd9b5bc82019-03-25 02:53:051434 SmallVector<AffineForOp, 4> dstLoopIVs;
MLIR Teamd7c82442019-01-30 23:53:411435 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251436
1437 LoopNestStats dstLoopNestStats;
1438 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
River Riddlef9d91532019-03-27 00:05:091439 dstStatsCollector.collect(dstLoopIVs[0].getOperation());
MLIR Team38c2fe32019-01-14 19:26:251440 // Currently only constant trip count loop nests are supported.
MLIR Teamc1ff9e82019-03-06 04:33:301441 if (dstStatsCollector.hasLoopWithNonConstTripCount) {
1442 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
MLIR Team38c2fe32019-01-14 19:26:251443 return false;
MLIR Teamc1ff9e82019-03-06 04:33:301444 }
MLIR Team38c2fe32019-01-14 19:26:251445
MLIR Teamd7c82442019-01-30 23:53:411446 // Compute the maximum loop depth at which we can can insert the src slice
MLIR Teamd038e342019-03-01 19:50:251447 // and still satisfy dest loop nest dependences, for producer-consumer fusion.
1448 unsigned maxDstLoopDepth =
1449 (srcOpInst == srcStoreOpInst)
1450 ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
1451 : dstLoopIVs.size();
MLIR Teamc1ff9e82019-03-06 04:33:301452 if (maxDstLoopDepth == 0) {
1453 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n");
MLIR Team27d067e2019-01-16 17:55:021454 return false;
MLIR Teamc1ff9e82019-03-06 04:33:301455 }
MLIR Team27d067e2019-01-16 17:55:021456
1457 // Search for min cost value for 'dstLoopDepth'. At each value of
1458 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
1459 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
1460 // of these bounds). Next the union slice bounds are used to calculate
1461 // the cost of the slice and the cost of the slice inserted into the dst
1462 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:241463 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
MLIR Teamd038e342019-03-01 19:50:251464 double maxStorageReduction = 0.0;
Uday Bondhugula864d9e02019-01-23 17:16:241465 Optional<uint64_t> sliceMemEstimate = None;
1466
MLIR Team27d067e2019-01-16 17:55:021467 SmallVector<ComputationSliceState, 4> sliceStates;
1468 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:241469 // The best loop depth at which to materialize the slice.
1470 Optional<unsigned> bestDstLoopDepth = None;
1471
1472 // Compute op instance count for the src loop nest without iteration slicing.
River Riddle5052bd82019-02-02 00:42:181473 uint64_t srcLoopNestCost =
River Riddlef9d91532019-03-27 00:05:091474 getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
River Riddle5052bd82019-02-02 00:42:181475 /*tripCountOverrideMap=*/nullptr,
1476 /*computeCostMap=*/nullptr);
Uday Bondhugula864d9e02019-01-23 17:16:241477
MLIR Teamb9dde912019-02-06 19:01:101478 // Compute src loop nest write region size.
MLIR Teamd038e342019-03-01 19:50:251479 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
River Riddle1e55ae12019-03-08 06:14:471480 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
MLIR Teamd42ef782019-03-04 19:01:251481 LLVM_DEBUG(llvm::dbgs()
River Riddle99b87c92019-03-27 21:02:021482 << "Unable to compute MemRefRegion for source operation\n.");
MLIR Teamd42ef782019-03-04 19:01:251483 return false;
1484 }
1485
MLIR Teamb9dde912019-02-06 19:01:101486 Optional<int64_t> maybeSrcWriteRegionSizeBytes =
1487 srcWriteRegion.getRegionSize();
1488 if (!maybeSrcWriteRegionSizeBytes.hasValue())
1489 return false;
1490 int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
1491
Uday Bondhugula864d9e02019-01-23 17:16:241492 // Compute op instance count for the src loop nest.
River Riddle5052bd82019-02-02 00:42:181493 uint64_t dstLoopNestCost =
River Riddlef9d91532019-03-27 00:05:091494 getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
River Riddle5052bd82019-02-02 00:42:181495 /*tripCountOverrideMap=*/nullptr,
1496 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:021497
MLIR Teamb9dde912019-02-06 19:01:101498 // Evaluate all depth choices for materializing the slice in the destination
1499 // loop nest.
River Riddle99b87c92019-03-27 21:02:021500 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1501 DenseMap<Operation *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:021502 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
MLIR Teamc1ff9e82019-03-06 04:33:301503 // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
1504 if (!getSliceUnion(srcOpInst, dstLoadOpInsts, numSrcLoopIVs, i,
1505 &sliceStates[i - 1])) {
1506 LLVM_DEBUG(llvm::dbgs()
1507 << "getSliceUnion failed for loopDepth: " << i << "\n");
1508 continue;
MLIR Team38c2fe32019-01-14 19:26:251509 }
MLIR Teamc1ff9e82019-03-06 04:33:301510
Uday Bondhugulab4a14432019-01-26 00:00:501511 // Build trip count map for computation slice. We'll skip cases where the
1512 // trip count was non-constant.
MLIR Team27d067e2019-01-16 17:55:021513 sliceTripCountMap.clear();
1514 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
MLIR Teamc1ff9e82019-03-06 04:33:301515 &sliceTripCountMap)) {
1516 LLVM_DEBUG(llvm::dbgs() << "Unable to build slice trip count map.\n.");
Uday Bondhugula864d9e02019-01-23 17:16:241517 continue;
MLIR Teamc1ff9e82019-03-06 04:33:301518 }
Uday Bondhugula864d9e02019-01-23 17:16:241519
1520 // Checks whether a store to load forwarding will happen.
1521 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
Uday Bondhugula864d9e02019-01-23 17:16:241522 assert(sliceIterationCount > 0);
Uday Bondhugulab4a14432019-01-26 00:00:501523 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
Uday Bondhugula864d9e02019-01-23 17:16:241524
1525 // Compute cost of fusion for this dest loop depth.
1526
1527 computeCostMap.clear();
1528
1529 // The store and loads to this memref will disappear.
MLIR Teamd038e342019-03-01 19:50:251530 // TODO(andydavis) Add load coalescing to memref data flow opt pass.
Uday Bondhugula864d9e02019-01-23 17:16:241531 if (storeLoadFwdGuaranteed) {
1532 // A single store disappears: -1 for that.
River Riddlef9d91532019-03-27 00:05:091533 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getOperation()] = -1;
River Riddlee4628b72019-04-05 19:24:031534 for (auto *loadOp : dstLoadOpInsts)
1535 if (auto forOp = dyn_cast_or_null<AffineForOp>(loadOp->getParentOp()))
1536 computeCostMap[forOp] = -1;
Uday Bondhugula864d9e02019-01-23 17:16:241537 }
MLIR Team27d067e2019-01-16 17:55:021538
MLIR Team38c2fe32019-01-14 19:26:251539 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:241540 int64_t sliceComputeCost =
River Riddlef9d91532019-03-27 00:05:091541 getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
Uday Bondhugula864d9e02019-01-23 17:16:241542 /*tripCountOverrideMap=*/&sliceTripCountMap,
1543 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:251544
Uday Bondhugula864d9e02019-01-23 17:16:241545 // Compute cost of fusion for this depth.
River Riddlef9d91532019-03-27 00:05:091546 computeCostMap[dstLoopIVs[i - 1].getOperation()] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:241547
1548 int64_t fusedLoopNestComputeCost =
River Riddlef9d91532019-03-27 00:05:091549 getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
MLIR Team27d067e2019-01-16 17:55:021550 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:241551
1552 double additionalComputeFraction =
1553 fusedLoopNestComputeCost /
1554 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1555 1;
1556
Amit Sabne70a416d2019-04-09 16:17:401557 // Determine what the slice write MemRefRegion would be, if the src loop
MLIR Teamb9dde912019-02-06 19:01:101558 // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
1559 // nest at loop depth 'i'
MLIR Teamd038e342019-03-01 19:50:251560 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
River Riddle1e55ae12019-03-08 06:14:471561 if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
1562 &sliceStates[i - 1]))) {
MLIR Teamc1ff9e82019-03-06 04:33:301563 LLVM_DEBUG(llvm::dbgs()
1564 << "Failed to compute slice write region at loopDepth: " << i
1565 << "\n");
MLIR Teamd42ef782019-03-04 19:01:251566 continue;
MLIR Teamc1ff9e82019-03-06 04:33:301567 }
MLIR Teamd42ef782019-03-04 19:01:251568
MLIR Teamb9dde912019-02-06 19:01:101569 Optional<int64_t> maybeSliceWriteRegionSizeBytes =
1570 sliceWriteRegion.getRegionSize();
1571 if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
MLIR Teamc1ff9e82019-03-06 04:33:301572 maybeSliceWriteRegionSizeBytes.getValue() == 0) {
1573 LLVM_DEBUG(llvm::dbgs()
1574 << "Failed to get slice write region size at loopDepth: " << i
1575 << "\n");
MLIR Teamb9dde912019-02-06 19:01:101576 continue;
MLIR Teamc1ff9e82019-03-06 04:33:301577 }
MLIR Teamb9dde912019-02-06 19:01:101578 int64_t sliceWriteRegionSizeBytes =
1579 maybeSliceWriteRegionSizeBytes.getValue();
1580
MLIR Teamd038e342019-03-01 19:50:251581 // If we are fusing for reuse, check that write regions remain the same.
1582 // TODO(andydavis) Write region check should check sizes and offsets in
1583 // each dimension, so that we are sure they are covering the same memref
1584 // region. Also, move this out to a isMemRefRegionSuperSet helper function.
1585 if (srcOpInst != srcStoreOpInst &&
1586 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
1587 continue;
1588
MLIR Teamb9dde912019-02-06 19:01:101589 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
1590 static_cast<double>(sliceWriteRegionSizeBytes);
Uday Bondhugula864d9e02019-01-23 17:16:241591
Uday Bondhugula06d21d92019-01-25 01:01:491592 LLVM_DEBUG({
1593 std::stringstream msg;
1594 msg << " evaluating fusion profitability at depth : " << i << "\n"
Uday Bondhugulad4b3ff12019-02-27 00:10:191595 << std::fixed << std::setprecision(2)
1596 << " additional compute fraction: "
Uday Bondhugula06d21d92019-01-25 01:01:491597 << 100.0 * additionalComputeFraction << "%\n"
1598 << " storage reduction factor: " << storageReduction << "x\n"
1599 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
Uday Bondhugulaa1dad3a2019-02-20 02:17:191600 << " slice iteration count: " << sliceIterationCount << "\n"
1601 << " src write region size: " << srcWriteRegionSizeBytes << "\n"
1602 << " slice write region size: " << sliceWriteRegionSizeBytes
1603 << "\n";
Uday Bondhugula06d21d92019-01-25 01:01:491604 llvm::dbgs() << msg.str();
1605 });
Uday Bondhugula864d9e02019-01-23 17:16:241606
1607 double computeToleranceThreshold =
1608 clFusionAddlComputeTolerance.getNumOccurrences() > 0
1609 ? clFusionAddlComputeTolerance
1610 : LoopFusion::kComputeToleranceThreshold;
1611
1612 // TODO(b/123247369): This is a placeholder cost model.
1613 // Among all choices that add an acceptable amount of redundant computation
1614 // (as per computeToleranceThreshold), we will simply pick the one that
1615 // reduces the intermediary size the most.
1616 if ((storageReduction > maxStorageReduction) &&
Uday Bondhugulace7e59532019-03-08 17:21:521617 (maximalFusion ||
Uday Bondhugula864d9e02019-01-23 17:16:241618 (additionalComputeFraction < computeToleranceThreshold))) {
1619 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:021620 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:241621 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
MLIR Teamb9dde912019-02-06 19:01:101622 sliceMemEstimate = sliceWriteRegionSizeBytes;
MLIR Team38c2fe32019-01-14 19:26:251623 }
1624 }
1625
Uday Bondhugula864d9e02019-01-23 17:16:241626 // A simple cost model: fuse if it reduces the memory footprint. If
1627 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251628
Uday Bondhugulace7e59532019-03-08 17:21:521629 if (!maximalFusion && !bestDstLoopDepth.hasValue()) {
Uday Bondhugulaa1dad3a2019-02-20 02:17:191630 LLVM_DEBUG(
1631 llvm::dbgs()
1632 << "All fusion choices involve more than the threshold amount of "
1633 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251634 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241635 }
1636
MLIR Teamd42ef782019-03-04 19:01:251637 if (!bestDstLoopDepth.hasValue()) {
1638 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
1639 return false;
1640 }
Uday Bondhugula864d9e02019-01-23 17:16:241641
1642 // Set dstLoopDepth based on best values from search.
1643 *dstLoopDepth = bestDstLoopDepth.getValue();
1644
1645 LLVM_DEBUG(
Uday Bondhugula06d21d92019-01-25 01:01:491646 llvm::dbgs() << " LoopFusion fusion stats:"
1647 << "\n best loop depth: " << bestDstLoopDepth
Uday Bondhugula864d9e02019-01-23 17:16:241648 << "\n src loop nest compute cost: " << srcLoopNestCost
1649 << "\n dst loop nest compute cost: " << dstLoopNestCost
1650 << "\n fused loop nest compute cost: "
1651 << minFusedLoopNestComputeCost << "\n");
1652
River Riddle5052bd82019-02-02 00:42:181653 auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
1654 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Uday Bondhugula864d9e02019-01-23 17:16:241655
1656 Optional<double> storageReduction = None;
1657
Uday Bondhugulace7e59532019-03-08 17:21:521658 if (!maximalFusion) {
Uday Bondhugula864d9e02019-01-23 17:16:241659 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1660 LLVM_DEBUG(
1661 llvm::dbgs()
1662 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1663 return false;
1664 }
1665
1666 auto srcMemSizeVal = srcMemSize.getValue();
1667 auto dstMemSizeVal = dstMemSize.getValue();
1668
1669 assert(sliceMemEstimate.hasValue() && "expected value");
Uday Bondhugula864d9e02019-01-23 17:16:241670 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1671
1672 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1673 << " dst mem: " << dstMemSizeVal << "\n"
1674 << " fused mem: " << fusedMem << "\n"
1675 << " slice mem: " << sliceMemEstimate << "\n");
1676
Jacques Pienaar2fe8ae42019-05-04 02:48:571677 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
Uday Bondhugula864d9e02019-01-23 17:16:241678 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1679 return false;
1680 }
1681 storageReduction =
1682 100.0 *
1683 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1684 }
1685
1686 double additionalComputeFraction =
1687 100.0 * (minFusedLoopNestComputeCost /
1688 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1689 1);
MLIR Team5c5739d2019-01-25 06:27:401690 (void)additionalComputeFraction;
Uday Bondhugula06d21d92019-01-25 01:01:491691 LLVM_DEBUG({
1692 std::stringstream msg;
1693 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
MLIR Team8564b272019-02-22 15:48:591694 << std::setprecision(2) << additionalComputeFraction
Uday Bondhugula06d21d92019-01-25 01:01:491695 << "% redundant computation and a ";
1696 msg << (storageReduction.hasValue()
1697 ? std::to_string(storageReduction.getValue())
1698 : "<unknown>");
1699 msg << "% storage reduction.\n";
1700 llvm::dbgs() << msg.str();
1701 });
Uday Bondhugula864d9e02019-01-23 17:16:241702
MLIR Team27d067e2019-01-16 17:55:021703 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241704 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021705 sliceState->lbs = bestSliceState->lbs;
1706 sliceState->ubs = bestSliceState->ubs;
1707 sliceState->lbOperands = bestSliceState->lbOperands;
1708 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241709
MLIR Team27d067e2019-01-16 17:55:021710 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251711 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Nicolas Vasilache0e7a8a92019-01-26 18:41:171712 if (sliceState->lbs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021713 canonicalizeMapAndOperands(&sliceState->lbs[i],
1714 &sliceState->lbOperands[i]);
1715 }
Nicolas Vasilache0e7a8a92019-01-26 18:41:171716 if (sliceState->ubs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021717 canonicalizeMapAndOperands(&sliceState->ubs[i],
1718 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251719 }
1720 }
1721 return true;
1722}
1723
MLIR Teamd038e342019-03-01 19:50:251724// GreedyFusion greedily fuses loop nests which have a producer/consumer or
1725// input-reuse relationship on a memref, with the goal of improving locality.
MLIR Teamf28e4df2018-11-01 14:26:001726//
MLIR Teamd038e342019-03-01 19:50:251727// The steps of the producer-consumer fusion algorithm are as follows:
MLIR Team3b692302018-12-17 17:57:141728//
MLIR Team6892ffb2018-12-20 04:42:551729// *) A worklist is initialized with node ids from the dependence graph.
1730// *) For each node id in the worklist:
Amit Sabne70a416d2019-04-09 16:17:401731// *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
River Riddle5052bd82019-02-02 00:42:181732// candidate destination AffineForOp into which fusion will be attempted.
1733// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141734// *) For each LoadOp in 'dstLoadOps' do:
Amit Sabne70a416d2019-04-09 16:17:401735// *) Look up dependent loop nests which have a single store op to the same
MLIR Teamd038e342019-03-01 19:50:251736// memref.
1737// *) Check if dependences would be violated by the fusion.
MLIR Team6892ffb2018-12-20 04:42:551738// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141739// bounds to be functions of 'dstLoopNest' IVs and symbols.
1740// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
MLIR Teamd038e342019-03-01 19:50:251741// at a loop depth determined by the cost model in 'isFusionProfitable'.
River Riddle99b87c92019-03-27 21:02:021742// *) Add the newly fused load/store operations to the state,
Amit Sabne70a416d2019-04-09 16:17:401743// and also add newly fused load ops to 'dstLoopOps' to be considered
MLIR Team3b692302018-12-17 17:57:141744// as fusion dst load ops in another iteration.
1745// *) Remove old src loop nest and its associated state.
1746//
MLIR Teamd038e342019-03-01 19:50:251747// The steps of the input-reuse fusion algorithm are as follows:
1748//
1749// *) Initialize 'worklist' with node ids from the dependence graph.
1750// *) For each 'dstNode' in the worklist:
1751// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
1752// loads from the same memref, but which has no dependence paths to/from.
1753// *) Get a computation slice of 'sibLoopNest', which adjusts its loop
1754// bounds to be functions of 'dstLoopNest' IVs and symbols.
1755// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
1756// at a loop depth determined by the cost model in 'isFusionProfitable'.
1757// This function also checks that the memref write region of 'sibLoopNest',
1758// is preserved in the fused loop nest.
1759// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
1760//
River Riddle99b87c92019-03-27 21:02:021761// Given a graph where top-level operations are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141762// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551763// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141764//
MLIR Team6892ffb2018-12-20 04:42:551765// This greedy algorithm is not 'maximal' due to the current restriction of
1766// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141767//
1768// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551769struct GreedyFusion {
1770public:
MLIR Teamd038e342019-03-01 19:50:251771 // The data dependence graph to traverse during fusion.
MLIR Team6892ffb2018-12-20 04:42:551772 MemRefDependenceGraph *mdg;
MLIR Teamd038e342019-03-01 19:50:251773 // Worklist of graph nodes visited during the fusion pass.
MLIR Teama78edcd2019-02-05 14:57:081774 SmallVector<unsigned, 8> worklist;
MLIR Teamd038e342019-03-01 19:50:251775 // Set of graph nodes which are present on the worklist.
MLIR Teama78edcd2019-02-05 14:57:081776 llvm::SmallDenseSet<unsigned, 16> worklistSet;
MLIR Teamd038e342019-03-01 19:50:251777 // Parameter for local buffer size threshold.
1778 unsigned localBufSizeThreshold;
1779 // Parameter for fast memory space.
1780 Optional<unsigned> fastMemorySpace;
Uday Bondhugulace7e59532019-03-08 17:21:521781 // If true, ignore any additional (redundant) computation tolerance threshold
1782 // that would have prevented fusion.
1783 bool maximalFusion;
MLIR Teamf28e4df2018-11-01 14:26:001784
MLIR Teamd038e342019-03-01 19:50:251785 using Node = MemRefDependenceGraph::Node;
1786
1787 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
Uday Bondhugulace7e59532019-03-08 17:21:521788 Optional<unsigned> fastMemorySpace, bool maximalFusion)
MLIR Teamd038e342019-03-01 19:50:251789 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
Uday Bondhugulace7e59532019-03-08 17:21:521790 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {}
MLIR Teamd038e342019-03-01 19:50:251791
1792 // Initializes 'worklist' with nodes from 'mdg'
1793 void init() {
MLIR Teama78edcd2019-02-05 14:57:081794 // TODO(andydavis) Add a priority queue for prioritizing nodes by different
1795 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
MLIR Teamd038e342019-03-01 19:50:251796 worklist.clear();
1797 worklistSet.clear();
1798 for (auto &idAndNode : mdg->nodes) {
1799 const Node &node = idAndNode.second;
1800 worklist.push_back(node.id);
1801 worklistSet.insert(node.id);
1802 }
MLIR Team6892ffb2018-12-20 04:42:551803 }
MLIR Team3b692302018-12-17 17:57:141804
MLIR Teamd038e342019-03-01 19:50:251805 // Run the GreedyFusion pass.
1806 // *) First pass through the nodes fuses single-use producer nodes into their
1807 // unique consumer.
1808 // *) Second pass fuses sibling nodes which share no dependence edges.
1809 // *) Third pass fuses any remaining producer nodes into their users.
1810 void run() {
MLIR Teamc1ff9e82019-03-06 04:33:301811 // TODO(andydavis) Run this repeatedly until a fixed-point is reached.
MLIR Teamd038e342019-03-01 19:50:251812 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
1813 fuseSiblingNodes();
1814 fuseProducerConsumerNodes(
1815 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
1816 eraseUnusedMemRefAllocations();
1817 }
1818
1819 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1820 init();
MLIR Team3b692302018-12-17 17:57:141821 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551822 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141823 worklist.pop_back();
MLIR Teama78edcd2019-02-05 14:57:081824 worklistSet.erase(dstId);
1825
MLIR Team6892ffb2018-12-20 04:42:551826 // Skip if this node was removed (fused into another node).
1827 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141828 continue;
MLIR Team6892ffb2018-12-20 04:42:551829 // Get 'dstNode' into which to attempt fusion.
1830 auto *dstNode = mdg->getNode(dstId);
1831 // Skip if 'dstNode' is not a loop nest.
River Riddle99b87c92019-03-27 21:02:021832 if (!dstNode->op->isa<AffineForOp>())
MLIR Team3b692302018-12-17 17:57:141833 continue;
MLIR Team8f5f2c72019-02-15 17:32:181834 // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
1835 // while preserving relative order. This can increase the maximum loop
1836 // depth at which we can fuse a slice of a producer loop nest into a
1837 // consumer loop nest.
1838 sinkSequentialLoops(dstNode);
MLIR Team3b692302018-12-17 17:57:141839
River Riddle99b87c92019-03-27 21:02:021840 SmallVector<Operation *, 4> loads = dstNode->loads;
1841 SmallVector<Operation *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271842 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551843 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021844 // Get memref of load on top of the stack.
River Riddle96ebde92019-03-25 20:02:061845 auto *memref = loads.back()->cast<LoadOp>().getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271846 if (visitedMemrefs.count(memref) > 0)
1847 continue;
1848 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021849 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1850 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551851 // Skip if no input edges along which to fuse.
1852 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141853 continue;
Amit Sabne70a416d2019-04-09 16:17:401854 // Iterate through in-edges for 'dstId' and src node id for any
MLIR Team1e851912019-01-31 00:01:461855 // edges on 'memref'.
1856 SmallVector<unsigned, 2> srcNodeIds;
MLIR Team6892ffb2018-12-20 04:42:551857 for (auto &srcEdge : mdg->inEdges[dstId]) {
1858 // Skip 'srcEdge' if not for 'memref'.
MLIR Teama0f3db402019-01-29 17:36:411859 if (srcEdge.value != memref)
MLIR Team6892ffb2018-12-20 04:42:551860 continue;
MLIR Team1e851912019-01-31 00:01:461861 srcNodeIds.push_back(srcEdge.id);
1862 }
1863 for (unsigned srcId : srcNodeIds) {
1864 // Skip if this node was removed (fused into another node).
1865 if (mdg->nodes.count(srcId) == 0)
1866 continue;
1867 // Get 'srcNode' from which to attempt fusion into 'dstNode'.
1868 auto *srcNode = mdg->getNode(srcId);
MLIR Team6892ffb2018-12-20 04:42:551869 // Skip if 'srcNode' is not a loop nest.
River Riddle99b87c92019-03-27 21:02:021870 if (!srcNode->op->isa<AffineForOp>())
MLIR Team6892ffb2018-12-20 04:42:551871 continue;
MLIR Teamb28009b2019-01-23 19:11:431872 // Skip if 'srcNode' has more than one store to any memref.
1873 // TODO(andydavis) Support fusing multi-output src loop nests.
1874 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551875 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241876
MLIR Team58aa3832019-02-16 01:12:191877 // Skip if 'srcNode' writes to any live in or escaping memrefs,
1878 // and cannot be fused.
1879 bool writesToLiveInOrOut =
1880 mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
1881 if (writesToLiveInOrOut &&
1882 !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
MLIR Teamd7c82442019-01-30 23:53:411883 continue;
1884
MLIR Teamd038e342019-03-01 19:50:251885 // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
1886 if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
1887 continue;
1888
River Riddle99b87c92019-03-27 21:02:021889 // Compute an operation list insertion point for the fused loop
MLIR Teama0f3db402019-01-29 17:36:411890 // nest which preserves dependences.
River Riddle99b87c92019-03-27 21:02:021891 Operation *insertPointInst =
MLIR Teama78edcd2019-02-05 14:57:081892 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
MLIR Teama0f3db402019-01-29 17:36:411893 if (insertPointInst == nullptr)
MLIR Team6892ffb2018-12-20 04:42:551894 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241895
MLIR Team6892ffb2018-12-20 04:42:551896 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351897 auto *srcStoreOpInst = srcNode->stores.front();
MLIR Teamd7c82442019-01-30 23:53:411898 // Gather 'dstNode' store ops to 'memref'.
River Riddle99b87c92019-03-27 21:02:021899 SmallVector<Operation *, 2> dstStoreOpInsts;
MLIR Teamd7c82442019-01-30 23:53:411900 for (auto *storeOpInst : dstNode->stores)
River Riddle96ebde92019-03-25 20:02:061901 if (storeOpInst->cast<StoreOp>().getMemRef() == memref)
MLIR Teamd7c82442019-01-30 23:53:411902 dstStoreOpInsts.push_back(storeOpInst);
1903
Uday Bondhugulab4a14432019-01-26 00:00:501904 unsigned bestDstLoopDepth;
MLIR Team38c2fe32019-01-14 19:26:251905 mlir::ComputationSliceState sliceState;
MLIR Teama0f3db402019-01-29 17:36:411906 // Check if fusion would be profitable.
MLIR Teamd038e342019-03-01 19:50:251907 if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst,
1908 dstLoadOpInsts, dstStoreOpInsts, &sliceState,
Uday Bondhugulace7e59532019-03-08 17:21:521909 &bestDstLoopDepth, maximalFusion))
MLIR Team38c2fe32019-01-14 19:26:251910 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241911
MLIR Team6892ffb2018-12-20 04:42:551912 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
River Riddle5052bd82019-02-02 00:42:181913 auto sliceLoopNest = mlir::insertBackwardComputationSlice(
Uday Bondhugulab4a14432019-01-26 00:00:501914 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
Chris Lattnerd9b5bc82019-03-25 02:53:051915 if (sliceLoopNest) {
River Riddleaf1abcc2019-03-25 18:13:311916 LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
River Riddlef9d91532019-03-27 00:05:091917 << *sliceLoopNest.getOperation() << "\n");
River Riddle5052bd82019-02-02 00:42:181918 // Move 'dstAffineForOp' before 'insertPointInst' if needed.
River Riddle99b87c92019-03-27 21:02:021919 auto dstAffineForOp = dstNode->op->cast<AffineForOp>();
River Riddlef9d91532019-03-27 00:05:091920 if (insertPointInst != dstAffineForOp.getOperation()) {
1921 dstAffineForOp.getOperation()->moveBefore(insertPointInst);
MLIR Teama0f3db402019-01-29 17:36:411922 }
MLIR Teamc4237ae2019-01-18 16:56:271923 // Update edges between 'srcNode' and 'dstNode'.
MLIR Teama0f3db402019-01-29 17:36:411924 mdg->updateEdges(srcNode->id, dstNode->id, memref);
MLIR Teamc4237ae2019-01-18 16:56:271925
1926 // Collect slice loop stats.
1927 LoopNestStateCollector sliceCollector;
River Riddlef9d91532019-03-27 00:05:091928 sliceCollector.collect(sliceLoopNest.getOperation());
MLIR Teamc4237ae2019-01-18 16:56:271929 // Promote single iteration slice loops to single IV value.
River Riddle5052bd82019-02-02 00:42:181930 for (auto forOp : sliceCollector.forOps) {
1931 promoteIfSingleIteration(forOp);
MLIR Team6892ffb2018-12-20 04:42:551932 }
MLIR Team58aa3832019-02-16 01:12:191933 if (!writesToLiveInOrOut) {
1934 // Create private memref for 'memref' in 'dstAffineForOp'.
River Riddle99b87c92019-03-27 21:02:021935 SmallVector<Operation *, 4> storesForMemref;
MLIR Team58aa3832019-02-16 01:12:191936 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
River Riddle96ebde92019-03-25 20:02:061937 if (storeOpInst->cast<StoreOp>().getMemRef() == memref)
MLIR Team58aa3832019-02-16 01:12:191938 storesForMemref.push_back(storeOpInst);
1939 }
1940 assert(storesForMemref.size() == 1);
1941 auto *newMemRef = createPrivateMemRef(
1942 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1943 fastMemorySpace, localBufSizeThreshold);
1944 visitedMemrefs.insert(newMemRef);
1945 // Create new node in dependence graph for 'newMemRef' alloc op.
1946 unsigned newMemRefNodeId =
River Riddlef9d91532019-03-27 00:05:091947 mdg->addNode(newMemRef->getDefiningOp());
MLIR Team58aa3832019-02-16 01:12:191948 // Add edge from 'newMemRef' node to dstNode.
1949 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
MLIR Teamc4237ae2019-01-18 16:56:271950 }
MLIR Teamc4237ae2019-01-18 16:56:271951
1952 // Collect dst loop stats after memref privatizaton transformation.
1953 LoopNestStateCollector dstLoopCollector;
River Riddlef9d91532019-03-27 00:05:091954 dstLoopCollector.collect(dstAffineForOp.getOperation());
MLIR Teamc4237ae2019-01-18 16:56:271955
1956 // Add new load ops to current Node load op list 'loads' to
1957 // continue fusing based on new operands.
1958 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
River Riddle96ebde92019-03-25 20:02:061959 auto *loadMemRef = loadOpInst->cast<LoadOp>().getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271960 if (visitedMemrefs.count(loadMemRef) == 0)
1961 loads.push_back(loadOpInst);
1962 }
1963
Amit Sabne70a416d2019-04-09 16:17:401964 // Clear and add back loads and stores.
MLIR Teamc4237ae2019-01-18 16:56:271965 mdg->clearNodeLoadAndStores(dstNode->id);
1966 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1967 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371968 // Remove old src loop nest if it no longer has outgoing dependence
Amit Sabne70a416d2019-04-09 16:17:401969 // edges, and if it does not write to a memref which escapes the
MLIR Team58aa3832019-02-16 01:12:191970 // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has
1971 // been fused into 'dstNode' and write region of 'dstNode' covers
1972 // the write region of 'srcNode', and 'srcNode' has no other users
1973 // so it is safe to remove.
1974 if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271975 mdg->removeNode(srcNode->id);
River Riddle99b87c92019-03-27 21:02:021976 srcNode->op->erase();
MLIR Teama78edcd2019-02-05 14:57:081977 } else {
1978 // Add remaining users of 'oldMemRef' back on the worklist (if not
1979 // already there), as its replacement with a local/private memref
1980 // has reduced dependences on 'oldMemRef' which may have created
1981 // new fusion opportunities.
1982 if (mdg->outEdges.count(srcNode->id) > 0) {
1983 SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
1984 mdg->outEdges[srcNode->id];
1985 for (auto &outEdge : oldOutEdges) {
1986 if (outEdge.value == memref &&
1987 worklistSet.count(outEdge.id) == 0) {
1988 worklist.push_back(outEdge.id);
1989 worklistSet.insert(outEdge.id);
1990 }
1991 }
1992 }
MLIR Teamc4237ae2019-01-18 16:56:271993 }
MLIR Team3b692302018-12-17 17:57:141994 }
MLIR Team3b692302018-12-17 17:57:141995 }
1996 }
1997 }
MLIR Teamd038e342019-03-01 19:50:251998 }
1999
2000 // Visits each node in the graph, and for each node, attempts to fuse it with
2001 // its sibling nodes (nodes which share a parent, but no dependence edges).
2002 void fuseSiblingNodes() {
2003 init();
2004 while (!worklist.empty()) {
2005 unsigned dstId = worklist.back();
2006 worklist.pop_back();
2007 worklistSet.erase(dstId);
2008
2009 // Skip if this node was removed (fused into another node).
2010 if (mdg->nodes.count(dstId) == 0)
2011 continue;
2012 // Get 'dstNode' into which to attempt fusion.
2013 auto *dstNode = mdg->getNode(dstId);
2014 // Skip if 'dstNode' is not a loop nest.
River Riddle99b87c92019-03-27 21:02:022015 if (!dstNode->op->isa<AffineForOp>())
MLIR Teamd038e342019-03-01 19:50:252016 continue;
2017 // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
2018 fuseWithSiblingNodes(dstNode);
2019 }
2020 }
2021
2022 // Attempt to fuse 'dstNode' with sibling nodes in the graph.
2023 void fuseWithSiblingNodes(Node *dstNode) {
2024 DenseSet<unsigned> visitedSibNodeIds;
2025 std::pair<unsigned, Value *> idAndMemref;
2026 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
2027 unsigned sibId = idAndMemref.first;
2028 Value *memref = idAndMemref.second;
2029 // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other
2030 // stores to the same memref in 'sibNode' loop nest.
2031 auto *sibNode = mdg->getNode(sibId);
River Riddle99b87c92019-03-27 21:02:022032 // Compute an operation list insertion point for the fused loop
MLIR Teamd038e342019-03-01 19:50:252033 // nest which preserves dependences.
River Riddle99b87c92019-03-27 21:02:022034 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
2035 Operation *insertPointInst =
2036 sibNode->op->isBeforeInBlock(dstNode->op)
MLIR Teamd038e342019-03-01 19:50:252037 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
2038 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
2039 if (insertPointInst == nullptr)
2040 continue;
2041
2042 // Check if fusion would be profitable and at what depth.
2043
2044 // Get unique 'sibNode' load op to 'memref'.
River Riddle99b87c92019-03-27 21:02:022045 SmallVector<Operation *, 2> sibLoadOpInsts;
MLIR Teamd038e342019-03-01 19:50:252046 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
2047 // Currently findSiblingNodeToFuse searches for siblings with one load.
2048 assert(sibLoadOpInsts.size() == 1);
River Riddle99b87c92019-03-27 21:02:022049 Operation *sibLoadOpInst = sibLoadOpInsts[0];
MLIR Teamd038e342019-03-01 19:50:252050 assert(!sibNode->stores.empty());
2051 // TODO(andydavis) Choose the store which postdominates all other stores.
2052 auto *sibStoreOpInst = sibNode->stores.back();
2053
2054 // Gather 'dstNode' load ops to 'memref'.
River Riddle99b87c92019-03-27 21:02:022055 SmallVector<Operation *, 2> dstLoadOpInsts;
MLIR Teamd038e342019-03-01 19:50:252056 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
2057
2058 // Gather 'dstNode' store ops to 'memref'.
River Riddle99b87c92019-03-27 21:02:022059 SmallVector<Operation *, 2> dstStoreOpInsts;
MLIR Teamd038e342019-03-01 19:50:252060 dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
2061
2062 unsigned bestDstLoopDepth;
2063 mlir::ComputationSliceState sliceState;
2064
2065 // Check if fusion would be profitable.
2066 if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
Uday Bondhugulace7e59532019-03-08 17:21:522067 dstStoreOpInsts, &sliceState, &bestDstLoopDepth,
2068 maximalFusion))
MLIR Teamd038e342019-03-01 19:50:252069 continue;
2070
2071 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
2072 auto sliceLoopNest = mlir::insertBackwardComputationSlice(
2073 sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
2074 if (sliceLoopNest != nullptr) {
River Riddle99b87c92019-03-27 21:02:022075 auto dstForInst = dstNode->op->cast<AffineForOp>();
2076 // Update operation position of fused loop nest (if needed).
River Riddlef9d91532019-03-27 00:05:092077 if (insertPointInst != dstForInst.getOperation()) {
2078 dstForInst.getOperation()->moveBefore(insertPointInst);
MLIR Teamd038e342019-03-01 19:50:252079 }
2080 // Update data dependence graph state post fusion.
2081 updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
2082 }
2083 }
2084 }
2085
MLIR Team9d30b362019-03-29 15:06:252086 // Searches function argument uses and the graph from 'dstNode' looking for a
2087 // fusion candidate sibling node which shares no dependences with 'dstNode'
2088 // but which loads from the same memref. Returns true and sets
2089 // 'idAndMemrefToFuse' on success. Returns false otherwise.
MLIR Teamd038e342019-03-01 19:50:252090 bool findSiblingNodeToFuse(Node *dstNode,
2091 DenseSet<unsigned> *visitedSibNodeIds,
2092 std::pair<unsigned, Value *> *idAndMemrefToFuse) {
MLIR Team9d30b362019-03-29 15:06:252093 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
2094 // on 'memref'.
2095 auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) {
2096 // Skip if 'outEdge' is not a read-after-write dependence.
2097 // TODO(andydavis) Remove restrict to single load op restriction.
2098 if (sibNode->getLoadOpCount(memref) != 1)
2099 return false;
2100 // Skip if there exists a path of dependent edges between
2101 // 'sibNode' and 'dstNode'.
2102 if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
2103 mdg->hasDependencePath(dstNode->id, sibNode->id))
2104 return false;
2105 // Skip sib node if it loads to (and stores from) the same memref on
2106 // which it also has an input dependence edge.
2107 DenseSet<Value *> loadAndStoreMemrefSet;
2108 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
2109 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
2110 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
2111 }))
2112 return false;
2113
2114 // Check that all stores are to the same memref.
2115 DenseSet<Value *> storeMemrefs;
2116 for (auto *storeOpInst : sibNode->stores) {
2117 storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
2118 }
2119 if (storeMemrefs.size() != 1)
2120 return false;
2121 return true;
2122 };
2123
2124 // Search for siblings which load the same memref function argument.
2125 auto *fn = dstNode->op->getFunction();
2126 for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) {
2127 for (auto &use : fn->getArgument(i)->getUses()) {
River Riddlec5ecf992019-05-11 22:56:502128 if (auto loadOp = dyn_cast<LoadOp>(use.getOwner())) {
MLIR Team9d30b362019-03-29 15:06:252129 // Gather loops surrounding 'use'.
2130 SmallVector<AffineForOp, 4> loops;
2131 getLoopIVs(*use.getOwner(), &loops);
2132 // Skip 'use' if it is not within a loop nest.
2133 if (loops.empty())
2134 continue;
2135 Node *sibNode = mdg->getForOpNode(loops[0]);
2136 assert(sibNode != nullptr);
2137 // Skip 'use' if it not a sibling to 'dstNode'.
2138 if (sibNode->id == dstNode->id)
2139 continue;
2140 // Skip 'use' if it has been visited.
2141 if (visitedSibNodeIds->count(sibNode->id) > 0)
2142 continue;
2143 // Skip 'use' if it does not load from the same memref as 'dstNode'.
2144 auto *memref = loadOp.getMemRef();
2145 if (dstNode->getLoadOpCount(memref) == 0)
2146 continue;
2147 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
2148 if (canFuseWithSibNode(sibNode, memref)) {
2149 visitedSibNodeIds->insert(sibNode->id);
2150 idAndMemrefToFuse->first = sibNode->id;
2151 idAndMemrefToFuse->second = memref;
2152 return true;
2153 }
2154 }
2155 }
2156 }
2157
2158 // Search for siblings by following edges through an intermediate src node.
MLIR Teamd038e342019-03-01 19:50:252159 // Collect candidate 'dstNode' input edges in 'inEdges'.
2160 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
2161 mdg->forEachMemRefInputEdge(
2162 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
2163 // Add 'inEdge' if it is a read-after-write dependence.
2164 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
2165 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
2166 inEdges.push_back(inEdge);
2167 });
2168
2169 // Search for sibling nodes to fuse by visiting output edges from each input
2170 // edge in 'inEdges'.
2171 for (auto &inEdge : inEdges) {
2172 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
2173 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
2174 mdg->forEachMemRefOutputEdge(
2175 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
2176 unsigned sibNodeId = outEdge.id;
2177 if (visitedSibNodeIds->count(sibNodeId) > 0)
2178 return;
2179 // Skip output edge if not a sibling using the same memref.
2180 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
2181 return;
2182 auto *sibNode = mdg->getNode(sibNodeId);
River Riddle99b87c92019-03-27 21:02:022183 if (!sibNode->op->isa<AffineForOp>())
MLIR Teamd038e342019-03-01 19:50:252184 return;
MLIR Team9d30b362019-03-29 15:06:252185 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
2186 if (canFuseWithSibNode(sibNode, outEdge.value)) {
2187 // Add candidate 'outEdge' to sibling node.
2188 outEdges.push_back(outEdge);
MLIR Teamd038e342019-03-01 19:50:252189 }
MLIR Teamd038e342019-03-01 19:50:252190 });
2191
2192 // Add first candidate if any were returned.
2193 if (!outEdges.empty()) {
2194 visitedSibNodeIds->insert(outEdges[0].id);
2195 idAndMemrefToFuse->first = outEdges[0].id;
2196 idAndMemrefToFuse->second = outEdges[0].value;
2197 return true;
2198 }
2199 }
2200 return false;
2201 }
2202
Chris Lattnerd9b5bc82019-03-25 02:53:052203 void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode,
2204 Node *dstNode) {
MLIR Teamd038e342019-03-01 19:50:252205 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
2206 mdg->updateEdges(sibNode->id, dstNode->id);
2207
2208 // Collect slice loop stats.
2209 LoopNestStateCollector sliceCollector;
River Riddlef9d91532019-03-27 00:05:092210 sliceCollector.collect(sliceLoopNest.getOperation());
MLIR Teamd038e342019-03-01 19:50:252211 // Promote single iteration slice loops to single IV value.
2212 for (auto forOp : sliceCollector.forOps) {
2213 promoteIfSingleIteration(forOp);
2214 }
2215
2216 // Collect dst loop stats after memref privatizaton transformation.
River Riddle99b87c92019-03-27 21:02:022217 auto dstForInst = dstNode->op->cast<AffineForOp>();
MLIR Teamd038e342019-03-01 19:50:252218 LoopNestStateCollector dstLoopCollector;
River Riddlef9d91532019-03-27 00:05:092219 dstLoopCollector.collect(dstForInst.getOperation());
MLIR Teamd038e342019-03-01 19:50:252220 // Clear and add back loads and stores
2221 mdg->clearNodeLoadAndStores(dstNode->id);
2222 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
2223 dstLoopCollector.storeOpInsts);
2224 // Remove old sibling loop nest if it no longer has outgoing dependence
2225 // edges, and it does not write to a memref which escapes the
2226 // function.
2227 if (mdg->getOutEdgeCount(sibNode->id) == 0) {
2228 mdg->removeNode(sibNode->id);
River Riddle99b87c92019-03-27 21:02:022229 sibNode->op->cast<AffineForOp>().erase();
MLIR Teamd038e342019-03-01 19:50:252230 }
2231 }
2232
2233 // Clean up any allocs with no users.
2234 void eraseUnusedMemRefAllocations() {
MLIR Teamc4237ae2019-01-18 16:56:272235 for (auto &pair : mdg->memrefEdgeCount) {
2236 if (pair.second > 0)
2237 continue;
2238 auto *memref = pair.first;
River Riddle99b87c92019-03-27 21:02:022239 // Skip if there exist other uses (return operation or function calls).
MLIR Team71495d52019-01-22 21:23:372240 if (!memref->use_empty())
2241 continue;
MLIR Teamc4237ae2019-01-18 16:56:272242 // Use list expected to match the dep graph info.
River Riddle99b87c92019-03-27 21:02:022243 auto *op = memref->getDefiningOp();
River Riddle1423acc2019-04-23 21:38:262244 if (isa_and_nonnull<AllocOp>(op))
River Riddle99b87c92019-03-27 21:02:022245 op->erase();
MLIR Teamc4237ae2019-01-18 16:56:272246 }
MLIR Teamf28e4df2018-11-01 14:26:002247 }
MLIR Team3b692302018-12-17 17:57:142248};
2249
2250} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:002251
River Riddleed5fe202019-02-28 22:50:422252void LoopFusion::runOnFunction() {
Uday Bondhugulad4b3ff12019-02-27 00:10:192253 // Override if a command line argument was provided.
Uday Bondhugula8be26272019-02-02 01:06:222254 if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
2255 fastMemorySpace = clFusionFastMemorySpace.getValue();
2256 }
2257
Uday Bondhugulad4b3ff12019-02-27 00:10:192258 // Override if a command line argument was provided.
2259 if (clFusionLocalBufThreshold.getNumOccurrences() > 0) {
2260 localBufSizeThreshold = clFusionLocalBufThreshold * 1024;
2261 }
2262
Uday Bondhugulace7e59532019-03-08 17:21:522263 if (clMaximalLoopFusion.getNumOccurrences() > 0)
2264 maximalFusion = clMaximalLoopFusion;
2265
MLIR Team6892ffb2018-12-20 04:42:552266 MemRefDependenceGraph g;
Uday Bondhugula02af8c22019-03-05 23:05:342267 if (g.init(getFunction()))
Uday Bondhugulace7e59532019-03-08 17:21:522268 GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace, maximalFusion)
2269 .run();
MLIR Teamf28e4df2018-11-01 14:26:002270}
Jacques Pienaar6f0fb222018-11-07 02:34:182271
Nicolas Vasilache258e8d92019-05-03 18:07:372272static PassRegistration<LoopFusion> pass("affine-loop-fusion",
2273 "Fuse loop nests");