blob: 8d5f51059bf9410239fbbf8ac99c93a4868fc67f [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
River Riddle75553832019-01-29 05:23:5322#include "mlir/AffineOps/AffineOps.h"
MLIR Teamf28e4df2018-11-01 14:26:0023#include "mlir/Analysis/AffineAnalysis.h"
MLIR Teamf28e4df2018-11-01 14:26:0024#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1425#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0026#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
River Riddle10237de2019-02-06 01:00:1328#include "mlir/IR/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0029#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinOps.h"
MLIR Teamf28e4df2018-11-01 14:26:0031#include "mlir/Pass.h"
32#include "mlir/StandardOps/StandardOps.h"
33#include "mlir/Transforms/LoopUtils.h"
34#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2735#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0036#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1437#include "llvm/ADT/DenseSet.h"
38#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2339#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2540#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1441#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2442#include <iomanip>
MLIR Team3b692302018-12-17 17:57:1443
MLIR Team38c2fe32019-01-14 19:26:2544#define DEBUG_TYPE "loop-fusion"
45
MLIR Team3b692302018-12-17 17:57:1446using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0047
48using namespace mlir;
49
River Riddle75c21e12019-01-26 06:14:0450static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
51
Uday Bondhugula864d9e02019-01-23 17:16:2452/// Disables fusion profitability check and fuses if valid.
MLIR Teamc4237ae2019-01-18 16:56:2753static llvm::cl::opt<bool>
54 clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
River Riddle75c21e12019-01-26 06:14:0455 llvm::cl::desc("Enables maximal loop fusion"),
56 llvm::cl::cat(clOptionsCategory));
Uday Bondhugula864d9e02019-01-23 17:16:2457
58/// A threshold in percent of additional computation allowed when fusing.
59static llvm::cl::opt<double> clFusionAddlComputeTolerance(
60 "fusion-compute-tolerance", llvm::cl::Hidden,
61 llvm::cl::desc("Fractional increase in additional"
River Riddle75c21e12019-01-26 06:14:0462 " computation tolerated while fusing"),
63 llvm::cl::cat(clOptionsCategory));
MLIR Teamc4237ae2019-01-18 16:56:2764
Uday Bondhugula8be26272019-02-02 01:06:2265static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
66 "fusion-fast-mem-space", llvm::cl::Hidden,
67 llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
68 llvm::cl::cat(clOptionsCategory));
69
70static llvm::cl::opt<unsigned> clFusionLocalBufThreshold(
71 "fusion-local-buf-threshold", llvm::cl::Hidden,
72 llvm::cl::desc("Threshold size (bytes) for promoting local buffers to fast "
73 "memory space"),
74 llvm::cl::cat(clOptionsCategory));
75
MLIR Teamf28e4df2018-11-01 14:26:0076namespace {
77
MLIR Team3b692302018-12-17 17:57:1478/// Loop fusion pass. This pass currently supports a greedy fusion policy,
79/// which fuses loop nests with single-writer/single-reader memref dependences
80/// with the goal of improving locality.
81
82// TODO(andydavis) Support fusion of source loop nests which write to multiple
83// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0084// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
85// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1486
MLIR Teamf28e4df2018-11-01 14:26:0087struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0388 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0089
Chris Lattner79748892018-12-31 07:10:3590 PassResult runOnFunction(Function *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1891 static char passID;
Uday Bondhugula864d9e02019-01-23 17:16:2492
Uday Bondhugula8be26272019-02-02 01:06:2293 // Any local buffers smaller than this size will be created in
94 // `fastMemorySpace` if provided.
95 unsigned localBufSizeThreshold = 1024;
96 Optional<unsigned> fastMemorySpace = None;
97
Uday Bondhugula864d9e02019-01-23 17:16:2498 // The amount of additional computation that is tolerated while fusing
99 // pair-wise as a fraction of the total computation.
100 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:00101};
102
MLIR Teamf28e4df2018-11-01 14:26:00103} // end anonymous namespace
104
Jacques Pienaar6f0fb222018-11-07 02:34:18105char LoopFusion::passID = 0;
106
MLIR Teamf28e4df2018-11-01 14:26:00107FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
108
MLIR Team3b692302018-12-17 17:57:14109namespace {
MLIR Teamf28e4df2018-11-01 14:26:00110
MLIR Team3b692302018-12-17 17:57:14111// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:35112// operations, and whether or not an IfInst was encountered in the loop nest.
River Riddlebf9c3812019-02-05 00:24:44113struct LoopNestStateCollector {
River Riddle5052bd82019-02-02 00:42:18114 SmallVector<OpPointer<AffineForOp>, 4> forOps;
River Riddleb4992772019-02-04 18:38:47115 SmallVector<Instruction *, 4> loadOpInsts;
116 SmallVector<Instruction *, 4> storeOpInsts;
River Riddle75553832019-01-29 05:23:53117 bool hasNonForRegion = false;
MLIR Team3b692302018-12-17 17:57:14118
River Riddlebf9c3812019-02-05 00:24:44119 void collect(Instruction *instToWalk) {
120 instToWalk->walk([&](Instruction *opInst) {
121 if (opInst->isa<AffineForOp>())
122 forOps.push_back(opInst->cast<AffineForOp>());
123 else if (opInst->getNumBlockLists() != 0)
124 hasNonForRegion = true;
125 else if (opInst->isa<LoadOp>())
126 loadOpInsts.push_back(opInst);
127 else if (opInst->isa<StoreOp>())
128 storeOpInsts.push_back(opInst);
129 });
MLIR Team3b692302018-12-17 17:57:14130 }
131};
132
MLIR Team71495d52019-01-22 21:23:37133// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
River Riddleb4992772019-02-04 18:38:47134static bool isMemRefDereferencingOp(const Instruction &op) {
MLIR Team71495d52019-01-22 21:23:37135 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
136 op.isa<DmaWaitOp>())
137 return true;
138 return false;
139}
MLIR Team6892ffb2018-12-20 04:42:55140// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:35141// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55142// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27143// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55144// TODO(andydavis) Add a depth parameter to dependence graph construction.
145struct MemRefDependenceGraph {
146public:
147 // Node represents a node in the graph. A Node is either an entire loop nest
148 // rooted at the top level which contains loads/stores, or a top level
149 // load/store.
150 struct Node {
151 // The unique identifier of this node in the graph.
152 unsigned id;
153 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35154 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41155 // List of load operations.
River Riddleb4992772019-02-04 18:38:47156 SmallVector<Instruction *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35157 // List of store op insts.
River Riddleb4992772019-02-04 18:38:47158 SmallVector<Instruction *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35159 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55160
161 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10162 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55163 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35164 for (auto *loadOpInst : loads) {
165 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55166 ++loadOpCount;
167 }
168 return loadOpCount;
169 }
170
171 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10172 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55173 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35174 for (auto *storeOpInst : stores) {
175 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55176 ++storeOpCount;
177 }
178 return storeOpCount;
179 }
180 };
181
MLIR Teama0f3db402019-01-29 17:36:41182 // Edge represents a data dependece between nodes in the graph.
MLIR Team6892ffb2018-12-20 04:42:55183 struct Edge {
184 // The id of the node at the other end of the edge.
MLIR Team1e851912019-01-31 00:01:46185 // If this edge is stored in Edge = Node.inEdges[i], then
186 // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
187 // If this edge is stored in Edge = Node.outEdges[i], then
188 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
MLIR Team6892ffb2018-12-20 04:42:55189 unsigned id;
MLIR Teama0f3db402019-01-29 17:36:41190 // The SSA value on which this edge represents a dependence.
191 // If the value is a memref, then the dependence is between graph nodes
192 // which contain accesses to the same memref 'value'. If the value is a
193 // non-memref value, then the dependence is between a graph node which
194 // defines an SSA value and another graph node which uses the SSA value
195 // (e.g. a constant instruction defining a value which is used inside a loop
196 // nest).
197 Value *value;
MLIR Team6892ffb2018-12-20 04:42:55198 };
199
200 // Map from node id to Node.
201 DenseMap<unsigned, Node> nodes;
202 // Map from node id to list of input edges.
203 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
204 // Map from node id to list of output edges.
205 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27206 // Map from memref to a count on the dependence edges associated with that
207 // memref.
208 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Teama0f3db402019-01-29 17:36:41209 // The next unique identifier to use for newly created graph nodes.
210 unsigned nextNodeId = 0;
MLIR Team6892ffb2018-12-20 04:42:55211
212 MemRefDependenceGraph() {}
213
214 // Initializes the dependence graph based on operations in 'f'.
215 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09216 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55217
218 // Returns the graph node for 'id'.
219 Node *getNode(unsigned id) {
220 auto it = nodes.find(id);
221 assert(it != nodes.end());
222 return &it->second;
223 }
224
MLIR Teama0f3db402019-01-29 17:36:41225 // Adds a node with 'inst' to the graph and returns its unique identifier.
226 unsigned addNode(Instruction *inst) {
227 Node node(nextNodeId++, inst);
228 nodes.insert({node.id, node});
229 return node.id;
230 }
231
MLIR Teamc4237ae2019-01-18 16:56:27232 // Remove node 'id' (and its associated edges) from graph.
233 void removeNode(unsigned id) {
234 // Remove each edge in 'inEdges[id]'.
235 if (inEdges.count(id) > 0) {
236 SmallVector<Edge, 2> oldInEdges = inEdges[id];
237 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41238 removeEdge(inEdge.id, id, inEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27239 }
240 }
241 // Remove each edge in 'outEdges[id]'.
242 if (outEdges.count(id) > 0) {
243 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
244 for (auto &outEdge : oldOutEdges) {
MLIR Teama0f3db402019-01-29 17:36:41245 removeEdge(id, outEdge.id, outEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27246 }
247 }
248 // Erase remaining node state.
249 inEdges.erase(id);
250 outEdges.erase(id);
251 nodes.erase(id);
252 }
253
MLIR Teamd7c82442019-01-30 23:53:41254 // Returns true if node 'id' writes to any memref which escapes (or is an
255 // argument to) the function/block. Returns false otherwise.
256 bool writesToLiveInOrEscapingMemrefs(unsigned id) {
MLIR Team71495d52019-01-22 21:23:37257 Node *node = getNode(id);
258 for (auto *storeOpInst : node->stores) {
259 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
260 auto *inst = memref->getDefiningInst();
River Riddleb4992772019-02-04 18:38:47261 // Return false if 'memref' is a block argument.
262 if (!inst)
MLIR Teamd7c82442019-01-30 23:53:41263 return true;
MLIR Team71495d52019-01-22 21:23:37264 // Return false if any use of 'memref' escapes the function.
River Riddleb4992772019-02-04 18:38:47265 for (auto &use : memref->getUses())
266 if (!isMemRefDereferencingOp(*use.getOwner()))
MLIR Teamd7c82442019-01-30 23:53:41267 return true;
MLIR Teamd7c82442019-01-30 23:53:41268 }
269 return false;
270 }
271
272 // Returns true if node 'id' can be removed from the graph. Returns false
273 // otherwise. A node can be removed from the graph iff the following
274 // conditions are met:
275 // *) The node does not write to any memref which escapes (or is a
276 // function/block argument).
277 // *) The node has no successors in the dependence graph.
278 bool canRemoveNode(unsigned id) {
279 if (writesToLiveInOrEscapingMemrefs(id))
280 return false;
281 Node *node = getNode(id);
282 for (auto *storeOpInst : node->stores) {
MLIR Teama0f3db402019-01-29 17:36:41283 // Return false if there exist out edges from 'id' on 'memref'.
MLIR Teamd7c82442019-01-30 23:53:41284 if (getOutEdgeCount(id, storeOpInst->cast<StoreOp>()->getMemRef()) > 0)
MLIR Teama0f3db402019-01-29 17:36:41285 return false;
MLIR Team71495d52019-01-22 21:23:37286 }
MLIR Teama0f3db402019-01-29 17:36:41287 return true;
MLIR Team71495d52019-01-22 21:23:37288 }
289
MLIR Team27d067e2019-01-16 17:55:02290 // Returns true iff there is an edge from node 'srcId' to node 'dstId' for
MLIR Teama0f3db402019-01-29 17:36:41291 // 'value'. Returns false otherwise.
292 bool hasEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team27d067e2019-01-16 17:55:02293 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
294 return false;
295 }
296 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41297 return edge.id == dstId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02298 });
299 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41300 return edge.id == srcId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02301 });
302 return hasOutEdge && hasInEdge;
303 }
304
MLIR Teama0f3db402019-01-29 17:36:41305 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
306 void addEdge(unsigned srcId, unsigned dstId, Value *value) {
307 if (!hasEdge(srcId, dstId, value)) {
308 outEdges[srcId].push_back({dstId, value});
309 inEdges[dstId].push_back({srcId, value});
310 if (value->getType().isa<MemRefType>())
311 memrefEdgeCount[value]++;
MLIR Team27d067e2019-01-16 17:55:02312 }
MLIR Team6892ffb2018-12-20 04:42:55313 }
314
MLIR Teama0f3db402019-01-29 17:36:41315 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
316 void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team6892ffb2018-12-20 04:42:55317 assert(inEdges.count(dstId) > 0);
318 assert(outEdges.count(srcId) > 0);
MLIR Teama0f3db402019-01-29 17:36:41319 if (value->getType().isa<MemRefType>()) {
320 assert(memrefEdgeCount.count(value) > 0);
321 memrefEdgeCount[value]--;
322 }
MLIR Team6892ffb2018-12-20 04:42:55323 // Remove 'srcId' from 'inEdges[dstId]'.
324 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41325 if ((*it).id == srcId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55326 inEdges[dstId].erase(it);
327 break;
328 }
329 }
330 // Remove 'dstId' from 'outEdges[srcId]'.
331 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41332 if ((*it).id == dstId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55333 outEdges[srcId].erase(it);
334 break;
335 }
336 }
337 }
338
MLIR Teama0f3db402019-01-29 17:36:41339 // Returns the input edge count for node 'id' and 'memref' from src nodes
340 // which access 'memref'.
341 unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55342 unsigned inEdgeCount = 0;
343 if (inEdges.count(id) > 0)
344 for (auto &inEdge : inEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41345 if (inEdge.value == memref) {
346 Node *srcNode = getNode(inEdge.id);
347 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
348 if (srcNode->getLoadOpCount(memref) > 0 ||
349 srcNode->getStoreOpCount(memref) > 0)
350 ++inEdgeCount;
351 }
MLIR Team6892ffb2018-12-20 04:42:55352 return inEdgeCount;
353 }
354
355 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10356 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55357 unsigned outEdgeCount = 0;
358 if (outEdges.count(id) > 0)
359 for (auto &outEdge : outEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41360 if (outEdge.value == memref)
MLIR Team6892ffb2018-12-20 04:42:55361 ++outEdgeCount;
362 return outEdgeCount;
363 }
364
MLIR Teama0f3db402019-01-29 17:36:41365 // Computes and returns an insertion point instruction, before which the
366 // the fused <srcId, dstId> loop nest can be inserted while preserving
367 // dependences. Returns nullptr if no such insertion point is found.
MLIR Teama78edcd2019-02-05 14:57:08368 Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
MLIR Team5c5739d2019-01-25 06:27:40369 if (outEdges.count(srcId) == 0)
MLIR Teama0f3db402019-01-29 17:36:41370 return getNode(dstId)->inst;
371
372 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
373 SmallPtrSet<Instruction *, 2> srcDepInsts;
374 for (auto &outEdge : outEdges[srcId])
MLIR Teama78edcd2019-02-05 14:57:08375 if (outEdge.id != dstId)
MLIR Teama0f3db402019-01-29 17:36:41376 srcDepInsts.insert(getNode(outEdge.id)->inst);
377
378 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
379 SmallPtrSet<Instruction *, 2> dstDepInsts;
380 for (auto &inEdge : inEdges[dstId])
MLIR Teama78edcd2019-02-05 14:57:08381 if (inEdge.id != srcId)
MLIR Teama0f3db402019-01-29 17:36:41382 dstDepInsts.insert(getNode(inEdge.id)->inst);
383
384 Instruction *srcNodeInst = getNode(srcId)->inst;
385 Instruction *dstNodeInst = getNode(dstId)->inst;
386
387 // Computing insertion point:
388 // *) Walk all instruction positions in Block instruction list in the
389 // range (src, dst). For each instruction 'inst' visited in this search:
390 // *) Store in 'firstSrcDepPos' the first position where 'inst' has a
391 // dependence edge from 'srcNode'.
392 // *) Store in 'lastDstDepPost' the last position where 'inst' has a
393 // dependence edge to 'dstNode'.
394 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
395 // instruction insertion point (or return null pointer if no such
396 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
397 SmallVector<Instruction *, 2> depInsts;
398 Optional<unsigned> firstSrcDepPos;
399 Optional<unsigned> lastDstDepPos;
400 unsigned pos = 0;
401 for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
402 it != Block::iterator(dstNodeInst); ++it) {
403 Instruction *inst = &(*it);
404 if (srcDepInsts.count(inst) > 0 && firstSrcDepPos == None)
405 firstSrcDepPos = pos;
406 if (dstDepInsts.count(inst) > 0)
407 lastDstDepPos = pos;
408 depInsts.push_back(inst);
409 ++pos;
MLIR Team5c5739d2019-01-25 06:27:40410 }
MLIR Teama0f3db402019-01-29 17:36:41411
412 if (firstSrcDepPos.hasValue()) {
413 if (lastDstDepPos.hasValue()) {
414 if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
415 // No valid insertion point exists which preserves dependences.
416 return nullptr;
417 }
418 }
419 // Return the insertion point at 'firstSrcDepPos'.
420 return depInsts[firstSrcDepPos.getValue()];
421 }
422 // No dependence targets in range (or only dst deps in range), return
423 // 'dstNodInst' insertion point.
424 return dstNodeInst;
MLIR Team6892ffb2018-12-20 04:42:55425 }
426
MLIR Teama0f3db402019-01-29 17:36:41427 // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
428 // has been replaced in node at 'dstId' by a private memref.
429 void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
MLIR Team6892ffb2018-12-20 04:42:55430 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
431 if (inEdges.count(srcId) > 0) {
432 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
433 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41434 // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
435 if (inEdge.value != oldMemRef)
436 addEdge(inEdge.id, dstId, inEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55437 }
438 }
MLIR Teamc4237ae2019-01-18 16:56:27439 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55440 if (outEdges.count(srcId) > 0) {
441 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
442 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27443 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
444 if (outEdge.id == dstId)
MLIR Teama0f3db402019-01-29 17:36:41445 removeEdge(srcId, outEdge.id, outEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55446 }
447 }
MLIR Teama0f3db402019-01-29 17:36:41448 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
449 // replaced by a private memref). These edges could come from nodes
450 // other than 'srcId' which were removed in the previous step.
451 if (inEdges.count(dstId) > 0) {
452 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
453 for (auto &inEdge : oldInEdges)
454 if (inEdge.value == oldMemRef)
455 removeEdge(inEdge.id, dstId, inEdge.value);
456 }
MLIR Team6892ffb2018-12-20 04:42:55457 }
458
459 // Adds ops in 'loads' and 'stores' to node at 'id'.
River Riddleb4992772019-02-04 18:38:47460 void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads,
461 const SmallVectorImpl<Instruction *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55462 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35463 for (auto *loadOpInst : loads)
464 node->loads.push_back(loadOpInst);
465 for (auto *storeOpInst : stores)
466 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55467 }
468
MLIR Teamc4237ae2019-01-18 16:56:27469 void clearNodeLoadAndStores(unsigned id) {
470 Node *node = getNode(id);
471 node->loads.clear();
472 node->stores.clear();
473 }
474
MLIR Team6892ffb2018-12-20 04:42:55475 void print(raw_ostream &os) const {
476 os << "\nMemRefDependenceGraph\n";
477 os << "\nNodes:\n";
478 for (auto &idAndNode : nodes) {
479 os << "Node: " << idAndNode.first << "\n";
480 auto it = inEdges.find(idAndNode.first);
481 if (it != inEdges.end()) {
482 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41483 os << " InEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55484 }
485 it = outEdges.find(idAndNode.first);
486 if (it != outEdges.end()) {
487 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41488 os << " OutEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55489 }
490 }
491 }
492 void dump() const { print(llvm::errs()); }
493};
494
Chris Lattner456ad6a2018-12-29 00:05:35495// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55496// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39497// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55498// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09499bool MemRefDependenceGraph::init(Function *f) {
Chris Lattner3f190312018-12-27 22:35:10500 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43501
502 // TODO: support multi-block functions.
503 if (f->getBlocks().size() != 1)
504 return false;
505
River Riddle5052bd82019-02-02 00:42:18506 DenseMap<Instruction *, unsigned> forToNodeMap;
Chris Lattnerdffc5892018-12-29 23:33:43507 for (auto &inst : f->front()) {
River Riddleb4992772019-02-04 18:38:47508 if (auto forOp = inst.dyn_cast<AffineForOp>()) {
River Riddle5052bd82019-02-02 00:42:18509 // Create graph node 'id' to represent top-level 'forOp' and record
MLIR Team6892ffb2018-12-20 04:42:55510 // all loads and store accesses it contains.
511 LoopNestStateCollector collector;
River Riddlebf9c3812019-02-05 00:24:44512 collector.collect(&inst);
River Riddle5052bd82019-02-02 00:42:18513 // Return false if a non 'for' region was found (not currently supported).
River Riddle75553832019-01-29 05:23:53514 if (collector.hasNonForRegion)
MLIR Team6892ffb2018-12-20 04:42:55515 return false;
MLIR Teama0f3db402019-01-29 17:36:41516 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35517 for (auto *opInst : collector.loadOpInsts) {
518 node.loads.push_back(opInst);
519 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55520 memrefAccesses[memref].insert(node.id);
521 }
Chris Lattner456ad6a2018-12-29 00:05:35522 for (auto *opInst : collector.storeOpInsts) {
523 node.stores.push_back(opInst);
524 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55525 memrefAccesses[memref].insert(node.id);
526 }
River Riddle5052bd82019-02-02 00:42:18527 forToNodeMap[&inst] = node.id;
MLIR Team6892ffb2018-12-20 04:42:55528 nodes.insert({node.id, node});
River Riddleb4992772019-02-04 18:38:47529 } else if (auto loadOp = inst.dyn_cast<LoadOp>()) {
530 // Create graph node for top-level load op.
531 Node node(nextNodeId++, &inst);
532 node.loads.push_back(&inst);
533 auto *memref = inst.cast<LoadOp>()->getMemRef();
534 memrefAccesses[memref].insert(node.id);
535 nodes.insert({node.id, node});
536 } else if (auto storeOp = inst.dyn_cast<StoreOp>()) {
537 // Create graph node for top-level store op.
538 Node node(nextNodeId++, &inst);
539 node.stores.push_back(&inst);
540 auto *memref = inst.cast<StoreOp>()->getMemRef();
541 memrefAccesses[memref].insert(node.id);
542 nodes.insert({node.id, node});
543 } else if (inst.getNumBlockLists() != 0) {
544 // Return false if another region is found (not currently supported).
545 return false;
546 } else if (inst.getNumResults() > 0 && !inst.use_empty()) {
547 // Create graph node for top-level producer of SSA values, which
548 // could be used by loop nest nodes.
549 Node node(nextNodeId++, &inst);
550 nodes.insert({node.id, node});
MLIR Teama0f3db402019-01-29 17:36:41551 }
552 }
553
554 // Add dependence edges between nodes which produce SSA values and their
555 // users.
556 for (auto &idAndNode : nodes) {
557 const Node &node = idAndNode.second;
558 if (!node.loads.empty() || !node.stores.empty())
559 continue;
River Riddleb4992772019-02-04 18:38:47560 auto *opInst = node.inst;
MLIR Teama0f3db402019-01-29 17:36:41561 for (auto *value : opInst->getResults()) {
562 for (auto &use : value->getUses()) {
River Riddle5052bd82019-02-02 00:42:18563 SmallVector<OpPointer<AffineForOp>, 4> loops;
River Riddleb4992772019-02-04 18:38:47564 getLoopIVs(*use.getOwner(), &loops);
MLIR Teama0f3db402019-01-29 17:36:41565 if (loops.empty())
566 continue;
River Riddle5052bd82019-02-02 00:42:18567 assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
568 unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()];
MLIR Teama0f3db402019-01-29 17:36:41569 addEdge(node.id, userLoopNestId, value);
MLIR Team6892ffb2018-12-20 04:42:55570 }
571 }
MLIR Team6892ffb2018-12-20 04:42:55572 }
573
574 // Walk memref access lists and add graph edges between dependent nodes.
575 for (auto &memrefAndList : memrefAccesses) {
576 unsigned n = memrefAndList.second.size();
577 for (unsigned i = 0; i < n; ++i) {
578 unsigned srcId = memrefAndList.second[i];
579 bool srcHasStore =
580 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
581 for (unsigned j = i + 1; j < n; ++j) {
582 unsigned dstId = memrefAndList.second[j];
583 bool dstHasStore =
584 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
585 if (srcHasStore || dstHasStore)
586 addEdge(srcId, dstId, memrefAndList.first);
587 }
588 }
589 }
590 return true;
591}
592
MLIR Team38c2fe32019-01-14 19:26:25593namespace {
594
595// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
596// and operation count) for a loop nest up until the innermost loop body.
597struct LoopNestStats {
River Riddle5052bd82019-02-02 00:42:18598 // Map from AffineForOp to immediate child AffineForOps in its loop body.
599 DenseMap<Instruction *, SmallVector<OpPointer<AffineForOp>, 2>> loopMap;
600 // Map from AffineForOp to count of operations in its loop body.
601 DenseMap<Instruction *, uint64_t> opCountMap;
602 // Map from AffineForOp to its constant trip count.
603 DenseMap<Instruction *, uint64_t> tripCountMap;
MLIR Team38c2fe32019-01-14 19:26:25604};
605
606// LoopNestStatsCollector walks a single loop nest and gathers per-loop
607// trip count and operation count statistics and records them in 'stats'.
River Riddlebf9c3812019-02-05 00:24:44608struct LoopNestStatsCollector {
MLIR Team38c2fe32019-01-14 19:26:25609 LoopNestStats *stats;
610 bool hasLoopWithNonConstTripCount = false;
611
612 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
613
River Riddlebf9c3812019-02-05 00:24:44614 void collect(Instruction *inst) {
615 inst->walk<AffineForOp>([&](OpPointer<AffineForOp> forOp) {
616 auto *forInst = forOp->getInstruction();
617 auto *parentInst = forOp->getInstruction()->getParentInst();
618 if (parentInst != nullptr) {
619 assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
620 // Add mapping to 'forOp' from its parent AffineForOp.
621 stats->loopMap[parentInst].push_back(forOp);
622 }
River Riddle5052bd82019-02-02 00:42:18623
River Riddlebf9c3812019-02-05 00:24:44624 // Record the number of op instructions in the body of 'forOp'.
625 unsigned count = 0;
626 stats->opCountMap[forInst] = 0;
627 for (auto &inst : *forOp->getBody()) {
628 if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
629 ++count;
630 }
631 stats->opCountMap[forInst] = count;
632 // Record trip count for 'forOp'. Set flag if trip count is not
633 // constant.
634 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
635 if (!maybeConstTripCount.hasValue()) {
636 hasLoopWithNonConstTripCount = true;
637 return;
638 }
639 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
640 });
MLIR Team38c2fe32019-01-14 19:26:25641 }
642};
643
River Riddle5052bd82019-02-02 00:42:18644// Computes the total cost of the loop nest rooted at 'forOp'.
MLIR Team38c2fe32019-01-14 19:26:25645// Currently, the total cost is computed by counting the total operation
646// instance count (i.e. total number of operations in the loop bodyloop
647// operation count * loop trip count) for the entire loop nest.
648// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
649// specified in the map when computing the total op instance count.
650// NOTE: this is used to compute the cost of computation slices, which are
651// sliced along the iteration dimension, and thus reduce the trip count.
River Riddle5052bd82019-02-02 00:42:18652// If 'computeCostMap' is non-null, the total op count for forOps specified
MLIR Team38c2fe32019-01-14 19:26:25653// in the map is increased (not overridden) by adding the op count from the
654// map to the existing op count for the for loop. This is done before
655// multiplying by the loop's trip count, and is used to model the cost of
656// inserting a sliced loop nest of known cost into the loop's body.
657// NOTE: this is used to compute the cost of fusing a slice of some loop nest
658// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24659static int64_t getComputeCost(
River Riddle5052bd82019-02-02 00:42:18660 Instruction *forInst, LoopNestStats *stats,
661 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountOverrideMap,
662 DenseMap<Instruction *, int64_t> *computeCostMap) {
663 // 'opCount' is the total number operations in one iteration of 'forOp' body
Uday Bondhugula864d9e02019-01-23 17:16:24664 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25665 if (stats->loopMap.count(forInst) > 0) {
River Riddle5052bd82019-02-02 00:42:18666 for (auto childForOp : stats->loopMap[forInst]) {
667 opCount += getComputeCost(childForOp->getInstruction(), stats,
668 tripCountOverrideMap, computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:25669 }
670 }
671 // Add in additional op instances from slice (if specified in map).
672 if (computeCostMap != nullptr) {
673 auto it = computeCostMap->find(forInst);
674 if (it != computeCostMap->end()) {
675 opCount += it->second;
676 }
677 }
678 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24679 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25680 if (tripCountOverrideMap != nullptr) {
681 auto it = tripCountOverrideMap->find(forInst);
682 if (it != tripCountOverrideMap->end()) {
683 tripCount = it->second;
684 }
685 }
686 // Returns the total number of dynamic instances of operations in loop body.
687 return tripCount * opCount;
688}
689
690} // end anonymous namespace
691
MLIR Team27d067e2019-01-16 17:55:02692static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00693 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
694 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02695 assert(lbMap.getNumDims() == ubMap.getNumDims());
696 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
697 // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
698 // ub_expr - lb_expr
699 AffineExpr lbExpr(lbMap.getResult(0));
700 AffineExpr ubExpr(ubMap.getResult(0));
701 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
702 lbMap.getNumSymbols());
703 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
704 if (!cExpr)
705 return None;
706 return cExpr.getValue();
707}
708
River Riddle5052bd82019-02-02 00:42:18709// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
MLIR Team38c2fe32019-01-14 19:26:25710// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
711// Returns true on success, false otherwise (if a non-constant trip count
712// was encountered).
713// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02714static bool buildSliceTripCountMap(
River Riddleb4992772019-02-04 18:38:47715 Instruction *srcOpInst, ComputationSliceState *sliceState,
River Riddle5052bd82019-02-02 00:42:18716 llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
717 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02718 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25719 unsigned numSrcLoopIVs = srcLoopIVs.size();
River Riddle5052bd82019-02-02 00:42:18720 // Populate map from AffineForOp -> trip count
MLIR Team38c2fe32019-01-14 19:26:25721 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
722 AffineMap lbMap = sliceState->lbs[i];
723 AffineMap ubMap = sliceState->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17724 if (lbMap == AffineMap() || ubMap == AffineMap()) {
MLIR Team38c2fe32019-01-14 19:26:25725 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
726 if (srcLoopIVs[i]->hasConstantLowerBound() &&
727 srcLoopIVs[i]->hasConstantUpperBound()) {
River Riddle5052bd82019-02-02 00:42:18728 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] =
MLIR Team38c2fe32019-01-14 19:26:25729 srcLoopIVs[i]->getConstantUpperBound() -
730 srcLoopIVs[i]->getConstantLowerBound();
731 continue;
732 }
733 return false;
734 }
MLIR Team27d067e2019-01-16 17:55:02735 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
736 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25737 return false;
River Riddle5052bd82019-02-02 00:42:18738 (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25739 }
740 return true;
741}
742
MLIR Team27d067e2019-01-16 17:55:02743// Removes load operations from 'srcLoads' which operate on 'memref', and
744// adds them to 'dstLoads'.
745static void
746moveLoadsAccessingMemrefTo(Value *memref,
River Riddleb4992772019-02-04 18:38:47747 SmallVectorImpl<Instruction *> *srcLoads,
748 SmallVectorImpl<Instruction *> *dstLoads) {
MLIR Team27d067e2019-01-16 17:55:02749 dstLoads->clear();
River Riddleb4992772019-02-04 18:38:47750 SmallVector<Instruction *, 4> srcLoadsToKeep;
MLIR Team27d067e2019-01-16 17:55:02751 for (auto *load : *srcLoads) {
752 if (load->cast<LoadOp>()->getMemRef() == memref)
753 dstLoads->push_back(load);
754 else
755 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25756 }
MLIR Team27d067e2019-01-16 17:55:02757 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25758}
759
MLIR Team27d067e2019-01-16 17:55:02760// Returns the innermost common loop depth for the set of operations in 'ops'.
River Riddleb4992772019-02-04 18:38:47761static unsigned getInnermostCommonLoopDepth(ArrayRef<Instruction *> ops) {
MLIR Team27d067e2019-01-16 17:55:02762 unsigned numOps = ops.size();
763 assert(numOps > 0);
764
River Riddle5052bd82019-02-02 00:42:18765 std::vector<SmallVector<OpPointer<AffineForOp>, 4>> loops(numOps);
MLIR Team27d067e2019-01-16 17:55:02766 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
767 for (unsigned i = 0; i < numOps; ++i) {
768 getLoopIVs(*ops[i], &loops[i]);
769 loopDepthLimit =
770 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25771 }
MLIR Team27d067e2019-01-16 17:55:02772
773 unsigned loopDepth = 0;
774 for (unsigned d = 0; d < loopDepthLimit; ++d) {
775 unsigned i;
776 for (i = 1; i < numOps; ++i) {
River Riddle5052bd82019-02-02 00:42:18777 if (loops[i - 1][d] != loops[i][d])
MLIR Team27d067e2019-01-16 17:55:02778 break;
MLIR Team27d067e2019-01-16 17:55:02779 }
780 if (i != numOps)
781 break;
782 ++loopDepth;
783 }
784 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25785}
786
MLIR Teamd7c82442019-01-30 23:53:41787// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
788// and 'storeOpInsts' are satisfied.
River Riddleb4992772019-02-04 18:38:47789static unsigned getMaxLoopDepth(ArrayRef<Instruction *> loadOpInsts,
790 ArrayRef<Instruction *> storeOpInsts) {
MLIR Teamd7c82442019-01-30 23:53:41791 // Merge loads and stores into the same array.
River Riddleb4992772019-02-04 18:38:47792 SmallVector<Instruction *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
MLIR Teamd7c82442019-01-30 23:53:41793 ops.append(storeOpInsts.begin(), storeOpInsts.end());
794
795 // Compute the innermost common loop depth for loads and stores.
796 unsigned loopDepth = getInnermostCommonLoopDepth(ops);
797
798 // Return common loop depth for loads if there are no store ops.
799 if (storeOpInsts.empty())
800 return loopDepth;
801
802 // Check dependences on all pairs of ops in 'ops' and store the minimum
803 // loop depth at which a dependence is satisfied.
804 for (unsigned i = 0, e = ops.size(); i < e; ++i) {
805 auto *srcOpInst = ops[i];
806 MemRefAccess srcAccess(srcOpInst);
807 for (unsigned j = 0; j < e; ++j) {
808 auto *dstOpInst = ops[j];
809 MemRefAccess dstAccess(dstOpInst);
810
811 unsigned numCommonLoops =
812 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
813 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
814 FlatAffineConstraints dependenceConstraints;
815 // TODO(andydavis) Cache dependence analysis results, check cache here.
816 if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
817 &dependenceConstraints,
818 /*dependenceComponents=*/nullptr)) {
819 // Store minimum loop depth and break because we want the min 'd' at
820 // which there is a dependence.
821 loopDepth = std::min(loopDepth, d - 1);
822 break;
823 }
824 }
825 }
826 }
827 return loopDepth;
828}
829
Uday Bondhugulac1ca23e2019-01-16 21:13:00830// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
831// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:02832// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
833// and 'sliceStateB' are aligned.
834// Specifically, when taking the union of overlapping intervals, it assumes
835// that both intervals start at zero. Support needs to be added to take into
836// account interval start offset when computing the union.
837// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:00838static bool getSliceUnion(const ComputationSliceState &sliceStateA,
839 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:02840 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
841 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
842
843 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
844 AffineMap lbMapA = sliceStateA.lbs[i];
845 AffineMap ubMapA = sliceStateA.ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17846 if (lbMapA == AffineMap()) {
847 assert(ubMapA == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02848 continue;
849 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00850 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:02851
852 AffineMap lbMapB = sliceStateB->lbs[i];
853 AffineMap ubMapB = sliceStateB->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17854 if (lbMapB == AffineMap()) {
855 assert(ubMapB == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02856 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
857 sliceStateB->lbs[i] = lbMapA;
858 sliceStateB->ubs[i] = ubMapA;
859 continue;
860 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00861
862 // TODO(andydavis) Change this code to take the min across all lower bounds
863 // and max across all upper bounds for each dimension. This code can for
864 // cases where a unique min or max could not be statically determined.
865
866 // Assumption: both lower bounds are the same.
867 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:02868 return false;
869
870 // Add bound with the largest trip count to union.
871 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
872 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
873 if (!tripCountA.hasValue() || !tripCountB.hasValue())
874 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:00875
MLIR Team27d067e2019-01-16 17:55:02876 if (tripCountA.getValue() > tripCountB.getValue()) {
877 sliceStateB->lbs[i] = lbMapA;
878 sliceStateB->ubs[i] = ubMapA;
879 }
880 }
881 return true;
882}
883
Uday Bondhugula8be26272019-02-02 01:06:22884// TODO(mlir-team): improve/complete this when we have target data.
885unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
886 auto elementType = memRefType.getElementType();
887
888 unsigned sizeInBits;
889 if (elementType.isIntOrFloat()) {
890 sizeInBits = elementType.getIntOrFloatBitWidth();
891 } else {
892 auto vectorType = elementType.cast<VectorType>();
893 sizeInBits =
894 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
895 }
896 return llvm::divideCeil(sizeInBits, 8);
897}
898
MLIR Teamc4237ae2019-01-18 16:56:27899// Creates and returns a private (single-user) memref for fused loop rooted
River Riddle5052bd82019-02-02 00:42:18900// at 'forOp', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:52901// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
902// TODO(bondhugula): consider refactoring the common code from generateDma and
903// this one.
River Riddle5052bd82019-02-02 00:42:18904static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
River Riddleb4992772019-02-04 18:38:47905 Instruction *srcStoreOpInst,
Uday Bondhugula8be26272019-02-02 01:06:22906 unsigned dstLoopDepth,
907 Optional<unsigned> fastMemorySpace,
908 unsigned localBufSizeThreshold) {
River Riddle5052bd82019-02-02 00:42:18909 auto *forInst = forOp->getInstruction();
910
911 // Create builder to insert alloc op just before 'forOp'.
MLIR Teamc4237ae2019-01-18 16:56:27912 FuncBuilder b(forInst);
913 // Builder to create constants at the top level.
914 FuncBuilder top(forInst->getFunction());
915 // Create new memref type based on slice bounds.
916 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>()->getMemRef();
917 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
918 unsigned rank = oldMemRefType.getRank();
919
Uday Bondhugula94a03f82019-01-22 21:58:52920 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
Uday Bondhugula0f504142019-02-04 21:48:44921 MemRefRegion region(srcStoreOpInst->getLoc());
922 region.compute(srcStoreOpInst, dstLoopDepth);
River Riddle6859f332019-01-23 22:39:45923 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:27924 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:52925 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:27926 lbs.reserve(rank);
927 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:52928 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27929 Optional<int64_t> numElements =
Uday Bondhugula0f504142019-02-04 21:48:44930 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
Uday Bondhugula8be26272019-02-02 01:06:22931 assert(numElements.hasValue() &&
932 "non-constant number of elts in local buffer");
MLIR Teamc4237ae2019-01-18 16:56:27933
Uday Bondhugula0f504142019-02-04 21:48:44934 const FlatAffineConstraints *cst = region.getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:52935 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
936 // on; this would correspond to loop IVs surrounding the level at which the
937 // slice is being materialized.
938 SmallVector<Value *, 8> outerIVs;
939 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
940
941 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:27942 SmallVector<AffineExpr, 4> offsets;
943 offsets.reserve(rank);
944 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:52945 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
946
MLIR Teamc4237ae2019-01-18 16:56:27947 AffineExpr offset = top.getAffineConstantExpr(0);
948 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
949 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
950 }
Uday Bondhugula94a03f82019-01-22 21:58:52951 assert(lbDivisors[d] > 0);
952 offset =
953 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:27954 offsets.push_back(offset);
955 }
956
957 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
958 // by 'srcStoreOpInst'.
Uday Bondhugula8be26272019-02-02 01:06:22959 uint64_t bufSize =
960 getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
961 unsigned newMemSpace;
962 if (bufSize < localBufSizeThreshold && fastMemorySpace.hasValue()) {
963 newMemSpace = fastMemorySpace.getValue();
964 } else {
965 newMemSpace = oldMemRefType.getMemorySpace();
966 }
967 auto newMemRefType = top.getMemRefType(
968 newShape, oldMemRefType.getElementType(), {}, newMemSpace);
MLIR Teamc4237ae2019-01-18 16:56:27969 // Gather alloc operands for the dynamic dimensions of the memref.
970 SmallVector<Value *, 4> allocOperands;
971 unsigned dynamicDimCount = 0;
972 for (auto dimSize : oldMemRefType.getShape()) {
973 if (dimSize == -1)
974 allocOperands.push_back(
River Riddle5052bd82019-02-02 00:42:18975 top.create<DimOp>(forOp->getLoc(), oldMemRef, dynamicDimCount++));
MLIR Teamc4237ae2019-01-18 16:56:27976 }
977
River Riddle5052bd82019-02-02 00:42:18978 // Create new private memref for fused loop 'forOp'.
MLIR Teama0f3db402019-01-29 17:36:41979 // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
980 // consumer loop nests to reduce their live range. Currently they are added
981 // at the beginning of the function, because loop nests can be reordered
982 // during the fusion pass.
MLIR Teamc4237ae2019-01-18 16:56:27983 Value *newMemRef =
River Riddle5052bd82019-02-02 00:42:18984 top.create<AllocOp>(forOp->getLoc(), newMemRefType, allocOperands);
MLIR Teamc4237ae2019-01-18 16:56:27985
986 // Build an AffineMap to remap access functions based on lower bound offsets.
987 SmallVector<AffineExpr, 4> remapExprs;
988 remapExprs.reserve(rank);
989 unsigned zeroOffsetCount = 0;
990 for (unsigned i = 0; i < rank; i++) {
991 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
992 if (constExpr.getValue() == 0)
993 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:52994 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
995
996 auto remapExpr =
997 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
998 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:27999 }
Uday Bondhugula94a03f82019-01-22 21:58:521000 auto indexRemap =
1001 zeroOffsetCount == rank
Nicolas Vasilache0e7a8a92019-01-26 18:41:171002 ? AffineMap()
Uday Bondhugula94a03f82019-01-22 21:58:521003 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:271004 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:521005 bool ret =
1006 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
1007 /*extraOperands=*/outerIVs,
River Riddle5052bd82019-02-02 00:42:181008 /*domInstFilter=*/&*forOp->getBody()->begin());
Uday Bondhugula94a03f82019-01-22 21:58:521009 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:371010 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:271011 return newMemRef;
1012}
1013
Uday Bondhugula864d9e02019-01-23 17:16:241014// Does the slice have a single iteration?
1015static uint64_t getSliceIterationCount(
River Riddle5052bd82019-02-02 00:42:181016 const llvm::SmallDenseMap<Instruction *, uint64_t, 8> &sliceTripCountMap) {
Uday Bondhugula864d9e02019-01-23 17:16:241017 uint64_t iterCount = 1;
1018 for (const auto &count : sliceTripCountMap) {
1019 iterCount *= count.second;
1020 }
1021 return iterCount;
1022}
1023
MLIR Team27d067e2019-01-16 17:55:021024// Checks the profitability of fusing a backwards slice of the loop nest
MLIR Teamd7c82442019-01-30 23:53:411025// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
Uday Bondhugulab4a14432019-01-26 00:00:501026// Returns true if it is profitable to fuse the candidate loop nests. Returns
1027// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
1028// to materialize the source loop nest slice.
MLIR Team38c2fe32019-01-14 19:26:251029// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:021030// *) Computes the backward computation slice at 'srcOpInst'. This
1031// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:251032// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:021033// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:251034// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
1035// loop nest is the total number of dynamic operation instances in the loop
1036// nest).
1037// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:021038// loop nest at various values of dst loop depth, attempting to fuse
1039// the largest compution slice at the maximal dst loop depth (closest to the
1040// load) to minimize reuse distance and potentially enable subsequent
1041// load/store forwarding.
MLIR Teamd7c82442019-01-30 23:53:411042// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
MLIR Team27d067e2019-01-16 17:55:021043// the same memref as is written by 'srcOpInst', then the union of slice
1044// loop bounds is used to compute the slice and associated slice cost.
Uday Bondhugulab4a14432019-01-26 00:00:501045// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
MLIR Team38c2fe32019-01-14 19:26:251046// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:021047// NOTE: We attempt to maximize the dst loop depth, but there are cases
1048// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:251049// loop (within the src computation slice) at a depth which results in
1050// execessive recomputation (see unit tests for examples).
1051// *) Compares the total cost of the unfused loop nests to the min cost fused
1052// loop nest computed in the previous step, and returns true if the latter
1053// is lower.
River Riddleb4992772019-02-04 18:38:471054static bool isFusionProfitable(Instruction *srcOpInst,
1055 ArrayRef<Instruction *> dstLoadOpInsts,
1056 ArrayRef<Instruction *> dstStoreOpInsts,
MLIR Team38c2fe32019-01-14 19:26:251057 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:021058 unsigned *dstLoopDepth) {
Uday Bondhugula06d21d92019-01-25 01:01:491059 LLVM_DEBUG({
1060 llvm::dbgs() << "Checking whether fusion is profitable between:\n";
1061 llvm::dbgs() << " ";
1062 srcOpInst->dump();
1063 llvm::dbgs() << " and \n";
MLIR Teamd7c82442019-01-30 23:53:411064 for (auto dstOpInst : dstLoadOpInsts) {
Uday Bondhugula06d21d92019-01-25 01:01:491065 llvm::dbgs() << " ";
1066 dstOpInst->dump();
1067 };
1068 });
Uday Bondhugula864d9e02019-01-23 17:16:241069
MLIR Team38c2fe32019-01-14 19:26:251070 // Compute cost of sliced and unsliced src loop nest.
River Riddle5052bd82019-02-02 00:42:181071 SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:021072 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251073 unsigned numSrcLoopIVs = srcLoopIVs.size();
1074
1075 // Walk src loop nest and collect stats.
1076 LoopNestStats srcLoopNestStats;
1077 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
River Riddlebf9c3812019-02-05 00:24:441078 srcStatsCollector.collect(srcLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251079 // Currently only constant trip count loop nests are supported.
1080 if (srcStatsCollector.hasLoopWithNonConstTripCount)
1081 return false;
1082
1083 // Compute cost of dst loop nest.
River Riddle5052bd82019-02-02 00:42:181084 SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
MLIR Teamd7c82442019-01-30 23:53:411085 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:251086
1087 LoopNestStats dstLoopNestStats;
1088 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
River Riddlebf9c3812019-02-05 00:24:441089 dstStatsCollector.collect(dstLoopIVs[0]->getInstruction());
MLIR Team38c2fe32019-01-14 19:26:251090 // Currently only constant trip count loop nests are supported.
1091 if (dstStatsCollector.hasLoopWithNonConstTripCount)
1092 return false;
1093
MLIR Teamd7c82442019-01-30 23:53:411094 // Compute the maximum loop depth at which we can can insert the src slice
1095 // and still satisfy dest loop nest dependences.
1096 unsigned maxDstLoopDepth = getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts);
MLIR Team27d067e2019-01-16 17:55:021097 if (maxDstLoopDepth == 0)
1098 return false;
1099
1100 // Search for min cost value for 'dstLoopDepth'. At each value of
1101 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
1102 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
1103 // of these bounds). Next the union slice bounds are used to calculate
1104 // the cost of the slice and the cost of the slice inserted into the dst
1105 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:241106 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
1107 uint64_t maxStorageReduction = 0;
1108 Optional<uint64_t> sliceMemEstimate = None;
1109
MLIR Team27d067e2019-01-16 17:55:021110 SmallVector<ComputationSliceState, 4> sliceStates;
1111 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:241112 // The best loop depth at which to materialize the slice.
1113 Optional<unsigned> bestDstLoopDepth = None;
1114
1115 // Compute op instance count for the src loop nest without iteration slicing.
River Riddle5052bd82019-02-02 00:42:181116 uint64_t srcLoopNestCost =
1117 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
1118 /*tripCountOverrideMap=*/nullptr,
1119 /*computeCostMap=*/nullptr);
Uday Bondhugula864d9e02019-01-23 17:16:241120
MLIR Teamb9dde912019-02-06 19:01:101121 // Compute src loop nest write region size.
1122 MemRefRegion srcWriteRegion(srcOpInst->getLoc());
1123 srcWriteRegion.compute(srcOpInst, /*loopDepth=*/0);
1124 Optional<int64_t> maybeSrcWriteRegionSizeBytes =
1125 srcWriteRegion.getRegionSize();
1126 if (!maybeSrcWriteRegionSizeBytes.hasValue())
1127 return false;
1128 int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
1129
Uday Bondhugula864d9e02019-01-23 17:16:241130 // Compute op instance count for the src loop nest.
River Riddle5052bd82019-02-02 00:42:181131 uint64_t dstLoopNestCost =
1132 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
1133 /*tripCountOverrideMap=*/nullptr,
1134 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:021135
MLIR Teamb9dde912019-02-06 19:01:101136 // Evaluate all depth choices for materializing the slice in the destination
1137 // loop nest.
River Riddle5052bd82019-02-02 00:42:181138 llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap;
1139 DenseMap<Instruction *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:021140 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
1141 MemRefAccess srcAccess(srcOpInst);
1142 // Handle the common case of one dst load without a copy.
1143 if (!mlir::getBackwardComputationSliceState(
MLIR Teamd7c82442019-01-30 23:53:411144 srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1]))
MLIR Team27d067e2019-01-16 17:55:021145 return false;
MLIR Teamd7c82442019-01-30 23:53:411146 // Compute the union of slice bound of all ops in 'dstLoadOpInsts'.
1147 for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) {
1148 MemRefAccess dstAccess(dstLoadOpInsts[j]);
MLIR Team27d067e2019-01-16 17:55:021149 ComputationSliceState tmpSliceState;
1150 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
1151 &tmpSliceState))
1152 return false;
1153 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:001154 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:251155 }
Uday Bondhugulab4a14432019-01-26 00:00:501156 // Build trip count map for computation slice. We'll skip cases where the
1157 // trip count was non-constant.
MLIR Team27d067e2019-01-16 17:55:021158 sliceTripCountMap.clear();
1159 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
1160 &sliceTripCountMap))
Uday Bondhugula864d9e02019-01-23 17:16:241161 continue;
1162
1163 // Checks whether a store to load forwarding will happen.
1164 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
Uday Bondhugula864d9e02019-01-23 17:16:241165 assert(sliceIterationCount > 0);
Uday Bondhugulab4a14432019-01-26 00:00:501166 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
Uday Bondhugula864d9e02019-01-23 17:16:241167
1168 // Compute cost of fusion for this dest loop depth.
1169
1170 computeCostMap.clear();
1171
1172 // The store and loads to this memref will disappear.
1173 if (storeLoadFwdGuaranteed) {
1174 // A single store disappears: -1 for that.
River Riddle5052bd82019-02-02 00:42:181175 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
MLIR Teamd7c82442019-01-30 23:53:411176 for (auto *loadOp : dstLoadOpInsts) {
River Riddle5052bd82019-02-02 00:42:181177 auto *parentInst = loadOp->getParentInst();
River Riddleb4992772019-02-04 18:38:471178 if (parentInst && parentInst->isa<AffineForOp>())
River Riddle5052bd82019-02-02 00:42:181179 computeCostMap[parentInst] = -1;
Uday Bondhugula864d9e02019-01-23 17:16:241180 }
1181 }
MLIR Team27d067e2019-01-16 17:55:021182
MLIR Team38c2fe32019-01-14 19:26:251183 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:241184 int64_t sliceComputeCost =
River Riddle5052bd82019-02-02 00:42:181185 getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
Uday Bondhugula864d9e02019-01-23 17:16:241186 /*tripCountOverrideMap=*/&sliceTripCountMap,
1187 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:251188
Uday Bondhugula864d9e02019-01-23 17:16:241189 // Compute cost of fusion for this depth.
River Riddle5052bd82019-02-02 00:42:181190 computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:241191
1192 int64_t fusedLoopNestComputeCost =
River Riddle5052bd82019-02-02 00:42:181193 getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
MLIR Team27d067e2019-01-16 17:55:021194 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:241195
1196 double additionalComputeFraction =
1197 fusedLoopNestComputeCost /
1198 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1199 1;
1200
MLIR Teamb9dde912019-02-06 19:01:101201 // Compute what the slice write MemRefRegion would be, if the src loop
1202 // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
1203 // nest at loop depth 'i'
1204 MemRefRegion sliceWriteRegion(srcOpInst->getLoc());
1205 sliceWriteRegion.compute(srcOpInst, /*loopDepth=*/0, &sliceStates[i - 1]);
1206 Optional<int64_t> maybeSliceWriteRegionSizeBytes =
1207 sliceWriteRegion.getRegionSize();
1208 if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
1209 maybeSliceWriteRegionSizeBytes.getValue() == 0)
1210 continue;
1211 int64_t sliceWriteRegionSizeBytes =
1212 maybeSliceWriteRegionSizeBytes.getValue();
1213
1214 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
1215 static_cast<double>(sliceWriteRegionSizeBytes);
Uday Bondhugula864d9e02019-01-23 17:16:241216
Uday Bondhugula06d21d92019-01-25 01:01:491217 LLVM_DEBUG({
1218 std::stringstream msg;
1219 msg << " evaluating fusion profitability at depth : " << i << "\n"
1220 << std::setprecision(2) << " additional compute fraction: "
1221 << 100.0 * additionalComputeFraction << "%\n"
1222 << " storage reduction factor: " << storageReduction << "x\n"
1223 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
1224 << " slice iteration count: " << sliceIterationCount << "\n";
1225 llvm::dbgs() << msg.str();
1226 });
Uday Bondhugula864d9e02019-01-23 17:16:241227
1228 double computeToleranceThreshold =
1229 clFusionAddlComputeTolerance.getNumOccurrences() > 0
1230 ? clFusionAddlComputeTolerance
1231 : LoopFusion::kComputeToleranceThreshold;
1232
1233 // TODO(b/123247369): This is a placeholder cost model.
1234 // Among all choices that add an acceptable amount of redundant computation
1235 // (as per computeToleranceThreshold), we will simply pick the one that
1236 // reduces the intermediary size the most.
1237 if ((storageReduction > maxStorageReduction) &&
1238 (clMaximalLoopFusion ||
1239 (additionalComputeFraction < computeToleranceThreshold))) {
1240 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:021241 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:241242 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
MLIR Teamb9dde912019-02-06 19:01:101243 sliceMemEstimate = sliceWriteRegionSizeBytes;
MLIR Team38c2fe32019-01-14 19:26:251244 }
1245 }
1246
Uday Bondhugula864d9e02019-01-23 17:16:241247 // A simple cost model: fuse if it reduces the memory footprint. If
1248 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251249
Uday Bondhugula864d9e02019-01-23 17:16:241250 if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
1251 LLVM_DEBUG(llvm::dbgs()
1252 << "All fusion choices involve more than the threshold amount of"
1253 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251254 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241255 }
1256
1257 assert(bestDstLoopDepth.hasValue() &&
1258 "expected to have a value per logic above");
1259
1260 // Set dstLoopDepth based on best values from search.
1261 *dstLoopDepth = bestDstLoopDepth.getValue();
1262
1263 LLVM_DEBUG(
Uday Bondhugula06d21d92019-01-25 01:01:491264 llvm::dbgs() << " LoopFusion fusion stats:"
1265 << "\n best loop depth: " << bestDstLoopDepth
Uday Bondhugula864d9e02019-01-23 17:16:241266 << "\n src loop nest compute cost: " << srcLoopNestCost
1267 << "\n dst loop nest compute cost: " << dstLoopNestCost
1268 << "\n fused loop nest compute cost: "
1269 << minFusedLoopNestComputeCost << "\n");
1270
River Riddle5052bd82019-02-02 00:42:181271 auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
1272 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Uday Bondhugula864d9e02019-01-23 17:16:241273
1274 Optional<double> storageReduction = None;
1275
1276 if (!clMaximalLoopFusion) {
1277 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1278 LLVM_DEBUG(
1279 llvm::dbgs()
1280 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1281 return false;
1282 }
1283
1284 auto srcMemSizeVal = srcMemSize.getValue();
1285 auto dstMemSizeVal = dstMemSize.getValue();
1286
1287 assert(sliceMemEstimate.hasValue() && "expected value");
1288 // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
1289 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1290
1291 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1292 << " dst mem: " << dstMemSizeVal << "\n"
1293 << " fused mem: " << fusedMem << "\n"
1294 << " slice mem: " << sliceMemEstimate << "\n");
1295
1296 if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
1297 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1298 return false;
1299 }
1300 storageReduction =
1301 100.0 *
1302 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1303 }
1304
1305 double additionalComputeFraction =
1306 100.0 * (minFusedLoopNestComputeCost /
1307 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1308 1);
MLIR Team5c5739d2019-01-25 06:27:401309 (void)additionalComputeFraction;
Uday Bondhugula06d21d92019-01-25 01:01:491310 LLVM_DEBUG({
1311 std::stringstream msg;
1312 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
1313 << setprecision(2) << additionalComputeFraction
1314 << "% redundant computation and a ";
1315 msg << (storageReduction.hasValue()
1316 ? std::to_string(storageReduction.getValue())
1317 : "<unknown>");
1318 msg << "% storage reduction.\n";
1319 llvm::dbgs() << msg.str();
1320 });
Uday Bondhugula864d9e02019-01-23 17:16:241321
MLIR Team27d067e2019-01-16 17:55:021322 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241323 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021324 sliceState->lbs = bestSliceState->lbs;
1325 sliceState->ubs = bestSliceState->ubs;
1326 sliceState->lbOperands = bestSliceState->lbOperands;
1327 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241328
MLIR Team27d067e2019-01-16 17:55:021329 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251330 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Nicolas Vasilache0e7a8a92019-01-26 18:41:171331 if (sliceState->lbs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021332 canonicalizeMapAndOperands(&sliceState->lbs[i],
1333 &sliceState->lbOperands[i]);
1334 }
Nicolas Vasilache0e7a8a92019-01-26 18:41:171335 if (sliceState->ubs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021336 canonicalizeMapAndOperands(&sliceState->ubs[i],
1337 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251338 }
1339 }
1340 return true;
1341}
1342
MLIR Team6892ffb2018-12-20 04:42:551343// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:141344// relationship on a memref, with the goal of improving locality. Currently,
1345// this the producer/consumer relationship is required to be unique in the
Chris Lattner69d9e992018-12-28 16:48:091346// Function (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:001347//
MLIR Team3b692302018-12-17 17:57:141348// The steps of the algorithm are as follows:
1349//
MLIR Team6892ffb2018-12-20 04:42:551350// *) A worklist is initialized with node ids from the dependence graph.
1351// *) For each node id in the worklist:
River Riddle5052bd82019-02-02 00:42:181352// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a
1353// candidate destination AffineForOp into which fusion will be attempted.
1354// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141355// *) For each LoadOp in 'dstLoadOps' do:
Chris Lattner69d9e992018-12-28 16:48:091356// *) Lookup dependent loop nests at earlier positions in the Function
MLIR Team3b692302018-12-17 17:57:141357// which have a single store op to the same memref.
1358// *) Check if dependences would be violated by the fusion. For example,
1359// the src loop nest may load from memrefs which are different than
1360// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:551361// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141362// bounds to be functions of 'dstLoopNest' IVs and symbols.
1363// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
1364// just before the dst load op user.
Chris Lattner456ad6a2018-12-29 00:05:351365// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:141366// and also add newly fuse load ops to 'dstLoopOps' to be considered
1367// as fusion dst load ops in another iteration.
1368// *) Remove old src loop nest and its associated state.
1369//
Chris Lattner456ad6a2018-12-29 00:05:351370// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141371// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551372// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141373//
MLIR Team6892ffb2018-12-20 04:42:551374// This greedy algorithm is not 'maximal' due to the current restriction of
1375// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141376//
1377// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551378// TODO(andydavis) Add support for fusing for input reuse (perhaps by
1379// constructing a graph with edges which represent loads from the same memref
MLIR Team5c5739d2019-01-25 06:27:401380// in two different loop nests.
MLIR Team6892ffb2018-12-20 04:42:551381struct GreedyFusion {
1382public:
1383 MemRefDependenceGraph *mdg;
MLIR Teama78edcd2019-02-05 14:57:081384 SmallVector<unsigned, 8> worklist;
1385 llvm::SmallDenseSet<unsigned, 16> worklistSet;
MLIR Teamf28e4df2018-11-01 14:26:001386
MLIR Team6892ffb2018-12-20 04:42:551387 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
1388 // Initialize worklist with nodes from 'mdg'.
MLIR Teama78edcd2019-02-05 14:57:081389 // TODO(andydavis) Add a priority queue for prioritizing nodes by different
1390 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
MLIR Team6892ffb2018-12-20 04:42:551391 worklist.resize(mdg->nodes.size());
1392 std::iota(worklist.begin(), worklist.end(), 0);
MLIR Teama78edcd2019-02-05 14:57:081393 worklistSet.insert(worklist.begin(), worklist.end());
MLIR Team6892ffb2018-12-20 04:42:551394 }
MLIR Team3b692302018-12-17 17:57:141395
Uday Bondhugula8be26272019-02-02 01:06:221396 void run(unsigned localBufSizeThreshold, Optional<unsigned> fastMemorySpace) {
MLIR Team3b692302018-12-17 17:57:141397 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551398 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141399 worklist.pop_back();
MLIR Teama78edcd2019-02-05 14:57:081400 worklistSet.erase(dstId);
1401
MLIR Team6892ffb2018-12-20 04:42:551402 // Skip if this node was removed (fused into another node).
1403 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141404 continue;
MLIR Team6892ffb2018-12-20 04:42:551405 // Get 'dstNode' into which to attempt fusion.
1406 auto *dstNode = mdg->getNode(dstId);
1407 // Skip if 'dstNode' is not a loop nest.
River Riddleb4992772019-02-04 18:38:471408 if (!dstNode->inst->isa<AffineForOp>())
MLIR Team3b692302018-12-17 17:57:141409 continue;
1410
River Riddleb4992772019-02-04 18:38:471411 SmallVector<Instruction *, 4> loads = dstNode->loads;
1412 SmallVector<Instruction *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271413 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551414 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021415 // Get memref of load on top of the stack.
1416 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271417 if (visitedMemrefs.count(memref) > 0)
1418 continue;
1419 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021420 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1421 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551422 // Skip if no input edges along which to fuse.
1423 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141424 continue;
MLIR Team1e851912019-01-31 00:01:461425 // Iterate through in edges for 'dstId' and src node id for any
1426 // edges on 'memref'.
1427 SmallVector<unsigned, 2> srcNodeIds;
MLIR Team6892ffb2018-12-20 04:42:551428 for (auto &srcEdge : mdg->inEdges[dstId]) {
1429 // Skip 'srcEdge' if not for 'memref'.
MLIR Teama0f3db402019-01-29 17:36:411430 if (srcEdge.value != memref)
MLIR Team6892ffb2018-12-20 04:42:551431 continue;
MLIR Team1e851912019-01-31 00:01:461432 srcNodeIds.push_back(srcEdge.id);
1433 }
1434 for (unsigned srcId : srcNodeIds) {
1435 // Skip if this node was removed (fused into another node).
1436 if (mdg->nodes.count(srcId) == 0)
1437 continue;
1438 // Get 'srcNode' from which to attempt fusion into 'dstNode'.
1439 auto *srcNode = mdg->getNode(srcId);
MLIR Team6892ffb2018-12-20 04:42:551440 // Skip if 'srcNode' is not a loop nest.
River Riddleb4992772019-02-04 18:38:471441 if (!srcNode->inst->isa<AffineForOp>())
MLIR Team6892ffb2018-12-20 04:42:551442 continue;
MLIR Teamb28009b2019-01-23 19:11:431443 // Skip if 'srcNode' has more than one store to any memref.
1444 // TODO(andydavis) Support fusing multi-output src loop nests.
1445 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551446 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241447
MLIR Teama0f3db402019-01-29 17:36:411448 // Skip 'srcNode' if it has in edges on 'memref'.
MLIR Team6892ffb2018-12-20 04:42:551449 // TODO(andydavis) Track dependence type with edges, and just check
MLIR Teama0f3db402019-01-29 17:36:411450 // for WAW dependence edge here. Note that this check is overly
1451 // conservative and will be removed in the future.
1452 if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0)
MLIR Team6892ffb2018-12-20 04:42:551453 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241454
MLIR Teamd7c82442019-01-30 23:53:411455 // Skip if 'srcNode' writes to any live in or escaping memrefs.
1456 if (mdg->writesToLiveInOrEscapingMemrefs(srcNode->id))
1457 continue;
1458
MLIR Teama0f3db402019-01-29 17:36:411459 // Compute an instruction list insertion point for the fused loop
1460 // nest which preserves dependences.
MLIR Teama78edcd2019-02-05 14:57:081461 Instruction *insertPointInst =
1462 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
MLIR Teama0f3db402019-01-29 17:36:411463 if (insertPointInst == nullptr)
MLIR Team6892ffb2018-12-20 04:42:551464 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241465
MLIR Team6892ffb2018-12-20 04:42:551466 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351467 auto *srcStoreOpInst = srcNode->stores.front();
MLIR Teamd7c82442019-01-30 23:53:411468 // Gather 'dstNode' store ops to 'memref'.
River Riddleb4992772019-02-04 18:38:471469 SmallVector<Instruction *, 2> dstStoreOpInsts;
MLIR Teamd7c82442019-01-30 23:53:411470 for (auto *storeOpInst : dstNode->stores)
1471 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1472 dstStoreOpInsts.push_back(storeOpInst);
1473
Uday Bondhugulab4a14432019-01-26 00:00:501474 unsigned bestDstLoopDepth;
MLIR Team38c2fe32019-01-14 19:26:251475 mlir::ComputationSliceState sliceState;
MLIR Teama0f3db402019-01-29 17:36:411476 // Check if fusion would be profitable.
MLIR Teamd7c82442019-01-30 23:53:411477 if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts,
1478 dstStoreOpInsts, &sliceState,
Uday Bondhugulab4a14432019-01-26 00:00:501479 &bestDstLoopDepth))
MLIR Team38c2fe32019-01-14 19:26:251480 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241481
MLIR Team6892ffb2018-12-20 04:42:551482 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
River Riddle5052bd82019-02-02 00:42:181483 auto sliceLoopNest = mlir::insertBackwardComputationSlice(
Uday Bondhugulab4a14432019-01-26 00:00:501484 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:551485 if (sliceLoopNest != nullptr) {
River Riddle5052bd82019-02-02 00:42:181486 // Move 'dstAffineForOp' before 'insertPointInst' if needed.
River Riddleb4992772019-02-04 18:38:471487 auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
River Riddle5052bd82019-02-02 00:42:181488 if (insertPointInst != dstAffineForOp->getInstruction()) {
1489 dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
MLIR Teama0f3db402019-01-29 17:36:411490 }
MLIR Teamc4237ae2019-01-18 16:56:271491 // Update edges between 'srcNode' and 'dstNode'.
MLIR Teama0f3db402019-01-29 17:36:411492 mdg->updateEdges(srcNode->id, dstNode->id, memref);
MLIR Teamc4237ae2019-01-18 16:56:271493
1494 // Collect slice loop stats.
1495 LoopNestStateCollector sliceCollector;
River Riddlebf9c3812019-02-05 00:24:441496 sliceCollector.collect(sliceLoopNest->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271497 // Promote single iteration slice loops to single IV value.
River Riddle5052bd82019-02-02 00:42:181498 for (auto forOp : sliceCollector.forOps) {
1499 promoteIfSingleIteration(forOp);
MLIR Team6892ffb2018-12-20 04:42:551500 }
River Riddle5052bd82019-02-02 00:42:181501 // Create private memref for 'memref' in 'dstAffineForOp'.
River Riddleb4992772019-02-04 18:38:471502 SmallVector<Instruction *, 4> storesForMemref;
MLIR Teamc4237ae2019-01-18 16:56:271503 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
1504 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1505 storesForMemref.push_back(storeOpInst);
1506 }
1507 assert(storesForMemref.size() == 1);
Uday Bondhugula94a03f82019-01-22 21:58:521508 auto *newMemRef = createPrivateMemRef(
Uday Bondhugula8be26272019-02-02 01:06:221509 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1510 fastMemorySpace, localBufSizeThreshold);
MLIR Teamc4237ae2019-01-18 16:56:271511 visitedMemrefs.insert(newMemRef);
MLIR Teama0f3db402019-01-29 17:36:411512 // Create new node in dependence graph for 'newMemRef' alloc op.
1513 unsigned newMemRefNodeId =
1514 mdg->addNode(newMemRef->getDefiningInst());
1515 // Add edge from 'newMemRef' node to dstNode.
1516 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
MLIR Teamc4237ae2019-01-18 16:56:271517
1518 // Collect dst loop stats after memref privatizaton transformation.
1519 LoopNestStateCollector dstLoopCollector;
River Riddlebf9c3812019-02-05 00:24:441520 dstLoopCollector.collect(dstAffineForOp->getInstruction());
MLIR Teamc4237ae2019-01-18 16:56:271521
1522 // Add new load ops to current Node load op list 'loads' to
1523 // continue fusing based on new operands.
1524 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
1525 auto *loadMemRef = loadOpInst->cast<LoadOp>()->getMemRef();
1526 if (visitedMemrefs.count(loadMemRef) == 0)
1527 loads.push_back(loadOpInst);
1528 }
1529
1530 // Clear and add back loads and stores
1531 mdg->clearNodeLoadAndStores(dstNode->id);
1532 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1533 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371534 // Remove old src loop nest if it no longer has outgoing dependence
1535 // edges, and it does not write to a memref which escapes the
1536 // function.
MLIR Teama0f3db402019-01-29 17:36:411537 if (mdg->canRemoveNode(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271538 mdg->removeNode(srcNode->id);
River Riddle5052bd82019-02-02 00:42:181539 srcNode->inst->erase();
MLIR Teama78edcd2019-02-05 14:57:081540 } else {
1541 // Add remaining users of 'oldMemRef' back on the worklist (if not
1542 // already there), as its replacement with a local/private memref
1543 // has reduced dependences on 'oldMemRef' which may have created
1544 // new fusion opportunities.
1545 if (mdg->outEdges.count(srcNode->id) > 0) {
1546 SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
1547 mdg->outEdges[srcNode->id];
1548 for (auto &outEdge : oldOutEdges) {
1549 if (outEdge.value == memref &&
1550 worklistSet.count(outEdge.id) == 0) {
1551 worklist.push_back(outEdge.id);
1552 worklistSet.insert(outEdge.id);
1553 }
1554 }
1555 }
MLIR Teamc4237ae2019-01-18 16:56:271556 }
MLIR Team3b692302018-12-17 17:57:141557 }
MLIR Team3b692302018-12-17 17:57:141558 }
1559 }
1560 }
MLIR Teamc4237ae2019-01-18 16:56:271561 // Clean up any allocs with no users.
1562 for (auto &pair : mdg->memrefEdgeCount) {
1563 if (pair.second > 0)
1564 continue;
1565 auto *memref = pair.first;
MLIR Team71495d52019-01-22 21:23:371566 // Skip if there exist other uses (return instruction or function calls).
1567 if (!memref->use_empty())
1568 continue;
MLIR Teamc4237ae2019-01-18 16:56:271569 // Use list expected to match the dep graph info.
MLIR Teamc4237ae2019-01-18 16:56:271570 auto *inst = memref->getDefiningInst();
River Riddleb4992772019-02-04 18:38:471571 if (inst && inst->isa<AllocOp>())
1572 inst->erase();
MLIR Teamc4237ae2019-01-18 16:56:271573 }
MLIR Teamf28e4df2018-11-01 14:26:001574 }
MLIR Team3b692302018-12-17 17:57:141575};
1576
1577} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:001578
Chris Lattner79748892018-12-31 07:10:351579PassResult LoopFusion::runOnFunction(Function *f) {
Uday Bondhugula8be26272019-02-02 01:06:221580 if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
1581 fastMemorySpace = clFusionFastMemorySpace.getValue();
1582 }
1583
MLIR Team6892ffb2018-12-20 04:42:551584 MemRefDependenceGraph g;
1585 if (g.init(f))
Uday Bondhugula8be26272019-02-02 01:06:221586 GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace);
MLIR Teamf28e4df2018-11-01 14:26:001587 return success();
1588}
Jacques Pienaar6f0fb222018-11-07 02:34:181589
1590static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");