blob: 1e4e020b43579b741e2134319050a92941aa1d7f [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
River Riddle75553832019-01-29 05:23:5322#include "mlir/AffineOps/AffineOps.h"
MLIR Teamf28e4df2018-11-01 14:26:0023#include "mlir/Analysis/AffineAnalysis.h"
Uday Bondhuguladfe07b72019-02-23 00:51:0824#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0025#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1426#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0027#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
River Riddle48ccae22019-02-20 01:17:4630#include "mlir/Pass/Pass.h"
Lei Zhang85d9b6c2019-03-01 21:48:2431#include "mlir/StandardOps/Ops.h"
MLIR Teamf28e4df2018-11-01 14:26:0032#include "mlir/Transforms/LoopUtils.h"
33#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2734#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0035#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1436#include "llvm/ADT/DenseSet.h"
37#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2338#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2539#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1440#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2441#include <iomanip>
MLIR Team3b692302018-12-17 17:57:1442
MLIR Team38c2fe32019-01-14 19:26:2543#define DEBUG_TYPE "loop-fusion"
44
MLIR Team3b692302018-12-17 17:57:1445using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0046
47using namespace mlir;
48
River Riddle75c21e12019-01-26 06:14:0449static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
50
Uday Bondhugula864d9e02019-01-23 17:16:2451/// Disables fusion profitability check and fuses if valid.
MLIR Teamc4237ae2019-01-18 16:56:2752static llvm::cl::opt<bool>
Uday Bondhugulaeee85362019-03-02 01:42:1353 clMaximalLoopFusion("fusion-maximal",
River Riddle75c21e12019-01-26 06:14:0454 llvm::cl::desc("Enables maximal loop fusion"),
55 llvm::cl::cat(clOptionsCategory));
Uday Bondhugula864d9e02019-01-23 17:16:2456
57/// A threshold in percent of additional computation allowed when fusing.
58static llvm::cl::opt<double> clFusionAddlComputeTolerance(
Uday Bondhugulaeee85362019-03-02 01:42:1359 "fusion-compute-tolerance",
Uday Bondhugulaa1dad3a2019-02-20 02:17:1960 llvm::cl::desc("Fractional increase in additional "
61 "computation tolerated while fusing"),
River Riddle75c21e12019-01-26 06:14:0462 llvm::cl::cat(clOptionsCategory));
MLIR Teamc4237ae2019-01-18 16:56:2763
Uday Bondhugula8be26272019-02-02 01:06:2264static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
Uday Bondhugulaeee85362019-03-02 01:42:1365 "fusion-fast-mem-space",
Uday Bondhugula8be26272019-02-02 01:06:2266 llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
67 llvm::cl::cat(clOptionsCategory));
68
Uday Bondhugulad4b3ff12019-02-27 00:10:1969// A local buffer of size less than or equal to this size is promoted to fast
70// memory.
71static llvm::cl::opt<unsigned long long> clFusionLocalBufThreshold(
Uday Bondhugulaeee85362019-03-02 01:42:1372 "fusion-local-buf-threshold",
Uday Bondhugulad4b3ff12019-02-27 00:10:1973 llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast "
Uday Bondhugula8be26272019-02-02 01:06:2274 "memory space"),
75 llvm::cl::cat(clOptionsCategory));
76
MLIR Teamf28e4df2018-11-01 14:26:0077namespace {
78
MLIR Team3b692302018-12-17 17:57:1479/// Loop fusion pass. This pass currently supports a greedy fusion policy,
80/// which fuses loop nests with single-writer/single-reader memref dependences
81/// with the goal of improving locality.
82
83// TODO(andydavis) Support fusion of source loop nests which write to multiple
84// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0085// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
86// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1487
River Riddlec6c53442019-02-27 18:59:2988struct LoopFusion : public FunctionPass<LoopFusion> {
Uday Bondhugulad4b3ff12019-02-27 00:10:1989 LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0)
River Riddlec6c53442019-02-27 18:59:2990 : localBufSizeThreshold(localBufSizeThreshold),
Uday Bondhugulad4b3ff12019-02-27 00:10:1991 fastMemorySpace(fastMemorySpace) {}
MLIR Teamf28e4df2018-11-01 14:26:0092
River Riddleed5fe202019-02-28 22:50:4293 void runOnFunction() override;
Uday Bondhugula864d9e02019-01-23 17:16:2494
Uday Bondhugulad4b3ff12019-02-27 00:10:1995 // Any local buffers smaller than this size (in bytes) will be created in
Uday Bondhugula8be26272019-02-02 01:06:2296 // `fastMemorySpace` if provided.
Uday Bondhugulad4b3ff12019-02-27 00:10:1997 uint64_t localBufSizeThreshold;
Uday Bondhugula8be26272019-02-02 01:06:2298 Optional<unsigned> fastMemorySpace = None;
99
Uday Bondhugula864d9e02019-01-23 17:16:24100 // The amount of additional computation that is tolerated while fusing
101 // pair-wise as a fraction of the total computation.
102 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:00103};
104
MLIR Teamf28e4df2018-11-01 14:26:00105} // end anonymous namespace
106
River Riddlec6c53442019-02-27 18:59:29107FunctionPassBase *mlir::createLoopFusionPass(unsigned fastMemorySpace,
108 uint64_t localBufSizeThreshold) {
Uday Bondhugulad4b3ff12019-02-27 00:10:19109 return new LoopFusion(fastMemorySpace, localBufSizeThreshold);
110}
MLIR Teamf28e4df2018-11-01 14:26:00111
MLIR Team3b692302018-12-17 17:57:14112namespace {
MLIR Teamf28e4df2018-11-01 14:26:00113
MLIR Team3b692302018-12-17 17:57:14114// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:35115// operations, and whether or not an IfInst was encountered in the loop nest.
River Riddlebf9c3812019-02-05 00:24:44116struct LoopNestStateCollector {
River Riddle5052bd82019-02-02 00:42:18117 SmallVector<OpPointer<AffineForOp>, 4> forOps;
River Riddleb4992772019-02-04 18:38:47118 SmallVector<Instruction *, 4> loadOpInsts;
119 SmallVector<Instruction *, 4> storeOpInsts;
River Riddle75553832019-01-29 05:23:53120 bool hasNonForRegion = false;
MLIR Team3b692302018-12-17 17:57:14121
River Riddlebf9c3812019-02-05 00:24:44122 void collect(Instruction *instToWalk) {
123 instToWalk->walk([&](Instruction *opInst) {
124 if (opInst->isa<AffineForOp>())
125 forOps.push_back(opInst->cast<AffineForOp>());
126 else if (opInst->getNumBlockLists() != 0)
127 hasNonForRegion = true;
128 else if (opInst->isa<LoadOp>())
129 loadOpInsts.push_back(opInst);
130 else if (opInst->isa<StoreOp>())
131 storeOpInsts.push_back(opInst);
132 });
MLIR Team3b692302018-12-17 17:57:14133 }
134};
135
MLIR Team71495d52019-01-22 21:23:37136// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
River Riddleb4992772019-02-04 18:38:47137static bool isMemRefDereferencingOp(const Instruction &op) {
MLIR Team71495d52019-01-22 21:23:37138 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
139 op.isa<DmaWaitOp>())
140 return true;
141 return false;
142}
MLIR Teamd038e342019-03-01 19:50:25143
MLIR Team6892ffb2018-12-20 04:42:55144// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:35145// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55146// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27147// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55148// TODO(andydavis) Add a depth parameter to dependence graph construction.
149struct MemRefDependenceGraph {
150public:
151 // Node represents a node in the graph. A Node is either an entire loop nest
152 // rooted at the top level which contains loads/stores, or a top level
153 // load/store.
154 struct Node {
155 // The unique identifier of this node in the graph.
156 unsigned id;
157 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35158 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41159 // List of load operations.
River Riddleb4992772019-02-04 18:38:47160 SmallVector<Instruction *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35161 // List of store op insts.
River Riddleb4992772019-02-04 18:38:47162 SmallVector<Instruction *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35163 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55164
165 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10166 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55167 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35168 for (auto *loadOpInst : loads) {
169 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55170 ++loadOpCount;
171 }
172 return loadOpCount;
173 }
174
175 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10176 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55177 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35178 for (auto *storeOpInst : stores) {
179 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55180 ++storeOpCount;
181 }
182 return storeOpCount;
183 }
MLIR Team58aa3832019-02-16 01:12:19184
MLIR Teamd038e342019-03-01 19:50:25185 // Returns all store ops in 'storeOps' which access 'memref'.
MLIR Team58aa3832019-02-16 01:12:19186 void getStoreOpsForMemref(Value *memref,
187 SmallVectorImpl<Instruction *> *storeOps) {
188 for (auto *storeOpInst : stores) {
189 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
190 storeOps->push_back(storeOpInst);
191 }
192 }
MLIR Teamd038e342019-03-01 19:50:25193
194 // Returns all load ops in 'loadOps' which access 'memref'.
195 void getLoadOpsForMemref(Value *memref,
196 SmallVectorImpl<Instruction *> *loadOps) {
197 for (auto *loadOpInst : loads) {
198 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
199 loadOps->push_back(loadOpInst);
200 }
201 }
202
203 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
204 // has at least one load and store operation.
205 void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
206 llvm::SmallDenseSet<Value *, 2> loadMemrefs;
207 for (auto *loadOpInst : loads) {
208 loadMemrefs.insert(loadOpInst->cast<LoadOp>()->getMemRef());
209 }
210 for (auto *storeOpInst : stores) {
211 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
212 if (loadMemrefs.count(memref) > 0)
213 loadAndStoreMemrefSet->insert(memref);
214 }
215 }
MLIR Team6892ffb2018-12-20 04:42:55216 };
217
MLIR Teama0f3db402019-01-29 17:36:41218 // Edge represents a data dependece between nodes in the graph.
MLIR Team6892ffb2018-12-20 04:42:55219 struct Edge {
220 // The id of the node at the other end of the edge.
MLIR Team1e851912019-01-31 00:01:46221 // If this edge is stored in Edge = Node.inEdges[i], then
222 // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
223 // If this edge is stored in Edge = Node.outEdges[i], then
224 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
MLIR Team6892ffb2018-12-20 04:42:55225 unsigned id;
MLIR Teama0f3db402019-01-29 17:36:41226 // The SSA value on which this edge represents a dependence.
227 // If the value is a memref, then the dependence is between graph nodes
228 // which contain accesses to the same memref 'value'. If the value is a
229 // non-memref value, then the dependence is between a graph node which
230 // defines an SSA value and another graph node which uses the SSA value
231 // (e.g. a constant instruction defining a value which is used inside a loop
232 // nest).
233 Value *value;
MLIR Team6892ffb2018-12-20 04:42:55234 };
235
236 // Map from node id to Node.
237 DenseMap<unsigned, Node> nodes;
238 // Map from node id to list of input edges.
239 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
240 // Map from node id to list of output edges.
241 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27242 // Map from memref to a count on the dependence edges associated with that
243 // memref.
244 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Teama0f3db402019-01-29 17:36:41245 // The next unique identifier to use for newly created graph nodes.
246 unsigned nextNodeId = 0;
MLIR Team6892ffb2018-12-20 04:42:55247
248 MemRefDependenceGraph() {}
249
250 // Initializes the dependence graph based on operations in 'f'.
251 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09252 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55253
254 // Returns the graph node for 'id'.
255 Node *getNode(unsigned id) {
256 auto it = nodes.find(id);
257 assert(it != nodes.end());
258 return &it->second;
259 }
260
MLIR Teama0f3db402019-01-29 17:36:41261 // Adds a node with 'inst' to the graph and returns its unique identifier.
262 unsigned addNode(Instruction *inst) {
263 Node node(nextNodeId++, inst);
264 nodes.insert({node.id, node});
265 return node.id;
266 }
267
MLIR Teamc4237ae2019-01-18 16:56:27268 // Remove node 'id' (and its associated edges) from graph.
269 void removeNode(unsigned id) {
270 // Remove each edge in 'inEdges[id]'.
271 if (inEdges.count(id) > 0) {
272 SmallVector<Edge, 2> oldInEdges = inEdges[id];
273 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41274 removeEdge(inEdge.id, id, inEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27275 }
276 }
277 // Remove each edge in 'outEdges[id]'.
278 if (outEdges.count(id) > 0) {
279 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
280 for (auto &outEdge : oldOutEdges) {
MLIR Teama0f3db402019-01-29 17:36:41281 removeEdge(id, outEdge.id, outEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27282 }
283 }
284 // Erase remaining node state.
285 inEdges.erase(id);
286 outEdges.erase(id);
287 nodes.erase(id);
288 }
289
MLIR Teamd7c82442019-01-30 23:53:41290 // Returns true if node 'id' writes to any memref which escapes (or is an
291 // argument to) the function/block. Returns false otherwise.
292 bool writesToLiveInOrEscapingMemrefs(unsigned id) {
MLIR Team71495d52019-01-22 21:23:37293 Node *node = getNode(id);
294 for (auto *storeOpInst : node->stores) {
295 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
296 auto *inst = memref->getDefiningInst();
MLIR Team58aa3832019-02-16 01:12:19297 // Return true if 'memref' is a block argument.
River Riddleb4992772019-02-04 18:38:47298 if (!inst)
MLIR Teamd7c82442019-01-30 23:53:41299 return true;
MLIR Team58aa3832019-02-16 01:12:19300 // Return true if any use of 'memref' escapes the function.
River Riddleb4992772019-02-04 18:38:47301 for (auto &use : memref->getUses())
302 if (!isMemRefDereferencingOp(*use.getOwner()))
MLIR Teamd7c82442019-01-30 23:53:41303 return true;
MLIR Teamd7c82442019-01-30 23:53:41304 }
305 return false;
306 }
307
308 // Returns true if node 'id' can be removed from the graph. Returns false
309 // otherwise. A node can be removed from the graph iff the following
310 // conditions are met:
311 // *) The node does not write to any memref which escapes (or is a
312 // function/block argument).
313 // *) The node has no successors in the dependence graph.
314 bool canRemoveNode(unsigned id) {
315 if (writesToLiveInOrEscapingMemrefs(id))
316 return false;
317 Node *node = getNode(id);
318 for (auto *storeOpInst : node->stores) {
MLIR Teama0f3db402019-01-29 17:36:41319 // Return false if there exist out edges from 'id' on 'memref'.
MLIR Teamd7c82442019-01-30 23:53:41320 if (getOutEdgeCount(id, storeOpInst->cast<StoreOp>()->getMemRef()) > 0)
MLIR Teama0f3db402019-01-29 17:36:41321 return false;
MLIR Team71495d52019-01-22 21:23:37322 }
MLIR Teama0f3db402019-01-29 17:36:41323 return true;
MLIR Team71495d52019-01-22 21:23:37324 }
325
MLIR Teamd038e342019-03-01 19:50:25326 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
327 // is for 'value' if non-null, or for any value otherwise. Returns false
328 // otherwise.
329 bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) {
MLIR Team27d067e2019-01-16 17:55:02330 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
331 return false;
332 }
333 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
MLIR Teamd038e342019-03-01 19:50:25334 return edge.id == dstId && (!value || edge.value == value);
MLIR Team27d067e2019-01-16 17:55:02335 });
336 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
MLIR Teamd038e342019-03-01 19:50:25337 return edge.id == srcId && (!value || edge.value == value);
MLIR Team27d067e2019-01-16 17:55:02338 });
339 return hasOutEdge && hasInEdge;
340 }
341
MLIR Teama0f3db402019-01-29 17:36:41342 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
343 void addEdge(unsigned srcId, unsigned dstId, Value *value) {
344 if (!hasEdge(srcId, dstId, value)) {
345 outEdges[srcId].push_back({dstId, value});
346 inEdges[dstId].push_back({srcId, value});
347 if (value->getType().isa<MemRefType>())
348 memrefEdgeCount[value]++;
MLIR Team27d067e2019-01-16 17:55:02349 }
MLIR Team6892ffb2018-12-20 04:42:55350 }
351
MLIR Teama0f3db402019-01-29 17:36:41352 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
353 void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team6892ffb2018-12-20 04:42:55354 assert(inEdges.count(dstId) > 0);
355 assert(outEdges.count(srcId) > 0);
MLIR Teama0f3db402019-01-29 17:36:41356 if (value->getType().isa<MemRefType>()) {
357 assert(memrefEdgeCount.count(value) > 0);
358 memrefEdgeCount[value]--;
359 }
MLIR Team6892ffb2018-12-20 04:42:55360 // Remove 'srcId' from 'inEdges[dstId]'.
361 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41362 if ((*it).id == srcId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55363 inEdges[dstId].erase(it);
364 break;
365 }
366 }
367 // Remove 'dstId' from 'outEdges[srcId]'.
368 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41369 if ((*it).id == dstId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55370 outEdges[srcId].erase(it);
371 break;
372 }
373 }
374 }
375
MLIR Teamd038e342019-03-01 19:50:25376 // Returns true if there is a path in the dependence graph from node 'srcId'
377 // to node 'dstId'. Returns false otherwise.
378 bool hasDependencePath(unsigned srcId, unsigned dstId) {
379 // Worklist state is: <node-id, next-output-edge-index-to-visit>
380 SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
381 worklist.push_back({srcId, 0});
382 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
383 while (!worklist.empty()) {
384 auto &idAndIndex = worklist.back();
385 // Return true if we have reached 'dstId'.
386 if (idAndIndex.first == dstId)
387 return true;
388 // Pop and continue if node has no out edges, or if all out edges have
389 // already been visited.
390 if (outEdges.count(idAndIndex.first) == 0 ||
391 idAndIndex.second == outEdges[idAndIndex.first].size()) {
392 worklist.pop_back();
393 continue;
394 }
395 // Get graph edge to traverse.
396 Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
397 // Increment next output edge index for 'idAndIndex'.
398 ++idAndIndex.second;
399 // Add node at 'edge.id' to worklist.
400 worklist.push_back({edge.id, 0});
401 }
402 return false;
403 }
404
MLIR Teama0f3db402019-01-29 17:36:41405 // Returns the input edge count for node 'id' and 'memref' from src nodes
MLIR Teamd038e342019-03-01 19:50:25406 // which access 'memref' with a store operation.
MLIR Teama0f3db402019-01-29 17:36:41407 unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55408 unsigned inEdgeCount = 0;
409 if (inEdges.count(id) > 0)
410 for (auto &inEdge : inEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41411 if (inEdge.value == memref) {
412 Node *srcNode = getNode(inEdge.id);
413 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
MLIR Teamd038e342019-03-01 19:50:25414 if (srcNode->getStoreOpCount(memref) > 0)
MLIR Teama0f3db402019-01-29 17:36:41415 ++inEdgeCount;
416 }
MLIR Team6892ffb2018-12-20 04:42:55417 return inEdgeCount;
418 }
419
MLIR Teamd038e342019-03-01 19:50:25420 // Returns the output edge count for node 'id' and 'memref' (if non-null),
421 // otherwise returns the total output edge count from node 'id'.
422 unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) {
MLIR Team6892ffb2018-12-20 04:42:55423 unsigned outEdgeCount = 0;
424 if (outEdges.count(id) > 0)
425 for (auto &outEdge : outEdges[id])
MLIR Teamd038e342019-03-01 19:50:25426 if (!memref || outEdge.value == memref)
MLIR Team6892ffb2018-12-20 04:42:55427 ++outEdgeCount;
428 return outEdgeCount;
429 }
430
MLIR Teama0f3db402019-01-29 17:36:41431 // Computes and returns an insertion point instruction, before which the
432 // the fused <srcId, dstId> loop nest can be inserted while preserving
433 // dependences. Returns nullptr if no such insertion point is found.
MLIR Teama78edcd2019-02-05 14:57:08434 Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
MLIR Team5c5739d2019-01-25 06:27:40435 if (outEdges.count(srcId) == 0)
MLIR Teama0f3db402019-01-29 17:36:41436 return getNode(dstId)->inst;
437
438 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
439 SmallPtrSet<Instruction *, 2> srcDepInsts;
440 for (auto &outEdge : outEdges[srcId])
MLIR Teama78edcd2019-02-05 14:57:08441 if (outEdge.id != dstId)
MLIR Teama0f3db402019-01-29 17:36:41442 srcDepInsts.insert(getNode(outEdge.id)->inst);
443
444 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
445 SmallPtrSet<Instruction *, 2> dstDepInsts;
446 for (auto &inEdge : inEdges[dstId])
MLIR Teama78edcd2019-02-05 14:57:08447 if (inEdge.id != srcId)
MLIR Teama0f3db402019-01-29 17:36:41448 dstDepInsts.insert(getNode(inEdge.id)->inst);
449
450 Instruction *srcNodeInst = getNode(srcId)->inst;
451 Instruction *dstNodeInst = getNode(dstId)->inst;
452
453 // Computing insertion point:
454 // *) Walk all instruction positions in Block instruction list in the
455 // range (src, dst). For each instruction 'inst' visited in this search:
456 // *) Store in 'firstSrcDepPos' the first position where 'inst' has a
457 // dependence edge from 'srcNode'.
458 // *) Store in 'lastDstDepPost' the last position where 'inst' has a
459 // dependence edge to 'dstNode'.
460 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
461 // instruction insertion point (or return null pointer if no such
462 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
463 SmallVector<Instruction *, 2> depInsts;
464 Optional<unsigned> firstSrcDepPos;
465 Optional<unsigned> lastDstDepPos;
466 unsigned pos = 0;
467 for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
468 it != Block::iterator(dstNodeInst); ++it) {
469 Instruction *inst = &(*it);
470 if (srcDepInsts.count(inst) > 0 && firstSrcDepPos == None)
471 firstSrcDepPos = pos;
472 if (dstDepInsts.count(inst) > 0)
473 lastDstDepPos = pos;
474 depInsts.push_back(inst);
475 ++pos;
MLIR Team5c5739d2019-01-25 06:27:40476 }
MLIR Teama0f3db402019-01-29 17:36:41477
478 if (firstSrcDepPos.hasValue()) {
479 if (lastDstDepPos.hasValue()) {
480 if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
481 // No valid insertion point exists which preserves dependences.
482 return nullptr;
483 }
484 }
485 // Return the insertion point at 'firstSrcDepPos'.
486 return depInsts[firstSrcDepPos.getValue()];
487 }
488 // No dependence targets in range (or only dst deps in range), return
489 // 'dstNodInst' insertion point.
490 return dstNodeInst;
MLIR Team6892ffb2018-12-20 04:42:55491 }
492
MLIR Teama0f3db402019-01-29 17:36:41493 // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
494 // has been replaced in node at 'dstId' by a private memref.
495 void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
MLIR Team6892ffb2018-12-20 04:42:55496 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
497 if (inEdges.count(srcId) > 0) {
498 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
499 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41500 // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
501 if (inEdge.value != oldMemRef)
502 addEdge(inEdge.id, dstId, inEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55503 }
504 }
MLIR Teamc4237ae2019-01-18 16:56:27505 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55506 if (outEdges.count(srcId) > 0) {
507 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
508 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27509 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
510 if (outEdge.id == dstId)
MLIR Teama0f3db402019-01-29 17:36:41511 removeEdge(srcId, outEdge.id, outEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55512 }
513 }
MLIR Teama0f3db402019-01-29 17:36:41514 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
515 // replaced by a private memref). These edges could come from nodes
516 // other than 'srcId' which were removed in the previous step.
517 if (inEdges.count(dstId) > 0) {
518 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
519 for (auto &inEdge : oldInEdges)
520 if (inEdge.value == oldMemRef)
521 removeEdge(inEdge.id, dstId, inEdge.value);
522 }
MLIR Team6892ffb2018-12-20 04:42:55523 }
524
MLIR Teamd038e342019-03-01 19:50:25525 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
526 // of sibling node 'sidId' into node 'dstId'.
527 void updateEdges(unsigned sibId, unsigned dstId) {
528 // For each edge in 'inEdges[sibId]':
529 // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
530 // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
531 if (inEdges.count(sibId) > 0) {
532 SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
533 for (auto &inEdge : oldInEdges) {
534 addEdge(inEdge.id, dstId, inEdge.value);
535 removeEdge(inEdge.id, sibId, inEdge.value);
536 }
537 }
538
539 // For each edge in 'outEdges[sibId]' to node 'id'
540 // *) Add new edge from 'dstId' to 'outEdge.id'.
541 // *) Remove edge from 'sibId' to 'outEdge.id'.
542 if (outEdges.count(sibId) > 0) {
543 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
544 for (auto &outEdge : oldOutEdges) {
545 addEdge(dstId, outEdge.id, outEdge.value);
546 removeEdge(sibId, outEdge.id, outEdge.value);
547 }
548 }
549 }
550
MLIR Team6892ffb2018-12-20 04:42:55551 // Adds ops in 'loads' and 'stores' to node at 'id'.
River Riddleb4992772019-02-04 18:38:47552 void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads,
553 const SmallVectorImpl<Instruction *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55554 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35555 for (auto *loadOpInst : loads)
556 node->loads.push_back(loadOpInst);
557 for (auto *storeOpInst : stores)
558 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55559 }
560
MLIR Teamc4237ae2019-01-18 16:56:27561 void clearNodeLoadAndStores(unsigned id) {
562 Node *node = getNode(id);
563 node->loads.clear();
564 node->stores.clear();
565 }
566
MLIR Teamd038e342019-03-01 19:50:25567 // Calls 'callback' for each input edge incident to node 'id' which carries a
568 // memref dependence.
569 void forEachMemRefInputEdge(unsigned id,
570 const std::function<void(Edge)> &callback) {
571 if (inEdges.count(id) > 0)
572 forEachMemRefEdge(inEdges[id], callback);
573 }
574 // Calls 'callback' for each output edge from node 'id' which carries a
575 // memref dependence.
576 void forEachMemRefOutputEdge(unsigned id,
577 const std::function<void(Edge)> &callback) {
578 if (outEdges.count(id) > 0)
579 forEachMemRefEdge(outEdges[id], callback);
580 }
581 // Calls 'callback' for each edge in 'edges' which carries a memref
582 // dependence.
583 void forEachMemRefEdge(ArrayRef<Edge> edges,
584 const std::function<void(Edge)> &callback) {
585 for (auto &edge : edges) {
586 // Skip if 'edge' is not a memref dependence edge.
587 if (!edge.value->getType().isa<MemRefType>())
588 continue;
589 assert(nodes.count(edge.id) > 0);
590 // Skip if 'edge.id' is not a loop nest.
591 if (!getNode(edge.id)->inst->isa<AffineForOp>())
592 continue;
593 // Visit current input edge 'edge'.
594 callback(edge);
595 }
596 }
597
MLIR Team6892ffb2018-12-20 04:42:55598 void print(raw_ostream &os) const {
599 os << "\nMemRefDependenceGraph\n";
600 os << "\nNodes:\n";
601 for (auto &idAndNode : nodes) {
602 os << "Node: " << idAndNode.first << "\n";
603 auto it = inEdges.find(idAndNode.first);
604 if (it != inEdges.end()) {
605 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41606 os << " InEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55607 }
608 it = outEdges.find(idAndNode.first);
609 if (it != outEdges.end()) {
610 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41611 os << " OutEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55612 }
613 }
614 }
615 void dump() const { print(llvm::errs()); }
616};
617
Chris Lattner456ad6a2018-12-29 00:05:35618// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55619// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39620// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55621// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09622bool MemRefDependenceGraph::init(Function *f) {
Chris Lattner3f190312018-12-27 22:35:10623 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43624
625 // TODO: support multi-block functions.
626 if (f->getBlocks().size() != 1)
627 return false;
628
River Riddle5052bd82019-02-02 00:42:18629 DenseMap<Instruction *, unsigned> forToNodeMap;
Chris Lattnerdffc5892018-12-29 23:33:43630 for (auto &inst : f->front()) {
River Riddleb4992772019-02-04 18:38:47631 if (auto forOp = inst.dyn_cast<AffineForOp>()) {
River Riddle5052bd82019-02-02 00:42:18632 // Create graph node 'id' to represent top-level 'forOp' and record
MLIR Team6892ffb2018-12-20 04:42:55633 // all loads and store accesses it contains.
634 LoopNestStateCollector collector;
River Riddlebf9c3812019-02-05 00:24:44635 collector.collect(&inst);
Uday Bondhugula4ba8c912019-02-07 05:54:18636 // Return false if a non 'for' region was found (not currently supported).
River Riddle75553832019-01-29 05:23:53637 if (collector.hasNonForRegion)
MLIR Team6892ffb2018-12-20 04:42:55638 return false;
MLIR Teama0f3db402019-01-29 17:36:41639 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35640 for (auto *opInst : collector.loadOpInsts) {
641 node.loads.push_back(opInst);
642 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55643 memrefAccesses[memref].insert(node.id);
644 }
Chris Lattner456ad6a2018-12-29 00:05:35645 for (auto *opInst : collector.storeOpInsts) {
646 node.stores.push_back(opInst);
647 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55648 memrefAccesses[memref].insert(node.id);
649 }
River Riddle5052bd82019-02-02 00:42:18650 forToNodeMap[&inst] = node.id;
MLIR Team6892ffb2018-12-20 04:42:55651 nodes.insert({node.id, node});
River Riddleb4992772019-02-04 18:38:47652 } else if (auto loadOp = inst.dyn_cast<LoadOp>()) {
653 // Create graph node for top-level load op.
654 Node node(nextNodeId++, &inst);
655 node.loads.push_back(&inst);
656 auto *memref = inst.cast<LoadOp>()->getMemRef();
657 memrefAccesses[memref].insert(node.id);
658 nodes.insert({node.id, node});
659 } else if (auto storeOp = inst.dyn_cast<StoreOp>()) {
660 // Create graph node for top-level store op.
661 Node node(nextNodeId++, &inst);
662 node.stores.push_back(&inst);
663 auto *memref = inst.cast<StoreOp>()->getMemRef();
664 memrefAccesses[memref].insert(node.id);
665 nodes.insert({node.id, node});
666 } else if (inst.getNumBlockLists() != 0) {
667 // Return false if another region is found (not currently supported).
668 return false;
669 } else if (inst.getNumResults() > 0 && !inst.use_empty()) {
670 // Create graph node for top-level producer of SSA values, which
671 // could be used by loop nest nodes.
672 Node node(nextNodeId++, &inst);
673 nodes.insert({node.id, node});
MLIR Teama0f3db402019-01-29 17:36:41674 }
675 }
676
677 // Add dependence edges between nodes which produce SSA values and their
678 // users.
679 for (auto &idAndNode : nodes) {
680 const Node &node = idAndNode.second;
681 if (!node.loads.empty() || !node.stores.empty())
682 continue;
River Riddleb4992772019-02-04 18:38:47683 auto *opInst = node.inst;
MLIR Teama0f3db402019-01-29 17:36:41684 for (auto *value : opInst->getResults()) {
685 for (auto &use : value->getUses()) {
River Riddle5052bd82019-02-02 00:42:18686 SmallVector<OpPointer<AffineForOp>, 4> loops;
River Riddleb4992772019-02-04 18:38:47687 getLoopIVs(*use.getOwner(), &loops);
MLIR Teama0f3db402019-01-29 17:36:41688 if (loops.empty())
689 continue;
River Riddle5052bd82019-02-02 00:42:18690 assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
691 unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()];
MLIR Teama0f3db402019-01-29 17:36:41692 addEdge(node.id, userLoopNestId, value);
MLIR Team6892ffb2018-12-20 04:42:55693 }
694 }
MLIR Team6892ffb2018-12-20 04:42:55695 }
696
697 // Walk memref access lists and add graph edges between dependent nodes.
698 for (auto &memrefAndList : memrefAccesses) {
699 unsigned n = memrefAndList.second.size();
700 for (unsigned i = 0; i < n; ++i) {
701 unsigned srcId = memrefAndList.second[i];
702 bool srcHasStore =
703 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
704 for (unsigned j = i + 1; j < n; ++j) {
705 unsigned dstId = memrefAndList.second[j];
706 bool dstHasStore =
707 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
708 if (srcHasStore || dstHasStore)
709 addEdge(srcId, dstId, memrefAndList.first);
710 }
711 }
712 }
713 return true;
714}
715
MLIR Team38c2fe32019-01-14 19:26:25716namespace {
717
718// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
719// and operation count) for a loop nest up until the innermost loop body.
720struct LoopNestStats {
River Riddle5052bd82019-02-02 00:42:18721 // Map from AffineForOp to immediate child AffineForOps in its loop body.
722 DenseMap<Instruction *, SmallVector<OpPointer<AffineForOp>, 2>> loopMap;
723 // Map from AffineForOp to count of operations in its loop body.
724 DenseMap<Instruction *, uint64_t> opCountMap;
725 // Map from AffineForOp to its constant trip count.
726 DenseMap<Instruction *, uint64_t> tripCountMap;
MLIR Team38c2fe32019-01-14 19:26:25727};
728
729// LoopNestStatsCollector walks a single loop nest and gathers per-loop
730// trip count and operation count statistics and records them in 'stats'.
River Riddlebf9c3812019-02-05 00:24:44731struct LoopNestStatsCollector {
MLIR Team38c2fe32019-01-14 19:26:25732 LoopNestStats *stats;
733 bool hasLoopWithNonConstTripCount = false;
734
735 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
736
River Riddlebf9c3812019-02-05 00:24:44737 void collect(Instruction *inst) {
738 inst->walk<AffineForOp>([&](OpPointer<AffineForOp> forOp) {
739 auto *forInst = forOp->getInstruction();
740 auto *parentInst = forOp->getInstruction()->getParentInst();
741 if (parentInst != nullptr) {
742 assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
743 // Add mapping to 'forOp' from its parent AffineForOp.
744 stats->loopMap[parentInst].push_back(forOp);
745 }
River Riddle5052bd82019-02-02 00:42:18746
River Riddlebf9c3812019-02-05 00:24:44747 // Record the number of op instructions in the body of 'forOp'.
748 unsigned count = 0;
749 stats->opCountMap[forInst] = 0;
750 for (auto &inst : *forOp->getBody()) {
Uday Bondhugulad4b3ff12019-02-27 00:10:19751 if (!inst.isa<AffineForOp>() && !inst.isa<AffineIfOp>())
River Riddlebf9c3812019-02-05 00:24:44752 ++count;
753 }
754 stats->opCountMap[forInst] = count;
755 // Record trip count for 'forOp'. Set flag if trip count is not
756 // constant.
757 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
758 if (!maybeConstTripCount.hasValue()) {
759 hasLoopWithNonConstTripCount = true;
760 return;
761 }
762 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
763 });
MLIR Team38c2fe32019-01-14 19:26:25764 }
765};
766
River Riddle5052bd82019-02-02 00:42:18767// Computes the total cost of the loop nest rooted at 'forOp'.
MLIR Team38c2fe32019-01-14 19:26:25768// Currently, the total cost is computed by counting the total operation
769// instance count (i.e. total number of operations in the loop bodyloop
770// operation count * loop trip count) for the entire loop nest.
771// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
772// specified in the map when computing the total op instance count.
773// NOTE: this is used to compute the cost of computation slices, which are
774// sliced along the iteration dimension, and thus reduce the trip count.
River Riddle5052bd82019-02-02 00:42:18775// If 'computeCostMap' is non-null, the total op count for forOps specified
MLIR Team38c2fe32019-01-14 19:26:25776// in the map is increased (not overridden) by adding the op count from the
777// map to the existing op count for the for loop. This is done before
778// multiplying by the loop's trip count, and is used to model the cost of
779// inserting a sliced loop nest of known cost into the loop's body.
780// NOTE: this is used to compute the cost of fusing a slice of some loop nest
781// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24782static int64_t getComputeCost(
River Riddle5052bd82019-02-02 00:42:18783 Instruction *forInst, LoopNestStats *stats,
784 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountOverrideMap,
785 DenseMap<Instruction *, int64_t> *computeCostMap) {
786 // 'opCount' is the total number operations in one iteration of 'forOp' body
Uday Bondhugula864d9e02019-01-23 17:16:24787 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25788 if (stats->loopMap.count(forInst) > 0) {
River Riddle5052bd82019-02-02 00:42:18789 for (auto childForOp : stats->loopMap[forInst]) {
790 opCount += getComputeCost(childForOp->getInstruction(), stats,
791 tripCountOverrideMap, computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:25792 }
793 }
794 // Add in additional op instances from slice (if specified in map).
795 if (computeCostMap != nullptr) {
796 auto it = computeCostMap->find(forInst);
797 if (it != computeCostMap->end()) {
798 opCount += it->second;
799 }
800 }
801 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24802 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25803 if (tripCountOverrideMap != nullptr) {
804 auto it = tripCountOverrideMap->find(forInst);
805 if (it != tripCountOverrideMap->end()) {
806 tripCount = it->second;
807 }
808 }
809 // Returns the total number of dynamic instances of operations in loop body.
810 return tripCount * opCount;
811}
812
813} // end anonymous namespace
814
Uday Bondhugula7aa60a32019-02-27 01:32:47815// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
MLIR Team27d067e2019-01-16 17:55:02816static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00817 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
818 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02819 assert(lbMap.getNumDims() == ubMap.getNumDims());
820 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
MLIR Team27d067e2019-01-16 17:55:02821 AffineExpr lbExpr(lbMap.getResult(0));
822 AffineExpr ubExpr(ubMap.getResult(0));
823 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
824 lbMap.getNumSymbols());
825 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
826 if (!cExpr)
827 return None;
828 return cExpr.getValue();
829}
830
River Riddle5052bd82019-02-02 00:42:18831// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
MLIR Team38c2fe32019-01-14 19:26:25832// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
833// Returns true on success, false otherwise (if a non-constant trip count
834// was encountered).
835// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02836static bool buildSliceTripCountMap(
River Riddleb4992772019-02-04 18:38:47837 Instruction *srcOpInst, ComputationSliceState *sliceState,
River Riddle5052bd82019-02-02 00:42:18838 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
839 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02840 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25841 unsigned numSrcLoopIVs = srcLoopIVs.size();
River Riddle5052bd82019-02-02 00:42:18842 // Populate map from AffineForOp -> trip count
MLIR Team38c2fe32019-01-14 19:26:25843 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
844 AffineMap lbMap = sliceState->lbs[i];
845 AffineMap ubMap = sliceState->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17846 if (lbMap == AffineMap() || ubMap == AffineMap()) {
MLIR Team38c2fe32019-01-14 19:26:25847 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
848 if (srcLoopIVs[i]->hasConstantLowerBound() &&
849 srcLoopIVs[i]->hasConstantUpperBound()) {
River Riddle5052bd82019-02-02 00:42:18850 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] =
MLIR Team38c2fe32019-01-14 19:26:25851 srcLoopIVs[i]->getConstantUpperBound() -
852 srcLoopIVs[i]->getConstantLowerBound();
853 continue;
854 }
855 return false;
856 }
MLIR Team27d067e2019-01-16 17:55:02857 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
858 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25859 return false;
River Riddle5052bd82019-02-02 00:42:18860 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25861 }
862 return true;
863}
864
MLIR Team27d067e2019-01-16 17:55:02865// Removes load operations from 'srcLoads' which operate on 'memref', and
866// adds them to 'dstLoads'.
867static void
868moveLoadsAccessingMemrefTo(Value *memref,
River Riddleb4992772019-02-04 18:38:47869 SmallVectorImpl<Instruction *> *srcLoads,
870 SmallVectorImpl<Instruction *> *dstLoads) {
MLIR Team27d067e2019-01-16 17:55:02871 dstLoads->clear();
River Riddleb4992772019-02-04 18:38:47872 SmallVector<Instruction *, 4> srcLoadsToKeep;
MLIR Team27d067e2019-01-16 17:55:02873 for (auto *load : *srcLoads) {
874 if (load->cast<LoadOp>()->getMemRef() == memref)
875 dstLoads->push_back(load);
876 else
877 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25878 }
MLIR Team27d067e2019-01-16 17:55:02879 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25880}
881
MLIR Team27d067e2019-01-16 17:55:02882// Returns the innermost common loop depth for the set of operations in 'ops'.
River Riddleb4992772019-02-04 18:38:47883static unsigned getInnermostCommonLoopDepth(ArrayRef<Instruction *> ops) {
MLIR Team27d067e2019-01-16 17:55:02884 unsigned numOps = ops.size();
885 assert(numOps > 0);
886
River Riddle5052bd82019-02-02 00:42:18887 std::vector<SmallVector<OpPointer<AffineForOp>, 4>> loops(numOps);
MLIR Team27d067e2019-01-16 17:55:02888 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
889 for (unsigned i = 0; i < numOps; ++i) {
890 getLoopIVs(*ops[i], &loops[i]);
891 loopDepthLimit =
892 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25893 }
MLIR Team27d067e2019-01-16 17:55:02894
895 unsigned loopDepth = 0;
896 for (unsigned d = 0; d < loopDepthLimit; ++d) {
897 unsigned i;
898 for (i = 1; i < numOps; ++i) {
River Riddle5052bd82019-02-02 00:42:18899 if (loops[i - 1][d] != loops[i][d])
MLIR Team27d067e2019-01-16 17:55:02900 break;
MLIR Team27d067e2019-01-16 17:55:02901 }
902 if (i != numOps)
903 break;
904 ++loopDepth;
905 }
906 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25907}
908
MLIR Teamd7c82442019-01-30 23:53:41909// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
910// and 'storeOpInsts' are satisfied.
River Riddleb4992772019-02-04 18:38:47911static unsigned getMaxLoopDepth(ArrayRef<Instruction *> loadOpInsts,
912 ArrayRef<Instruction *> storeOpInsts) {
MLIR Teamd7c82442019-01-30 23:53:41913 // Merge loads and stores into the same array.
River Riddleb4992772019-02-04 18:38:47914 SmallVector<Instruction *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
MLIR Teamd7c82442019-01-30 23:53:41915 ops.append(storeOpInsts.begin(), storeOpInsts.end());
916
917 // Compute the innermost common loop depth for loads and stores.
918 unsigned loopDepth = getInnermostCommonLoopDepth(ops);
919
920 // Return common loop depth for loads if there are no store ops.
921 if (storeOpInsts.empty())
922 return loopDepth;
923
924 // Check dependences on all pairs of ops in 'ops' and store the minimum
925 // loop depth at which a dependence is satisfied.
926 for (unsigned i = 0, e = ops.size(); i < e; ++i) {
927 auto *srcOpInst = ops[i];
928 MemRefAccess srcAccess(srcOpInst);
929 for (unsigned j = 0; j < e; ++j) {
930 auto *dstOpInst = ops[j];
931 MemRefAccess dstAccess(dstOpInst);
932
933 unsigned numCommonLoops =
934 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
935 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
936 FlatAffineConstraints dependenceConstraints;
937 // TODO(andydavis) Cache dependence analysis results, check cache here.
938 if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
939 &dependenceConstraints,
940 /*dependenceComponents=*/nullptr)) {
941 // Store minimum loop depth and break because we want the min 'd' at
942 // which there is a dependence.
943 loopDepth = std::min(loopDepth, d - 1);
944 break;
945 }
946 }
947 }
948 }
949 return loopDepth;
950}
951
MLIR Team8f5f2c72019-02-15 17:32:18952// Compute loop interchange permutation:
953// *) Computes dependence components between all op pairs in 'ops' for loop
954// depths in range [1, 'maxLoopDepth'].
955// *) Classifies the outermost 'maxLoopDepth' loops surrounding 'ops' as either
956// parallel or sequential.
957// *) Computes the loop permutation which sinks sequential loops deeper into
958// the loop nest, while preserving the relative order between other loops.
959// *) Checks each dependence component against the permutation to see if the
960// desired loop interchange would violated dependences by making the a
961// dependence componenent lexicographically negative.
962// TODO(andydavis) Move this function to LoopUtils.
963static bool
964computeLoopInterchangePermutation(ArrayRef<Instruction *> ops,
965 unsigned maxLoopDepth,
966 SmallVectorImpl<unsigned> *loopPermMap) {
967 // Gather dependence components for dependences between all ops in 'ops'
968 // at loop depths in range [1, maxLoopDepth].
969 // TODO(andydavis) Refactor this loop into a LoopUtil utility function:
970 // mlir::getDependenceComponents().
971 // TODO(andydavis) Split this loop into two: first check all dependences,
972 // and construct dep vectors. Then, scan through them to detect the parallel
973 // ones.
974 std::vector<llvm::SmallVector<DependenceComponent, 2>> depCompsVec;
975 llvm::SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
976 unsigned numOps = ops.size();
977 for (unsigned d = 1; d <= maxLoopDepth; ++d) {
978 for (unsigned i = 0; i < numOps; ++i) {
979 auto *srcOpInst = ops[i];
980 MemRefAccess srcAccess(srcOpInst);
981 for (unsigned j = 0; j < numOps; ++j) {
982 auto *dstOpInst = ops[j];
983 MemRefAccess dstAccess(dstOpInst);
984
985 FlatAffineConstraints dependenceConstraints;
986 llvm::SmallVector<DependenceComponent, 2> depComps;
987 // TODO(andydavis,bondhugula) Explore whether it would be profitable
988 // to pre-compute and store deps instead of repeatidly checking.
989 if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
990 &dependenceConstraints, &depComps)) {
991 isParallelLoop[d - 1] = false;
992 depCompsVec.push_back(depComps);
993 }
994 }
995 }
996 }
997 // Count the number of parallel loops.
998 unsigned numParallelLoops = 0;
999 for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i)
1000 if (isParallelLoop[i])
1001 ++numParallelLoops;
1002
1003 // Compute permutation of loops that sinks sequential loops (and thus raises
1004 // parallel loops) while preserving relative order.
1005 llvm::SmallVector<unsigned, 4> loopPermMapInv;
1006 loopPermMapInv.resize(maxLoopDepth);
1007 loopPermMap->resize(maxLoopDepth);
1008 unsigned nextSequentialLoop = numParallelLoops;
1009 unsigned nextParallelLoop = 0;
1010 for (unsigned i = 0; i < maxLoopDepth; ++i) {
1011 if (isParallelLoop[i]) {
1012 (*loopPermMap)[i] = nextParallelLoop;
1013 loopPermMapInv[nextParallelLoop++] = i;
1014 } else {
1015 (*loopPermMap)[i] = nextSequentialLoop;
1016 loopPermMapInv[nextSequentialLoop++] = i;
1017 }
1018 }
1019
1020 // Check each dependence component against the permutation to see if the
1021 // desired loop interchange permutation would make the dependence vectors
1022 // lexicographically negative.
1023 // Example 1: [-1, 1][0, 0]
1024 // Example 2: [0, 0][-1, 1]
1025 for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
1026 llvm::SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
1027 assert(depComps.size() >= maxLoopDepth);
1028 // Check if the first non-zero dependence component is positive.
1029 for (unsigned j = 0; j < maxLoopDepth; ++j) {
1030 unsigned permIndex = loopPermMapInv[j];
1031 assert(depComps[permIndex].lb.hasValue());
1032 int64_t depCompLb = depComps[permIndex].lb.getValue();
1033 if (depCompLb > 0)
1034 break;
1035 if (depCompLb < 0)
1036 return false;
1037 }
1038 }
1039 return true;
1040}
1041
1042// Sinks all sequential loops to the innermost levels (while preserving
1043// relative order among them) and moves all parallel loops to the
1044// outermost (while again preserving relative order among them).
1045// This can increase the loop depth at which we can fuse a slice, since we are
1046// pushing loop carried dependence to a greater depth in the loop nest.
1047static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
1048 assert(node->inst->isa<AffineForOp>());
1049 // Get perfectly nested sequence of loops starting at root of loop nest.
1050 // TODO(andydavis,bondhugula) Share this with similar code in loop tiling.
1051 SmallVector<OpPointer<AffineForOp>, 4> loops;
1052 OpPointer<AffineForOp> curr = node->inst->cast<AffineForOp>();
1053 loops.push_back(curr);
1054 auto *currBody = curr->getBody();
1055 while (!currBody->empty() &&
1056 std::next(currBody->begin()) == currBody->end() &&
1057 (curr = curr->getBody()->front().dyn_cast<AffineForOp>())) {
1058 loops.push_back(curr);
1059 currBody = curr->getBody();
1060 }
1061 if (loops.size() < 2)
1062 return;
1063
1064 // Merge loads and stores into the same array.
1065 SmallVector<Instruction *, 2> memOps(node->loads.begin(), node->loads.end());
1066 memOps.append(node->stores.begin(), node->stores.end());
1067
1068 // Compute loop permutation in 'loopPermMap'.
1069 llvm::SmallVector<unsigned, 4> loopPermMap;
1070 if (!computeLoopInterchangePermutation(memOps, loops.size(), &loopPermMap))
1071 return;
1072
1073 int loopNestRootIndex = -1;
1074 for (int i = loops.size() - 1; i >= 0; --i) {
1075 int permIndex = static_cast<int>(loopPermMap[i]);
1076 // Store the index of the for loop which will be the new loop nest root.
1077 if (permIndex == 0)
1078 loopNestRootIndex = i;
1079 if (permIndex > i) {
1080 // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest.
1081 sinkLoop(loops[i], permIndex - i);
1082 }
1083 }
1084 assert(loopNestRootIndex != -1 && "invalid root index");
1085 node->inst = loops[loopNestRootIndex]->getInstruction();
1086}
1087
Uday Bondhugulac1ca23e2019-01-16 21:13:001088// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
1089// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:021090// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
1091// and 'sliceStateB' are aligned.
1092// Specifically, when taking the union of overlapping intervals, it assumes
1093// that both intervals start at zero. Support needs to be added to take into
1094// account interval start offset when computing the union.
1095// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:001096static bool getSliceUnion(const ComputationSliceState &sliceStateA,
1097 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:021098 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
1099 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
1100
1101 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
1102 AffineMap lbMapA = sliceStateA.lbs[i];
1103 AffineMap ubMapA = sliceStateA.ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:171104 if (lbMapA == AffineMap()) {
1105 assert(ubMapA == AffineMap());
MLIR Team27d067e2019-01-16 17:55:021106 continue;
1107 }
Uday Bondhugulac1ca23e2019-01-16 21:13:001108 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:021109
1110 AffineMap lbMapB = sliceStateB->lbs[i];
1111 AffineMap ubMapB = sliceStateB->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:171112 if (lbMapB == AffineMap()) {
1113 assert(ubMapB == AffineMap());
MLIR Team27d067e2019-01-16 17:55:021114 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
1115 sliceStateB->lbs[i] = lbMapA;
1116 sliceStateB->ubs[i] = ubMapA;
1117 continue;
1118 }
Uday Bondhugulac1ca23e2019-01-16 21:13:001119
1120 // TODO(andydavis) Change this code to take the min across all lower bounds
1121 // and max across all upper bounds for each dimension. This code can for
1122 // cases where a unique min or max could not be statically determined.
1123
1124 // Assumption: both lower bounds are the same.
1125 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:021126 return false;
1127
1128 // Add bound with the largest trip count to union.
1129 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
1130 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
1131 if (!tripCountA.hasValue() || !tripCountB.hasValue())
1132 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:001133
MLIR Team27d067e2019-01-16 17:55:021134 if (tripCountA.getValue() > tripCountB.getValue()) {
1135 sliceStateB->lbs[i] = lbMapA;
1136 sliceStateB->ubs[i] = ubMapA;
1137 }
1138 }
1139 return true;
1140}
1141
Uday Bondhugula8be26272019-02-02 01:06:221142// TODO(mlir-team): improve/complete this when we have target data.
1143unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
1144 auto elementType = memRefType.getElementType();
1145
1146 unsigned sizeInBits;
1147 if (elementType.isIntOrFloat()) {
1148 sizeInBits = elementType.getIntOrFloatBitWidth();
1149 } else {
1150 auto vectorType = elementType.cast<VectorType>();
1151 sizeInBits =
1152 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1153 }
1154 return llvm::divideCeil(sizeInBits, 8);
1155}
1156
MLIR Teamc4237ae2019-01-18 16:56:271157// Creates and returns a private (single-user) memref for fused loop rooted
River Riddle5052bd82019-02-02 00:42:181158// at 'forOp', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:521159// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
1160// TODO(bondhugula): consider refactoring the common code from generateDma and
1161// this one.
River Riddle5052bd82019-02-02 00:42:181162static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
River Riddleb4992772019-02-04 18:38:471163 Instruction *srcStoreOpInst,
Uday Bondhugula8be26272019-02-02 01:06:221164 unsigned dstLoopDepth,
1165 Optional<unsigned> fastMemorySpace,
Uday Bondhugulad4b3ff12019-02-27 00:10:191166 uint64_t localBufSizeThreshold) {
River Riddle5052bd82019-02-02 00:42:181167 auto *forInst = forOp->getInstruction();
1168
1169 // Create builder to insert alloc op just before 'forOp'.
MLIR Teamc4237ae2019-01-18 16:56:271170 FuncBuilder b(forInst);
1171 // Builder to create constants at the top level.
1172 FuncBuilder top(forInst->getFunction());
1173 // Create new memref type based on slice bounds.
1174 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>()->getMemRef();
1175 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
1176 unsigned rank = oldMemRefType.getRank();
1177
Uday Bondhugula94a03f82019-01-22 21:58:521178 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
Uday Bondhugula0f504142019-02-04 21:48:441179 MemRefRegion region(srcStoreOpInst->getLoc());
MLIR Teamd42ef782019-03-04 19:01:251180 bool validRegion = region.compute(srcStoreOpInst, dstLoopDepth);
1181 (void)validRegion;
1182 assert(validRegion && "unexpected memref region failure");
River Riddle6859f332019-01-23 22:39:451183 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:271184 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:521185 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:271186 lbs.reserve(rank);
1187 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:521188 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:271189 Optional<int64_t> numElements =
Uday Bondhugula0f504142019-02-04 21:48:441190 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
Uday Bondhugula8be26272019-02-02 01:06:221191 assert(numElements.hasValue() &&
1192 "non-constant number of elts in local buffer");
MLIR Teamc4237ae2019-01-18 16:56:271193
Uday Bondhugula0f504142019-02-04 21:48:441194 const FlatAffineConstraints *cst = region.getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:521195 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
1196 // on; this would correspond to loop IVs surrounding the level at which the
1197 // slice is being materialized.
1198 SmallVector<Value *, 8> outerIVs;
1199 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
1200
1201 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:271202 SmallVector<AffineExpr, 4> offsets;
1203 offsets.reserve(rank);
1204 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:521205 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
1206
MLIR Teamc4237ae2019-01-18 16:56:271207 AffineExpr offset = top.getAffineConstantExpr(0);
1208 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
1209 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
1210 }
Uday Bondhugula94a03f82019-01-22 21:58:521211 assert(lbDivisors[d] > 0);
1212 offset =
1213 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:271214 offsets.push_back(offset);
1215 }
1216
1217 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
1218 // by 'srcStoreOpInst'.
Uday Bondhugula8be26272019-02-02 01:06:221219 uint64_t bufSize =
1220 getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
1221 unsigned newMemSpace;
Uday Bondhugulad4b3ff12019-02-27 00:10:191222 if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) {
Uday Bondhugula8be26272019-02-02 01:06:221223 newMemSpace = fastMemorySpace.getValue();
1224 } else {
1225 newMemSpace = oldMemRefType.getMemorySpace();
1226 }
1227 auto newMemRefType = top.getMemRefType(
1228 newShape, oldMemRefType.getElementType(), {}, newMemSpace);
MLIR Teamc4237ae2019-01-18 16:56:271229 // Gather alloc operands for the dynamic dimensions of the memref.
1230 SmallVector<Value *, 4> allocOperands;
1231 unsigned dynamicDimCount = 0;
1232 for (auto dimSize : oldMemRefType.getShape()) {
1233 if (dimSize == -1)
1234 allocOperands.push_back(
River Riddle5052bd82019-02-02 00:42:181235 top.create<DimOp>(forOp->getLoc(), oldMemRef, dynamicDimCount++));
MLIR Teamc4237ae2019-01-18 16:56:271236 }
1237
River Riddle5052bd82019-02-02 00:42:181238 // Create new private memref for fused loop 'forOp'.
MLIR Teama0f3db402019-01-29 17:36:411239 // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
1240 // consumer loop nests to reduce their live range. Currently they are added
1241 // at the beginning of the function, because loop nests can be reordered
1242 // during the fusion pass.
MLIR Teamc4237ae2019-01-18 16:56:271243 Value *newMemRef =
River Riddle5052bd82019-02-02 00:42:181244 top.create<AllocOp>(forOp->getLoc(), newMemRefType, allocOperands);
MLIR Teamc4237ae2019-01-18 16:56:271245
1246 // Build an AffineMap to remap access functions based on lower bound offsets.
1247 SmallVector<AffineExpr, 4> remapExprs;
1248 remapExprs.reserve(rank);
1249 unsigned zeroOffsetCount = 0;
1250 for (unsigned i = 0; i < rank; i++) {
1251 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
1252 if (constExpr.getValue() == 0)
1253 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:521254 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
1255
1256 auto remapExpr =
1257 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
1258 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:271259 }
Uday Bondhugula94a03f82019-01-22 21:58:521260 auto indexRemap =
1261 zeroOffsetCount == rank
Nicolas Vasilache0e7a8a92019-01-26 18:41:171262 ? AffineMap()
Uday Bondhugula94a03f82019-01-22 21:58:521263 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:271264 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:521265 bool ret =
1266 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
1267 /*extraOperands=*/outerIVs,
River Riddle5052bd82019-02-02 00:42:181268 /*domInstFilter=*/&*forOp->getBody()->begin());
Uday Bondhugula94a03f82019-01-22 21:58:521269 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:371270 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:271271 return newMemRef;
1272}
1273
Uday Bondhugula864d9e02019-01-23 17:16:241274// Does the slice have a single iteration?
1275static uint64_t getSliceIterationCount(
River Riddle5052bd82019-02-02 00:42:181276 const llvm::SmallDenseMap<Instruction *, uint64_t, 8> &sliceTripCountMap) {
Uday Bondhugula864d9e02019-01-23 17:16:241277 uint64_t iterCount = 1;
1278 for (const auto &count : sliceTripCountMap) {
1279 iterCount *= count.second;
1280 }
1281 return iterCount;
1282}
1283
MLIR Team58aa3832019-02-16 01:12:191284// Checks if node 'srcId' (which writes to a live out memref), can be safely
1285// fused into node 'dstId'. Returns true if the following conditions are met:
1286// *) 'srcNode' writes only writes to live out 'memref'.
1287// *) 'srcNode' has exaclty one output edge on 'memref' (which is to 'dstId').
1288// *) 'dstNode' does write to 'memref'.
1289// *) 'dstNode's write region to 'memref' is a super set of 'srcNode's write
1290// region to 'memref'.
1291// TODO(andydavis) Generalize this to handle more live in/out cases.
1292static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
1293 Value *memref,
1294 MemRefDependenceGraph *mdg) {
1295 auto *srcNode = mdg->getNode(srcId);
1296 auto *dstNode = mdg->getNode(dstId);
1297
1298 // Return false if any of the following are true:
1299 // *) 'srcNode' writes to a live in/out memref other than 'memref'.
1300 // *) 'srcNode' has more than one output edge on 'memref'.
1301 // *) 'dstNode' does not write to 'memref'.
1302 if (srcNode->getStoreOpCount(memref) != 1 ||
1303 mdg->getOutEdgeCount(srcNode->id, memref) != 1 ||
1304 dstNode->getStoreOpCount(memref) == 0)
1305 return false;
1306 // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
1307 auto *srcStoreOpInst = srcNode->stores.front();
1308 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
MLIR Teamd42ef782019-03-04 19:01:251309 if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) {
1310 LLVM_DEBUG(llvm::dbgs()
1311 << "Unable to compute MemRefRegion for source operation\n.");
1312 return false;
1313 }
MLIR Team58aa3832019-02-16 01:12:191314 SmallVector<int64_t, 4> srcShape;
1315 // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
1316 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
1317 Optional<int64_t> srcNumElements =
1318 srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
1319 if (!srcNumElements.hasValue())
1320 return false;
1321
1322 // Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'.
1323 SmallVector<Instruction *, 2> dstStoreOps;
1324 dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
1325 assert(dstStoreOps.size() == 1);
1326 auto *dstStoreOpInst = dstStoreOps[0];
1327 MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc());
MLIR Teamd42ef782019-03-04 19:01:251328 if (!dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0)) {
1329 LLVM_DEBUG(llvm::dbgs()
1330 << "Unable to compute MemRefRegion for dest operation\n.");
1331 return false;
1332 }
MLIR Team58aa3832019-02-16 01:12:191333 SmallVector<int64_t, 4> dstShape;
1334 // Query 'dstWriteRegion' for 'dstShape' and 'dstNumElements'.
1335 // by 'dstStoreOpInst' at depth 'dstLoopDepth'.
1336 Optional<int64_t> dstNumElements =
1337 dstWriteRegion.getConstantBoundingSizeAndShape(&dstShape);
1338 if (!dstNumElements.hasValue())
1339 return false;
1340
1341 // Return false if write region is not a superset of 'srcNodes' write
1342 // region to 'memref'.
1343 // TODO(andydavis) Check the shape and lower bounds here too.
1344 if (srcNumElements != dstNumElements)
1345 return false;
1346 return true;
1347}
1348
MLIR Team27d067e2019-01-16 17:55:021349// Checks the profitability of fusing a backwards slice of the loop nest
MLIR Teamd7c82442019-01-30 23:53:411350// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
MLIR Teamd038e342019-03-01 19:50:251351// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
1352// the memref being produced and consumed, which is an input to the cost model.
1353// For producer-constumer fusion, 'srcStoreOpInst' will be the same as
1354// 'srcOpInst', as we are slicing w.r.t to that producer.
1355// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
1356// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
1357// will be the unique store op in the src node, which will be used to check
1358// that the write region is the same after input-reuse fusion.
Uday Bondhugulab4a14432019-01-26 00:00:501359// Returns true if it is profitable to fuse the candidate loop nests. Returns
1360// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
1361// to materialize the source loop nest slice.
MLIR Team38c2fe32019-01-14 19:26:251362// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:021363// *) Computes the backward computation slice at 'srcOpInst'. This
1364// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:251365// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:021366// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:251367// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
1368// loop nest is the total number of dynamic operation instances in the loop
1369// nest).
1370// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:021371// loop nest at various values of dst loop depth, attempting to fuse
1372// the largest compution slice at the maximal dst loop depth (closest to the
1373// load) to minimize reuse distance and potentially enable subsequent
1374// load/store forwarding.
MLIR Teamd7c82442019-01-30 23:53:411375// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
MLIR Team27d067e2019-01-16 17:55:021376// the same memref as is written by 'srcOpInst', then the union of slice
1377// loop bounds is used to compute the slice and associated slice cost.
Uday Bondhugulab4a14432019-01-26 00:00:501378// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
MLIR Team38c2fe32019-01-14 19:26:251379// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:021380// NOTE: We attempt to maximize the dst loop depth, but there are cases
1381// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:251382// loop (within the src computation slice) at a depth which results in
1383// execessive recomputation (see unit tests for examples).
1384// *) Compares the total cost of the unfused loop nests to the min cost fused
1385// loop nest computed in the previous step, and returns true if the latter
1386// is lower.
River Riddleb4992772019-02-04 18:38:471387static bool isFusionProfitable(Instruction *srcOpInst,
MLIR Teamd038e342019-03-01 19:50:251388 Instruction *srcStoreOpInst,
River Riddleb4992772019-02-04 18:38:471389 ArrayRef<Instruction *> dstLoadOpInsts,
1390 ArrayRef<Instruction *> dstStoreOpInsts,
MLIR Team38c2fe32019-01-14 19:26:251391 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:021392 unsigned *dstLoopDepth) {
Uday Bondhugula06d21d92019-01-25 01:01:491393 LLVM_DEBUG({
1394 llvm::dbgs() << "Checking whether fusion is profitable between:\n";
Uday Bondhugulaa1dad3a2019-02-20 02:17:191395 llvm::dbgs() << " " << *srcOpInst << " and \n";
MLIR Teamd7c82442019-01-30 23:53:411396 for (auto dstOpInst : dstLoadOpInsts) {
Uday Bondhugulaa1dad3a2019-02-20 02:17:191397 llvm::dbgs() << " " << *dstOpInst << "\n";
Uday Bondhugula06d21d92019-01-25 01:01:491398 };
1399 });
Uday Bondhugula864d9e02019-01-23 17:16:241400
MLIR Team38c2fe32019-01-14 19:26:251401 // Compute cost of sliced and unsliced src loop nest.
River Riddle5052bd82019-02-02 00:42:181402 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:021403 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251404 unsigned numSrcLoopIVs = srcLoopIVs.size();
1405
1406 // Walk src loop nest and collect stats.
1407 LoopNestStats srcLoopNestStats;
1408 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
River Riddlebf9c3812019-02-05 00:24:441409 srcStatsCollector.collect(srcLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251410 // Currently only constant trip count loop nests are supported.
1411 if (srcStatsCollector.hasLoopWithNonConstTripCount)
1412 return false;
1413
1414 // Compute cost of dst loop nest.
River Riddle5052bd82019-02-02 00:42:181415 SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
MLIR Teamd7c82442019-01-30 23:53:411416 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251417
1418 LoopNestStats dstLoopNestStats;
1419 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
River Riddlebf9c3812019-02-05 00:24:441420 dstStatsCollector.collect(dstLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251421 // Currently only constant trip count loop nests are supported.
1422 if (dstStatsCollector.hasLoopWithNonConstTripCount)
1423 return false;
1424
MLIR Teamd7c82442019-01-30 23:53:411425 // Compute the maximum loop depth at which we can can insert the src slice
MLIR Teamd038e342019-03-01 19:50:251426 // and still satisfy dest loop nest dependences, for producer-consumer fusion.
1427 unsigned maxDstLoopDepth =
1428 (srcOpInst == srcStoreOpInst)
1429 ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
1430 : dstLoopIVs.size();
MLIR Team27d067e2019-01-16 17:55:021431 if (maxDstLoopDepth == 0)
1432 return false;
1433
1434 // Search for min cost value for 'dstLoopDepth'. At each value of
1435 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
1436 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
1437 // of these bounds). Next the union slice bounds are used to calculate
1438 // the cost of the slice and the cost of the slice inserted into the dst
1439 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:241440 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
MLIR Teamd038e342019-03-01 19:50:251441 double maxStorageReduction = 0.0;
Uday Bondhugula864d9e02019-01-23 17:16:241442 Optional<uint64_t> sliceMemEstimate = None;
1443
MLIR Team27d067e2019-01-16 17:55:021444 SmallVector<ComputationSliceState, 4> sliceStates;
1445 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:241446 // The best loop depth at which to materialize the slice.
1447 Optional<unsigned> bestDstLoopDepth = None;
1448
1449 // Compute op instance count for the src loop nest without iteration slicing.
River Riddle5052bd82019-02-02 00:42:181450 uint64_t srcLoopNestCost =
1451 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
1452 /*tripCountOverrideMap=*/nullptr,
1453 /*computeCostMap=*/nullptr);
Uday Bondhugula864d9e02019-01-23 17:16:241454
MLIR Teamb9dde912019-02-06 19:01:101455 // Compute src loop nest write region size.
MLIR Teamd038e342019-03-01 19:50:251456 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
MLIR Teamd42ef782019-03-04 19:01:251457 if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) {
1458 LLVM_DEBUG(llvm::dbgs()
1459 << "Unable to compute MemRefRegion for source operation\n.");
1460 return false;
1461 }
1462
MLIR Teamb9dde912019-02-06 19:01:101463 Optional<int64_t> maybeSrcWriteRegionSizeBytes =
1464 srcWriteRegion.getRegionSize();
1465 if (!maybeSrcWriteRegionSizeBytes.hasValue())
1466 return false;
1467 int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
1468
Uday Bondhugula864d9e02019-01-23 17:16:241469 // Compute op instance count for the src loop nest.
River Riddle5052bd82019-02-02 00:42:181470 uint64_t dstLoopNestCost =
1471 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
1472 /*tripCountOverrideMap=*/nullptr,
1473 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:021474
MLIR Teamb9dde912019-02-06 19:01:101475 // Evaluate all depth choices for materializing the slice in the destination
1476 // loop nest.
River Riddle5052bd82019-02-02 00:42:181477 llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap;
1478 DenseMap<Instruction *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:021479 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
1480 MemRefAccess srcAccess(srcOpInst);
1481 // Handle the common case of one dst load without a copy.
1482 if (!mlir::getBackwardComputationSliceState(
MLIR Teamd7c82442019-01-30 23:53:411483 srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1]))
MLIR Team27d067e2019-01-16 17:55:021484 return false;
MLIR Teamd038e342019-03-01 19:50:251485
MLIR Teamd7c82442019-01-30 23:53:411486 // Compute the union of slice bound of all ops in 'dstLoadOpInsts'.
1487 for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) {
1488 MemRefAccess dstAccess(dstLoadOpInsts[j]);
MLIR Team27d067e2019-01-16 17:55:021489 ComputationSliceState tmpSliceState;
1490 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
1491 &tmpSliceState))
1492 return false;
1493 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:001494 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:251495 }
Uday Bondhugulab4a14432019-01-26 00:00:501496 // Build trip count map for computation slice. We'll skip cases where the
1497 // trip count was non-constant.
MLIR Team27d067e2019-01-16 17:55:021498 sliceTripCountMap.clear();
1499 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
1500 &sliceTripCountMap))
Uday Bondhugula864d9e02019-01-23 17:16:241501 continue;
1502
1503 // Checks whether a store to load forwarding will happen.
1504 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
Uday Bondhugula864d9e02019-01-23 17:16:241505 assert(sliceIterationCount > 0);
Uday Bondhugulab4a14432019-01-26 00:00:501506 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
Uday Bondhugula864d9e02019-01-23 17:16:241507
1508 // Compute cost of fusion for this dest loop depth.
1509
1510 computeCostMap.clear();
1511
1512 // The store and loads to this memref will disappear.
MLIR Teamd038e342019-03-01 19:50:251513 // TODO(andydavis) Add load coalescing to memref data flow opt pass.
Uday Bondhugula864d9e02019-01-23 17:16:241514 if (storeLoadFwdGuaranteed) {
1515 // A single store disappears: -1 for that.
River Riddle5052bd82019-02-02 00:42:181516 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
MLIR Teamd7c82442019-01-30 23:53:411517 for (auto *loadOp : dstLoadOpInsts) {
River Riddle5052bd82019-02-02 00:42:181518 auto *parentInst = loadOp->getParentInst();
River Riddleb4992772019-02-04 18:38:471519 if (parentInst && parentInst->isa<AffineForOp>())
River Riddle5052bd82019-02-02 00:42:181520 computeCostMap[parentInst] = -1;
Uday Bondhugula864d9e02019-01-23 17:16:241521 }
1522 }
MLIR Team27d067e2019-01-16 17:55:021523
MLIR Team38c2fe32019-01-14 19:26:251524 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:241525 int64_t sliceComputeCost =
River Riddle5052bd82019-02-02 00:42:181526 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
Uday Bondhugula864d9e02019-01-23 17:16:241527 /*tripCountOverrideMap=*/&sliceTripCountMap,
1528 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:251529
Uday Bondhugula864d9e02019-01-23 17:16:241530 // Compute cost of fusion for this depth.
River Riddle5052bd82019-02-02 00:42:181531 computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:241532
1533 int64_t fusedLoopNestComputeCost =
River Riddle5052bd82019-02-02 00:42:181534 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
MLIR Team27d067e2019-01-16 17:55:021535 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:241536
1537 double additionalComputeFraction =
1538 fusedLoopNestComputeCost /
1539 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1540 1;
1541
MLIR Teamb9dde912019-02-06 19:01:101542 // Compute what the slice write MemRefRegion would be, if the src loop
1543 // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
1544 // nest at loop depth 'i'
MLIR Teamd038e342019-03-01 19:50:251545 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
MLIR Teamd42ef782019-03-04 19:01:251546 if (!sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
1547 &sliceStates[i - 1]))
1548 continue;
1549
MLIR Teamb9dde912019-02-06 19:01:101550 Optional<int64_t> maybeSliceWriteRegionSizeBytes =
1551 sliceWriteRegion.getRegionSize();
1552 if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
1553 maybeSliceWriteRegionSizeBytes.getValue() == 0)
1554 continue;
1555 int64_t sliceWriteRegionSizeBytes =
1556 maybeSliceWriteRegionSizeBytes.getValue();
1557
MLIR Teamd038e342019-03-01 19:50:251558 // If we are fusing for reuse, check that write regions remain the same.
1559 // TODO(andydavis) Write region check should check sizes and offsets in
1560 // each dimension, so that we are sure they are covering the same memref
1561 // region. Also, move this out to a isMemRefRegionSuperSet helper function.
1562 if (srcOpInst != srcStoreOpInst &&
1563 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
1564 continue;
1565
MLIR Teamb9dde912019-02-06 19:01:101566 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
1567 static_cast<double>(sliceWriteRegionSizeBytes);
Uday Bondhugula864d9e02019-01-23 17:16:241568
Uday Bondhugula06d21d92019-01-25 01:01:491569 LLVM_DEBUG({
1570 std::stringstream msg;
1571 msg << " evaluating fusion profitability at depth : " << i << "\n"
Uday Bondhugulad4b3ff12019-02-27 00:10:191572 << std::fixed << std::setprecision(2)
1573 << " additional compute fraction: "
Uday Bondhugula06d21d92019-01-25 01:01:491574 << 100.0 * additionalComputeFraction << "%\n"
1575 << " storage reduction factor: " << storageReduction << "x\n"
1576 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
Uday Bondhugulaa1dad3a2019-02-20 02:17:191577 << " slice iteration count: " << sliceIterationCount << "\n"
1578 << " src write region size: " << srcWriteRegionSizeBytes << "\n"
1579 << " slice write region size: " << sliceWriteRegionSizeBytes
1580 << "\n";
Uday Bondhugula06d21d92019-01-25 01:01:491581 llvm::dbgs() << msg.str();
1582 });
Uday Bondhugula864d9e02019-01-23 17:16:241583
1584 double computeToleranceThreshold =
1585 clFusionAddlComputeTolerance.getNumOccurrences() > 0
1586 ? clFusionAddlComputeTolerance
1587 : LoopFusion::kComputeToleranceThreshold;
1588
1589 // TODO(b/123247369): This is a placeholder cost model.
1590 // Among all choices that add an acceptable amount of redundant computation
1591 // (as per computeToleranceThreshold), we will simply pick the one that
1592 // reduces the intermediary size the most.
1593 if ((storageReduction > maxStorageReduction) &&
1594 (clMaximalLoopFusion ||
1595 (additionalComputeFraction < computeToleranceThreshold))) {
1596 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:021597 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:241598 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
MLIR Teamb9dde912019-02-06 19:01:101599 sliceMemEstimate = sliceWriteRegionSizeBytes;
MLIR Team38c2fe32019-01-14 19:26:251600 }
1601 }
1602
Uday Bondhugula864d9e02019-01-23 17:16:241603 // A simple cost model: fuse if it reduces the memory footprint. If
1604 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251605
Uday Bondhugula864d9e02019-01-23 17:16:241606 if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
Uday Bondhugulaa1dad3a2019-02-20 02:17:191607 LLVM_DEBUG(
1608 llvm::dbgs()
1609 << "All fusion choices involve more than the threshold amount of "
1610 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251611 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241612 }
1613
MLIR Teamd42ef782019-03-04 19:01:251614 if (!bestDstLoopDepth.hasValue()) {
1615 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
1616 return false;
1617 }
Uday Bondhugula864d9e02019-01-23 17:16:241618
1619 // Set dstLoopDepth based on best values from search.
1620 *dstLoopDepth = bestDstLoopDepth.getValue();
1621
1622 LLVM_DEBUG(
Uday Bondhugula06d21d92019-01-25 01:01:491623 llvm::dbgs() << " LoopFusion fusion stats:"
1624 << "\n best loop depth: " << bestDstLoopDepth
Uday Bondhugula864d9e02019-01-23 17:16:241625 << "\n src loop nest compute cost: " << srcLoopNestCost
1626 << "\n dst loop nest compute cost: " << dstLoopNestCost
1627 << "\n fused loop nest compute cost: "
1628 << minFusedLoopNestComputeCost << "\n");
1629
River Riddle5052bd82019-02-02 00:42:181630 auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
1631 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Uday Bondhugula864d9e02019-01-23 17:16:241632
1633 Optional<double> storageReduction = None;
1634
1635 if (!clMaximalLoopFusion) {
1636 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1637 LLVM_DEBUG(
1638 llvm::dbgs()
1639 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1640 return false;
1641 }
1642
1643 auto srcMemSizeVal = srcMemSize.getValue();
1644 auto dstMemSizeVal = dstMemSize.getValue();
1645
1646 assert(sliceMemEstimate.hasValue() && "expected value");
1647 // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
1648 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1649
1650 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1651 << " dst mem: " << dstMemSizeVal << "\n"
1652 << " fused mem: " << fusedMem << "\n"
1653 << " slice mem: " << sliceMemEstimate << "\n");
1654
1655 if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
1656 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1657 return false;
1658 }
1659 storageReduction =
1660 100.0 *
1661 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1662 }
1663
1664 double additionalComputeFraction =
1665 100.0 * (minFusedLoopNestComputeCost /
1666 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1667 1);
MLIR Team5c5739d2019-01-25 06:27:401668 (void)additionalComputeFraction;
Uday Bondhugula06d21d92019-01-25 01:01:491669 LLVM_DEBUG({
1670 std::stringstream msg;
1671 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
MLIR Team8564b272019-02-22 15:48:591672 << std::setprecision(2) << additionalComputeFraction
Uday Bondhugula06d21d92019-01-25 01:01:491673 << "% redundant computation and a ";
1674 msg << (storageReduction.hasValue()
1675 ? std::to_string(storageReduction.getValue())
1676 : "<unknown>");
1677 msg << "% storage reduction.\n";
1678 llvm::dbgs() << msg.str();
1679 });
Uday Bondhugula864d9e02019-01-23 17:16:241680
MLIR Team27d067e2019-01-16 17:55:021681 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241682 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021683 sliceState->lbs = bestSliceState->lbs;
1684 sliceState->ubs = bestSliceState->ubs;
1685 sliceState->lbOperands = bestSliceState->lbOperands;
1686 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241687
MLIR Team27d067e2019-01-16 17:55:021688 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251689 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Nicolas Vasilache0e7a8a92019-01-26 18:41:171690 if (sliceState->lbs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021691 canonicalizeMapAndOperands(&sliceState->lbs[i],
1692 &sliceState->lbOperands[i]);
1693 }
Nicolas Vasilache0e7a8a92019-01-26 18:41:171694 if (sliceState->ubs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021695 canonicalizeMapAndOperands(&sliceState->ubs[i],
1696 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251697 }
1698 }
1699 return true;
1700}
1701
MLIR Teamd038e342019-03-01 19:50:251702// GreedyFusion greedily fuses loop nests which have a producer/consumer or
1703// input-reuse relationship on a memref, with the goal of improving locality.
MLIR Teamf28e4df2018-11-01 14:26:001704//
MLIR Teamd038e342019-03-01 19:50:251705// The steps of the producer-consumer fusion algorithm are as follows:
MLIR Team3b692302018-12-17 17:57:141706//
MLIR Team6892ffb2018-12-20 04:42:551707// *) A worklist is initialized with node ids from the dependence graph.
1708// *) For each node id in the worklist:
River Riddle5052bd82019-02-02 00:42:181709// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a
1710// candidate destination AffineForOp into which fusion will be attempted.
1711// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141712// *) For each LoadOp in 'dstLoadOps' do:
MLIR Teamd038e342019-03-01 19:50:251713// *) Lookup dependent loop nests which have a single store op to the same
1714// memref.
1715// *) Check if dependences would be violated by the fusion.
MLIR Team6892ffb2018-12-20 04:42:551716// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141717// bounds to be functions of 'dstLoopNest' IVs and symbols.
1718// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
MLIR Teamd038e342019-03-01 19:50:251719// at a loop depth determined by the cost model in 'isFusionProfitable'.
Chris Lattner456ad6a2018-12-29 00:05:351720// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:141721// and also add newly fuse load ops to 'dstLoopOps' to be considered
1722// as fusion dst load ops in another iteration.
1723// *) Remove old src loop nest and its associated state.
1724//
MLIR Teamd038e342019-03-01 19:50:251725// The steps of the input-reuse fusion algorithm are as follows:
1726//
1727// *) Initialize 'worklist' with node ids from the dependence graph.
1728// *) For each 'dstNode' in the worklist:
1729// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
1730// loads from the same memref, but which has no dependence paths to/from.
1731// *) Get a computation slice of 'sibLoopNest', which adjusts its loop
1732// bounds to be functions of 'dstLoopNest' IVs and symbols.
1733// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
1734// at a loop depth determined by the cost model in 'isFusionProfitable'.
1735// This function also checks that the memref write region of 'sibLoopNest',
1736// is preserved in the fused loop nest.
1737// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
1738//
Chris Lattner456ad6a2018-12-29 00:05:351739// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141740// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551741// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141742//
MLIR Team6892ffb2018-12-20 04:42:551743// This greedy algorithm is not 'maximal' due to the current restriction of
1744// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141745//
1746// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551747struct GreedyFusion {
1748public:
MLIR Teamd038e342019-03-01 19:50:251749 // The data dependence graph to traverse during fusion.
MLIR Team6892ffb2018-12-20 04:42:551750 MemRefDependenceGraph *mdg;
MLIR Teamd038e342019-03-01 19:50:251751 // Worklist of graph nodes visited during the fusion pass.
MLIR Teama78edcd2019-02-05 14:57:081752 SmallVector<unsigned, 8> worklist;
MLIR Teamd038e342019-03-01 19:50:251753 // Set of graph nodes which are present on the worklist.
MLIR Teama78edcd2019-02-05 14:57:081754 llvm::SmallDenseSet<unsigned, 16> worklistSet;
MLIR Teamd038e342019-03-01 19:50:251755 // Parameter for local buffer size threshold.
1756 unsigned localBufSizeThreshold;
1757 // Parameter for fast memory space.
1758 Optional<unsigned> fastMemorySpace;
MLIR Teamf28e4df2018-11-01 14:26:001759
MLIR Teamd038e342019-03-01 19:50:251760 using Node = MemRefDependenceGraph::Node;
1761
1762 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
1763 Optional<unsigned> fastMemorySpace)
1764 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
1765 fastMemorySpace(fastMemorySpace) {}
1766
1767 // Initializes 'worklist' with nodes from 'mdg'
1768 void init() {
MLIR Teama78edcd2019-02-05 14:57:081769 // TODO(andydavis) Add a priority queue for prioritizing nodes by different
1770 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
MLIR Teamd038e342019-03-01 19:50:251771 worklist.clear();
1772 worklistSet.clear();
1773 for (auto &idAndNode : mdg->nodes) {
1774 const Node &node = idAndNode.second;
1775 worklist.push_back(node.id);
1776 worklistSet.insert(node.id);
1777 }
MLIR Team6892ffb2018-12-20 04:42:551778 }
MLIR Team3b692302018-12-17 17:57:141779
MLIR Teamd038e342019-03-01 19:50:251780 // Run the GreedyFusion pass.
1781 // *) First pass through the nodes fuses single-use producer nodes into their
1782 // unique consumer.
1783 // *) Second pass fuses sibling nodes which share no dependence edges.
1784 // *) Third pass fuses any remaining producer nodes into their users.
1785 void run() {
1786 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
1787 fuseSiblingNodes();
1788 fuseProducerConsumerNodes(
1789 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
1790 eraseUnusedMemRefAllocations();
1791 }
1792
1793 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1794 init();
MLIR Team3b692302018-12-17 17:57:141795 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551796 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141797 worklist.pop_back();
MLIR Teama78edcd2019-02-05 14:57:081798 worklistSet.erase(dstId);
1799
MLIR Team6892ffb2018-12-20 04:42:551800 // Skip if this node was removed (fused into another node).
1801 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141802 continue;
MLIR Team6892ffb2018-12-20 04:42:551803 // Get 'dstNode' into which to attempt fusion.
1804 auto *dstNode = mdg->getNode(dstId);
1805 // Skip if 'dstNode' is not a loop nest.
River Riddleb4992772019-02-04 18:38:471806 if (!dstNode->inst->isa<AffineForOp>())
MLIR Team3b692302018-12-17 17:57:141807 continue;
MLIR Team8f5f2c72019-02-15 17:32:181808 // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
1809 // while preserving relative order. This can increase the maximum loop
1810 // depth at which we can fuse a slice of a producer loop nest into a
1811 // consumer loop nest.
1812 sinkSequentialLoops(dstNode);
MLIR Team3b692302018-12-17 17:57:141813
River Riddleb4992772019-02-04 18:38:471814 SmallVector<Instruction *, 4> loads = dstNode->loads;
1815 SmallVector<Instruction *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271816 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551817 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021818 // Get memref of load on top of the stack.
1819 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271820 if (visitedMemrefs.count(memref) > 0)
1821 continue;
1822 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021823 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1824 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551825 // Skip if no input edges along which to fuse.
1826 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141827 continue;
MLIR Team1e851912019-01-31 00:01:461828 // Iterate through in edges for 'dstId' and src node id for any
1829 // edges on 'memref'.
1830 SmallVector<unsigned, 2> srcNodeIds;
MLIR Team6892ffb2018-12-20 04:42:551831 for (auto &srcEdge : mdg->inEdges[dstId]) {
1832 // Skip 'srcEdge' if not for 'memref'.
MLIR Teama0f3db402019-01-29 17:36:411833 if (srcEdge.value != memref)
MLIR Team6892ffb2018-12-20 04:42:551834 continue;
MLIR Team1e851912019-01-31 00:01:461835 srcNodeIds.push_back(srcEdge.id);
1836 }
1837 for (unsigned srcId : srcNodeIds) {
1838 // Skip if this node was removed (fused into another node).
1839 if (mdg->nodes.count(srcId) == 0)
1840 continue;
1841 // Get 'srcNode' from which to attempt fusion into 'dstNode'.
1842 auto *srcNode = mdg->getNode(srcId);
MLIR Team6892ffb2018-12-20 04:42:551843 // Skip if 'srcNode' is not a loop nest.
River Riddleb4992772019-02-04 18:38:471844 if (!srcNode->inst->isa<AffineForOp>())
MLIR Team6892ffb2018-12-20 04:42:551845 continue;
MLIR Teamb28009b2019-01-23 19:11:431846 // Skip if 'srcNode' has more than one store to any memref.
1847 // TODO(andydavis) Support fusing multi-output src loop nests.
1848 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551849 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241850
MLIR Teama0f3db402019-01-29 17:36:411851 // Skip 'srcNode' if it has in edges on 'memref'.
MLIR Team6892ffb2018-12-20 04:42:551852 // TODO(andydavis) Track dependence type with edges, and just check
MLIR Teama0f3db402019-01-29 17:36:411853 // for WAW dependence edge here. Note that this check is overly
1854 // conservative and will be removed in the future.
1855 if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0)
MLIR Team6892ffb2018-12-20 04:42:551856 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241857
MLIR Team58aa3832019-02-16 01:12:191858 // Skip if 'srcNode' writes to any live in or escaping memrefs,
1859 // and cannot be fused.
1860 bool writesToLiveInOrOut =
1861 mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
1862 if (writesToLiveInOrOut &&
1863 !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
MLIR Teamd7c82442019-01-30 23:53:411864 continue;
1865
MLIR Teamd038e342019-03-01 19:50:251866 // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
1867 if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
1868 continue;
1869
MLIR Teama0f3db402019-01-29 17:36:411870 // Compute an instruction list insertion point for the fused loop
1871 // nest which preserves dependences.
MLIR Teama78edcd2019-02-05 14:57:081872 Instruction *insertPointInst =
1873 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
MLIR Teama0f3db402019-01-29 17:36:411874 if (insertPointInst == nullptr)
MLIR Team6892ffb2018-12-20 04:42:551875 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241876
MLIR Team6892ffb2018-12-20 04:42:551877 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351878 auto *srcStoreOpInst = srcNode->stores.front();
MLIR Teamd7c82442019-01-30 23:53:411879 // Gather 'dstNode' store ops to 'memref'.
River Riddleb4992772019-02-04 18:38:471880 SmallVector<Instruction *, 2> dstStoreOpInsts;
MLIR Teamd7c82442019-01-30 23:53:411881 for (auto *storeOpInst : dstNode->stores)
1882 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1883 dstStoreOpInsts.push_back(storeOpInst);
1884
Uday Bondhugulab4a14432019-01-26 00:00:501885 unsigned bestDstLoopDepth;
MLIR Team38c2fe32019-01-14 19:26:251886 mlir::ComputationSliceState sliceState;
MLIR Teama0f3db402019-01-29 17:36:411887 // Check if fusion would be profitable.
MLIR Teamd038e342019-03-01 19:50:251888 if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst,
1889 dstLoadOpInsts, dstStoreOpInsts, &sliceState,
Uday Bondhugulab4a14432019-01-26 00:00:501890 &bestDstLoopDepth))
MLIR Team38c2fe32019-01-14 19:26:251891 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241892
MLIR Team6892ffb2018-12-20 04:42:551893 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
River Riddle5052bd82019-02-02 00:42:181894 auto sliceLoopNest = mlir::insertBackwardComputationSlice(
Uday Bondhugulab4a14432019-01-26 00:00:501895 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:551896 if (sliceLoopNest != nullptr) {
Uday Bondhugulaa1dad3a2019-02-20 02:17:191897 LLVM_DEBUG(llvm::dbgs()
1898 << "\tslice loop nest:\n"
1899 << *sliceLoopNest->getInstruction() << "\n");
River Riddle5052bd82019-02-02 00:42:181900 // Move 'dstAffineForOp' before 'insertPointInst' if needed.
River Riddleb4992772019-02-04 18:38:471901 auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
River Riddle5052bd82019-02-02 00:42:181902 if (insertPointInst != dstAffineForOp->getInstruction()) {
1903 dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
MLIR Teama0f3db402019-01-29 17:36:411904 }
MLIR Teamc4237ae2019-01-18 16:56:271905 // Update edges between 'srcNode' and 'dstNode'.
MLIR Teama0f3db402019-01-29 17:36:411906 mdg->updateEdges(srcNode->id, dstNode->id, memref);
MLIR Teamc4237ae2019-01-18 16:56:271907
1908 // Collect slice loop stats.
1909 LoopNestStateCollector sliceCollector;
River Riddlebf9c3812019-02-05 00:24:441910 sliceCollector.collect(sliceLoopNest->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271911 // Promote single iteration slice loops to single IV value.
River Riddle5052bd82019-02-02 00:42:181912 for (auto forOp : sliceCollector.forOps) {
1913 promoteIfSingleIteration(forOp);
MLIR Team6892ffb2018-12-20 04:42:551914 }
MLIR Team58aa3832019-02-16 01:12:191915 if (!writesToLiveInOrOut) {
1916 // Create private memref for 'memref' in 'dstAffineForOp'.
1917 SmallVector<Instruction *, 4> storesForMemref;
1918 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
1919 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1920 storesForMemref.push_back(storeOpInst);
1921 }
1922 assert(storesForMemref.size() == 1);
1923 auto *newMemRef = createPrivateMemRef(
1924 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1925 fastMemorySpace, localBufSizeThreshold);
1926 visitedMemrefs.insert(newMemRef);
1927 // Create new node in dependence graph for 'newMemRef' alloc op.
1928 unsigned newMemRefNodeId =
1929 mdg->addNode(newMemRef->getDefiningInst());
1930 // Add edge from 'newMemRef' node to dstNode.
1931 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
MLIR Teamc4237ae2019-01-18 16:56:271932 }
MLIR Teamc4237ae2019-01-18 16:56:271933
1934 // Collect dst loop stats after memref privatizaton transformation.
1935 LoopNestStateCollector dstLoopCollector;
River Riddlebf9c3812019-02-05 00:24:441936 dstLoopCollector.collect(dstAffineForOp->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271937
1938 // Add new load ops to current Node load op list 'loads' to
1939 // continue fusing based on new operands.
1940 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
1941 auto *loadMemRef = loadOpInst->cast<LoadOp>()->getMemRef();
1942 if (visitedMemrefs.count(loadMemRef) == 0)
1943 loads.push_back(loadOpInst);
1944 }
1945
1946 // Clear and add back loads and stores
1947 mdg->clearNodeLoadAndStores(dstNode->id);
1948 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1949 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371950 // Remove old src loop nest if it no longer has outgoing dependence
1951 // edges, and it does not write to a memref which escapes the
MLIR Team58aa3832019-02-16 01:12:191952 // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has
1953 // been fused into 'dstNode' and write region of 'dstNode' covers
1954 // the write region of 'srcNode', and 'srcNode' has no other users
1955 // so it is safe to remove.
1956 if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271957 mdg->removeNode(srcNode->id);
River Riddle5052bd82019-02-02 00:42:181958 srcNode->inst->erase();
MLIR Teama78edcd2019-02-05 14:57:081959 } else {
1960 // Add remaining users of 'oldMemRef' back on the worklist (if not
1961 // already there), as its replacement with a local/private memref
1962 // has reduced dependences on 'oldMemRef' which may have created
1963 // new fusion opportunities.
1964 if (mdg->outEdges.count(srcNode->id) > 0) {
1965 SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
1966 mdg->outEdges[srcNode->id];
1967 for (auto &outEdge : oldOutEdges) {
1968 if (outEdge.value == memref &&
1969 worklistSet.count(outEdge.id) == 0) {
1970 worklist.push_back(outEdge.id);
1971 worklistSet.insert(outEdge.id);
1972 }
1973 }
1974 }
MLIR Teamc4237ae2019-01-18 16:56:271975 }
MLIR Team3b692302018-12-17 17:57:141976 }
MLIR Team3b692302018-12-17 17:57:141977 }
1978 }
1979 }
MLIR Teamd038e342019-03-01 19:50:251980 }
1981
1982 // Visits each node in the graph, and for each node, attempts to fuse it with
1983 // its sibling nodes (nodes which share a parent, but no dependence edges).
1984 void fuseSiblingNodes() {
1985 init();
1986 while (!worklist.empty()) {
1987 unsigned dstId = worklist.back();
1988 worklist.pop_back();
1989 worklistSet.erase(dstId);
1990
1991 // Skip if this node was removed (fused into another node).
1992 if (mdg->nodes.count(dstId) == 0)
1993 continue;
1994 // Get 'dstNode' into which to attempt fusion.
1995 auto *dstNode = mdg->getNode(dstId);
1996 // Skip if 'dstNode' is not a loop nest.
1997 if (!dstNode->inst->isa<AffineForOp>())
1998 continue;
1999 // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
2000 fuseWithSiblingNodes(dstNode);
2001 }
2002 }
2003
2004 // Attempt to fuse 'dstNode' with sibling nodes in the graph.
2005 void fuseWithSiblingNodes(Node *dstNode) {
2006 DenseSet<unsigned> visitedSibNodeIds;
2007 std::pair<unsigned, Value *> idAndMemref;
2008 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
2009 unsigned sibId = idAndMemref.first;
2010 Value *memref = idAndMemref.second;
2011 // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other
2012 // stores to the same memref in 'sibNode' loop nest.
2013 auto *sibNode = mdg->getNode(sibId);
2014 // Compute an instruction list insertion point for the fused loop
2015 // nest which preserves dependences.
2016 assert(sibNode->inst->getBlock() == dstNode->inst->getBlock());
2017 Instruction *insertPointInst =
2018 sibNode->inst->isBeforeInBlock(dstNode->inst)
2019 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
2020 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
2021 if (insertPointInst == nullptr)
2022 continue;
2023
2024 // Check if fusion would be profitable and at what depth.
2025
2026 // Get unique 'sibNode' load op to 'memref'.
2027 SmallVector<Instruction *, 2> sibLoadOpInsts;
2028 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
2029 // Currently findSiblingNodeToFuse searches for siblings with one load.
2030 assert(sibLoadOpInsts.size() == 1);
2031 Instruction *sibLoadOpInst = sibLoadOpInsts[0];
2032 assert(!sibNode->stores.empty());
2033 // TODO(andydavis) Choose the store which postdominates all other stores.
2034 auto *sibStoreOpInst = sibNode->stores.back();
2035
2036 // Gather 'dstNode' load ops to 'memref'.
2037 SmallVector<Instruction *, 2> dstLoadOpInsts;
2038 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
2039
2040 // Gather 'dstNode' store ops to 'memref'.
2041 SmallVector<Instruction *, 2> dstStoreOpInsts;
2042 dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
2043
2044 unsigned bestDstLoopDepth;
2045 mlir::ComputationSliceState sliceState;
2046
2047 // Check if fusion would be profitable.
2048 if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
2049 dstStoreOpInsts, &sliceState, &bestDstLoopDepth))
2050 continue;
2051
2052 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
2053 auto sliceLoopNest = mlir::insertBackwardComputationSlice(
2054 sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
2055 if (sliceLoopNest != nullptr) {
2056 auto dstForInst = dstNode->inst->cast<AffineForOp>();
2057 // Update instruction position of fused loop nest (if needed).
2058 if (insertPointInst != dstForInst->getInstruction()) {
2059 dstForInst->getInstruction()->moveBefore(insertPointInst);
2060 }
2061 // Update data dependence graph state post fusion.
2062 updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
2063 }
2064 }
2065 }
2066
2067 // Searches the graph from 'dstNode' looking for a fusion candidate sibling
2068 // node which shares no dependences with 'dstNode' but which loads from the
2069 // same memref. Returns true and sets 'idAndMemrefToFuse' on success. Returns
2070 // false otherwise.
2071 bool findSiblingNodeToFuse(Node *dstNode,
2072 DenseSet<unsigned> *visitedSibNodeIds,
2073 std::pair<unsigned, Value *> *idAndMemrefToFuse) {
2074 // TODO(andydavis) Currently we discover siblings by following edges
2075 // through an intermediate src node. We should also consider siblings
2076 // which load from the same memref, but which do not necessarily share
2077 // a src node parent (e.g. loading from a memref which is a function arg).
2078 // Collect candidate 'dstNode' input edges in 'inEdges'.
2079 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
2080 mdg->forEachMemRefInputEdge(
2081 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
2082 // Add 'inEdge' if it is a read-after-write dependence.
2083 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
2084 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
2085 inEdges.push_back(inEdge);
2086 });
2087
2088 // Search for sibling nodes to fuse by visiting output edges from each input
2089 // edge in 'inEdges'.
2090 for (auto &inEdge : inEdges) {
2091 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
2092 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
2093 mdg->forEachMemRefOutputEdge(
2094 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
2095 unsigned sibNodeId = outEdge.id;
2096 if (visitedSibNodeIds->count(sibNodeId) > 0)
2097 return;
2098 // Skip output edge if not a sibling using the same memref.
2099 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
2100 return;
2101 auto *sibNode = mdg->getNode(sibNodeId);
2102 if (!sibNode->inst->isa<AffineForOp>())
2103 return;
2104 // Skip if 'outEdge' is not a read-after-write dependence.
2105 // TODO(andydavis) Remove restrict to single load op restriction.
2106 if (sibNode->getLoadOpCount(inEdge.value) != 1)
2107 return;
2108 // Skip if there exists a path of dependent edges between
2109 // 'sibNode' and 'dstNode'.
2110 if (mdg->hasDependencePath(sibNodeId, dstNode->id) ||
2111 mdg->hasDependencePath(dstNode->id, sibNodeId))
2112 return;
2113 // Skip sib node if it loads to (and stores from) the same memref on
2114 // which it also has an input dependence edge.
2115 DenseSet<Value *> loadAndStoreMemrefSet;
2116 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
2117 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
2118 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) >
2119 0;
2120 }))
2121 return;
2122 // Check that all stores are to the same memref.
2123 DenseSet<Value *> storeMemrefs;
2124 for (auto *storeOpInst : sibNode->stores) {
2125 storeMemrefs.insert(storeOpInst->cast<StoreOp>()->getMemRef());
2126 }
2127 if (storeMemrefs.size() != 1)
2128 return;
2129 // Add candidate 'outEdge' to sibling node.
2130 outEdges.push_back(outEdge);
2131 });
2132
2133 // Add first candidate if any were returned.
2134 if (!outEdges.empty()) {
2135 visitedSibNodeIds->insert(outEdges[0].id);
2136 idAndMemrefToFuse->first = outEdges[0].id;
2137 idAndMemrefToFuse->second = outEdges[0].value;
2138 return true;
2139 }
2140 }
2141 return false;
2142 }
2143
2144 void updateStateAfterSiblingFusion(OpPointer<AffineForOp> sliceLoopNest,
2145 Node *sibNode, Node *dstNode) {
2146 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
2147 mdg->updateEdges(sibNode->id, dstNode->id);
2148
2149 // Collect slice loop stats.
2150 LoopNestStateCollector sliceCollector;
2151 sliceCollector.collect(sliceLoopNest->getInstruction());
2152 // Promote single iteration slice loops to single IV value.
2153 for (auto forOp : sliceCollector.forOps) {
2154 promoteIfSingleIteration(forOp);
2155 }
2156
2157 // Collect dst loop stats after memref privatizaton transformation.
2158 auto dstForInst = dstNode->inst->cast<AffineForOp>();
2159 LoopNestStateCollector dstLoopCollector;
2160 dstLoopCollector.collect(dstForInst->getInstruction());
2161 // Clear and add back loads and stores
2162 mdg->clearNodeLoadAndStores(dstNode->id);
2163 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
2164 dstLoopCollector.storeOpInsts);
2165 // Remove old sibling loop nest if it no longer has outgoing dependence
2166 // edges, and it does not write to a memref which escapes the
2167 // function.
2168 if (mdg->getOutEdgeCount(sibNode->id) == 0) {
2169 mdg->removeNode(sibNode->id);
2170 sibNode->inst->cast<AffineForOp>()->erase();
2171 }
2172 }
2173
2174 // Clean up any allocs with no users.
2175 void eraseUnusedMemRefAllocations() {
MLIR Teamc4237ae2019-01-18 16:56:272176 for (auto &pair : mdg->memrefEdgeCount) {
2177 if (pair.second > 0)
2178 continue;
2179 auto *memref = pair.first;
MLIR Team71495d52019-01-22 21:23:372180 // Skip if there exist other uses (return instruction or function calls).
2181 if (!memref->use_empty())
2182 continue;
MLIR Teamc4237ae2019-01-18 16:56:272183 // Use list expected to match the dep graph info.
MLIR Teamc4237ae2019-01-18 16:56:272184 auto *inst = memref->getDefiningInst();
River Riddleb4992772019-02-04 18:38:472185 if (inst && inst->isa<AllocOp>())
2186 inst->erase();
MLIR Teamc4237ae2019-01-18 16:56:272187 }
MLIR Teamf28e4df2018-11-01 14:26:002188 }
MLIR Team3b692302018-12-17 17:57:142189};
2190
2191} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:002192
River Riddleed5fe202019-02-28 22:50:422193void LoopFusion::runOnFunction() {
Uday Bondhugulad4b3ff12019-02-27 00:10:192194 // Override if a command line argument was provided.
Uday Bondhugula8be26272019-02-02 01:06:222195 if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
2196 fastMemorySpace = clFusionFastMemorySpace.getValue();
2197 }
2198
Uday Bondhugulad4b3ff12019-02-27 00:10:192199 // Override if a command line argument was provided.
2200 if (clFusionLocalBufThreshold.getNumOccurrences() > 0) {
2201 localBufSizeThreshold = clFusionLocalBufThreshold * 1024;
2202 }
2203
MLIR Team6892ffb2018-12-20 04:42:552204 MemRefDependenceGraph g;
River Riddlec6c53442019-02-27 18:59:292205 if (g.init(&getFunction()))
MLIR Teamd038e342019-03-01 19:50:252206 GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace).run();
MLIR Teamf28e4df2018-11-01 14:26:002207}
Jacques Pienaar6f0fb222018-11-07 02:34:182208
2209static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");