blob: cee0a08a63cfe825479264e9e76a71cb01737f17 [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1423#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0024#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1425#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0026#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinOps.h"
Chris Lattner456ad6a2018-12-29 00:05:3530#include "mlir/IR/InstVisitor.h"
MLIR Teamf28e4df2018-11-01 14:26:0031#include "mlir/Pass.h"
32#include "mlir/StandardOps/StandardOps.h"
33#include "mlir/Transforms/LoopUtils.h"
34#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2735#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0036#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1437#include "llvm/ADT/DenseSet.h"
38#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2339#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2540#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1441#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2442#include <iomanip>
MLIR Team3b692302018-12-17 17:57:1443
MLIR Team38c2fe32019-01-14 19:26:2544#define DEBUG_TYPE "loop-fusion"
45
MLIR Team3b692302018-12-17 17:57:1446using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0047
48using namespace mlir;
49
River Riddle75c21e12019-01-26 06:14:0450static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
51
Uday Bondhugula864d9e02019-01-23 17:16:2452/// Disables fusion profitability check and fuses if valid.
MLIR Teamc4237ae2019-01-18 16:56:2753static llvm::cl::opt<bool>
54 clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
River Riddle75c21e12019-01-26 06:14:0455 llvm::cl::desc("Enables maximal loop fusion"),
56 llvm::cl::cat(clOptionsCategory));
Uday Bondhugula864d9e02019-01-23 17:16:2457
58/// A threshold in percent of additional computation allowed when fusing.
59static llvm::cl::opt<double> clFusionAddlComputeTolerance(
60 "fusion-compute-tolerance", llvm::cl::Hidden,
61 llvm::cl::desc("Fractional increase in additional"
River Riddle75c21e12019-01-26 06:14:0462 " computation tolerated while fusing"),
63 llvm::cl::cat(clOptionsCategory));
MLIR Teamc4237ae2019-01-18 16:56:2764
MLIR Teamf28e4df2018-11-01 14:26:0065namespace {
66
MLIR Team3b692302018-12-17 17:57:1467/// Loop fusion pass. This pass currently supports a greedy fusion policy,
68/// which fuses loop nests with single-writer/single-reader memref dependences
69/// with the goal of improving locality.
70
71// TODO(andydavis) Support fusion of source loop nests which write to multiple
72// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0073// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
74// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1475
MLIR Teamf28e4df2018-11-01 14:26:0076struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0377 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0078
Chris Lattner79748892018-12-31 07:10:3579 PassResult runOnFunction(Function *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1880 static char passID;
Uday Bondhugula864d9e02019-01-23 17:16:2481
82 // The amount of additional computation that is tolerated while fusing
83 // pair-wise as a fraction of the total computation.
84 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:0085};
86
MLIR Teamf28e4df2018-11-01 14:26:0087} // end anonymous namespace
88
Jacques Pienaar6f0fb222018-11-07 02:34:1889char LoopFusion::passID = 0;
90
MLIR Teamf28e4df2018-11-01 14:26:0091FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
92
MLIR Team3b692302018-12-17 17:57:1493namespace {
MLIR Teamf28e4df2018-11-01 14:26:0094
MLIR Team3b692302018-12-17 17:57:1495// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:3596// operations, and whether or not an IfInst was encountered in the loop nest.
97class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
MLIR Team3b692302018-12-17 17:57:1498public:
Chris Lattner456ad6a2018-12-29 00:05:3599 SmallVector<ForInst *, 4> forInsts;
100 SmallVector<OperationInst *, 4> loadOpInsts;
101 SmallVector<OperationInst *, 4> storeOpInsts;
102 bool hasIfInst = false;
MLIR Team3b692302018-12-17 17:57:14103
Chris Lattner456ad6a2018-12-29 00:05:35104 void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
MLIR Team3b692302018-12-17 17:57:14105
Chris Lattner456ad6a2018-12-29 00:05:35106 void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
MLIR Team3b692302018-12-17 17:57:14107
Chris Lattner456ad6a2018-12-29 00:05:35108 void visitOperationInst(OperationInst *opInst) {
109 if (opInst->isa<LoadOp>())
110 loadOpInsts.push_back(opInst);
111 if (opInst->isa<StoreOp>())
112 storeOpInsts.push_back(opInst);
MLIR Team3b692302018-12-17 17:57:14113 }
114};
115
MLIR Team71495d52019-01-22 21:23:37116// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
117static bool isMemRefDereferencingOp(const OperationInst &op) {
118 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
119 op.isa<DmaWaitOp>())
120 return true;
121 return false;
122}
MLIR Team6892ffb2018-12-20 04:42:55123// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:35124// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55125// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27126// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55127// TODO(andydavis) Add a depth parameter to dependence graph construction.
128struct MemRefDependenceGraph {
129public:
130 // Node represents a node in the graph. A Node is either an entire loop nest
131 // rooted at the top level which contains loads/stores, or a top level
132 // load/store.
133 struct Node {
134 // The unique identifier of this node in the graph.
135 unsigned id;
136 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35137 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41138 // List of load operations.
139 SmallVector<OperationInst *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35140 // List of store op insts.
Chris Lattner5187cfc2018-12-28 05:21:41141 SmallVector<OperationInst *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35142 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55143
144 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10145 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55146 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35147 for (auto *loadOpInst : loads) {
148 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55149 ++loadOpCount;
150 }
151 return loadOpCount;
152 }
153
154 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10155 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55156 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35157 for (auto *storeOpInst : stores) {
158 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55159 ++storeOpCount;
160 }
161 return storeOpCount;
162 }
163 };
164
165 // Edge represents a memref data dependece between nodes in the graph.
166 struct Edge {
167 // The id of the node at the other end of the edge.
168 unsigned id;
169 // The memref on which this edge represents a dependence.
Chris Lattner3f190312018-12-27 22:35:10170 Value *memref;
MLIR Team6892ffb2018-12-20 04:42:55171 };
172
173 // Map from node id to Node.
174 DenseMap<unsigned, Node> nodes;
175 // Map from node id to list of input edges.
176 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
177 // Map from node id to list of output edges.
178 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27179 // Map from memref to a count on the dependence edges associated with that
180 // memref.
181 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Team6892ffb2018-12-20 04:42:55182
183 MemRefDependenceGraph() {}
184
185 // Initializes the dependence graph based on operations in 'f'.
186 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09187 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55188
189 // Returns the graph node for 'id'.
190 Node *getNode(unsigned id) {
191 auto it = nodes.find(id);
192 assert(it != nodes.end());
193 return &it->second;
194 }
195
MLIR Teamc4237ae2019-01-18 16:56:27196 // Remove node 'id' (and its associated edges) from graph.
197 void removeNode(unsigned id) {
198 // Remove each edge in 'inEdges[id]'.
199 if (inEdges.count(id) > 0) {
200 SmallVector<Edge, 2> oldInEdges = inEdges[id];
201 for (auto &inEdge : oldInEdges) {
202 removeEdge(inEdge.id, id, inEdge.memref);
203 }
204 }
205 // Remove each edge in 'outEdges[id]'.
206 if (outEdges.count(id) > 0) {
207 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
208 for (auto &outEdge : oldOutEdges) {
209 removeEdge(id, outEdge.id, outEdge.memref);
210 }
211 }
212 // Erase remaining node state.
213 inEdges.erase(id);
214 outEdges.erase(id);
215 nodes.erase(id);
216 }
217
218 bool hasOutEdges(unsigned id) {
219 return outEdges.count(id) > 0 && !outEdges[id].empty();
220 }
221
MLIR Team71495d52019-01-22 21:23:37222 // Returns true if node 'id' writes to any memref which escapes (or is an
223 // argument to) the function/block. Returns false otherwise.
224 bool writesToLiveInOrEscapingMemrefs(unsigned id) {
225 Node *node = getNode(id);
226 for (auto *storeOpInst : node->stores) {
227 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
228 auto *inst = memref->getDefiningInst();
229 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
230 // Return false if 'memref' is a function argument.
231 if (opInst == nullptr)
232 return true;
233 // Return false if any use of 'memref' escapes the function.
234 for (auto &use : memref->getUses()) {
235 auto *user = dyn_cast<OperationInst>(use.getOwner());
236 if (!user || !isMemRefDereferencingOp(*user))
237 return true;
238 }
239 }
240 return false;
241 }
242
MLIR Team27d067e2019-01-16 17:55:02243 // Returns true iff there is an edge from node 'srcId' to node 'dstId' for
244 // 'memref'. Returns false otherwise.
245 bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) {
246 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
247 return false;
248 }
249 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
250 return edge.id == dstId && edge.memref == memref;
251 });
252 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
253 return edge.id == srcId && edge.memref == memref;
254 });
255 return hasOutEdge && hasInEdge;
256 }
257
MLIR Team6892ffb2018-12-20 04:42:55258 // Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10259 void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team27d067e2019-01-16 17:55:02260 if (!hasEdge(srcId, dstId, memref)) {
261 outEdges[srcId].push_back({dstId, memref});
262 inEdges[dstId].push_back({srcId, memref});
MLIR Teamc4237ae2019-01-18 16:56:27263 memrefEdgeCount[memref]++;
MLIR Team27d067e2019-01-16 17:55:02264 }
MLIR Team6892ffb2018-12-20 04:42:55265 }
266
267 // Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10268 void removeEdge(unsigned srcId, unsigned dstId, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55269 assert(inEdges.count(dstId) > 0);
270 assert(outEdges.count(srcId) > 0);
MLIR Teamc4237ae2019-01-18 16:56:27271 assert(memrefEdgeCount.count(memref) > 0);
272 memrefEdgeCount[memref]--;
MLIR Team6892ffb2018-12-20 04:42:55273 // Remove 'srcId' from 'inEdges[dstId]'.
274 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
275 if ((*it).id == srcId && (*it).memref == memref) {
276 inEdges[dstId].erase(it);
277 break;
278 }
279 }
280 // Remove 'dstId' from 'outEdges[srcId]'.
281 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
282 if ((*it).id == dstId && (*it).memref == memref) {
283 outEdges[srcId].erase(it);
284 break;
285 }
286 }
287 }
288
289 // Returns the input edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10290 unsigned getInEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55291 unsigned inEdgeCount = 0;
292 if (inEdges.count(id) > 0)
293 for (auto &inEdge : inEdges[id])
294 if (inEdge.memref == memref)
295 ++inEdgeCount;
296 return inEdgeCount;
297 }
298
299 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10300 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55301 unsigned outEdgeCount = 0;
302 if (outEdges.count(id) > 0)
303 for (auto &outEdge : outEdges[id])
304 if (outEdge.memref == memref)
305 ++outEdgeCount;
306 return outEdgeCount;
307 }
308
MLIR Team5c5739d2019-01-25 06:27:40309 // Check for a dependence in Block instruction list range (srcId, dstId) on
310 // memrefs other than 'memrefToSkip' (which will be privatized for the fused
311 // loop).
312 bool hasDependenceTargetInRange(unsigned srcId, unsigned dstId,
313 Value *memrefToSkip) {
314 if (outEdges.count(srcId) == 0)
315 return false;
316 // Check if any of the outgoing edge targets from srcId lie in
317 // (srcId, dstId).
318 SmallPtrSet<Instruction *, 2> depInsts;
319 for (auto &outEdge : outEdges[srcId]) {
320 if (outEdge.id != dstId && outEdge.memref != memrefToSkip) {
321 Node *node = getNode(outEdge.id);
322 depInsts.insert(node->inst);
323 }
324 }
325 // Do a linear walk from 'srcNode.inst' to 'dstNode.inst' and for each
326 // instruction 'inst' in range ('srcNode.inst', 'dstNode.inst') test
327 // if 'depInsts' contains 'inst', and return true if it does.
328 // TODO(andydavis) If this linear search becomes a compile time issue,
329 // create a data structure which allows a faster search through ForInsts
330 // in a Block.
331 Block::iterator it = std::next(Block::iterator(getNode(srcId)->inst));
332 Block::iterator itEnd = Block::iterator(getNode(dstId)->inst);
333 return std::any_of(it, itEnd, [&](Instruction &inst) {
334 return depInsts.count(&inst) > 0;
335 });
MLIR Team6892ffb2018-12-20 04:42:55336 }
337
MLIR Teamc4237ae2019-01-18 16:56:27338 // Updates edge mappings from node 'srcId' to node 'dstId'.
339 void updateEdges(unsigned srcId, unsigned dstId) {
MLIR Team6892ffb2018-12-20 04:42:55340 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
341 if (inEdges.count(srcId) > 0) {
342 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
343 for (auto &inEdge : oldInEdges) {
MLIR Team6892ffb2018-12-20 04:42:55344 // Add edge from 'inEdge.id' to 'dstId'.
345 addEdge(inEdge.id, dstId, inEdge.memref);
346 }
347 }
MLIR Teamc4237ae2019-01-18 16:56:27348 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55349 if (outEdges.count(srcId) > 0) {
350 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
351 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27352 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
353 if (outEdge.id == dstId)
354 removeEdge(srcId, outEdge.id, outEdge.memref);
MLIR Team6892ffb2018-12-20 04:42:55355 }
356 }
MLIR Team6892ffb2018-12-20 04:42:55357 }
358
359 // Adds ops in 'loads' and 'stores' to node at 'id'.
Chris Lattner5187cfc2018-12-28 05:21:41360 void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
361 const SmallVectorImpl<OperationInst *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55362 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35363 for (auto *loadOpInst : loads)
364 node->loads.push_back(loadOpInst);
365 for (auto *storeOpInst : stores)
366 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55367 }
368
MLIR Teamc4237ae2019-01-18 16:56:27369 void clearNodeLoadAndStores(unsigned id) {
370 Node *node = getNode(id);
371 node->loads.clear();
372 node->stores.clear();
373 }
374
MLIR Team6892ffb2018-12-20 04:42:55375 void print(raw_ostream &os) const {
376 os << "\nMemRefDependenceGraph\n";
377 os << "\nNodes:\n";
378 for (auto &idAndNode : nodes) {
379 os << "Node: " << idAndNode.first << "\n";
380 auto it = inEdges.find(idAndNode.first);
381 if (it != inEdges.end()) {
382 for (const auto &e : it->second)
383 os << " InEdge: " << e.id << " " << e.memref << "\n";
384 }
385 it = outEdges.find(idAndNode.first);
386 if (it != outEdges.end()) {
387 for (const auto &e : it->second)
388 os << " OutEdge: " << e.id << " " << e.memref << "\n";
389 }
390 }
391 }
392 void dump() const { print(llvm::errs()); }
393};
394
Chris Lattner456ad6a2018-12-29 00:05:35395// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55396// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39397// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55398// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09399bool MemRefDependenceGraph::init(Function *f) {
MLIR Team6892ffb2018-12-20 04:42:55400 unsigned id = 0;
Chris Lattner3f190312018-12-27 22:35:10401 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43402
403 // TODO: support multi-block functions.
404 if (f->getBlocks().size() != 1)
405 return false;
406
407 for (auto &inst : f->front()) {
Chris Lattner456ad6a2018-12-29 00:05:35408 if (auto *forInst = dyn_cast<ForInst>(&inst)) {
409 // Create graph node 'id' to represent top-level 'forInst' and record
MLIR Team6892ffb2018-12-20 04:42:55410 // all loads and store accesses it contains.
411 LoopNestStateCollector collector;
Chris Lattner456ad6a2018-12-29 00:05:35412 collector.walkForInst(forInst);
413 // Return false if IfInsts are found (not currently supported).
414 if (collector.hasIfInst)
MLIR Team6892ffb2018-12-20 04:42:55415 return false;
Chris Lattner456ad6a2018-12-29 00:05:35416 Node node(id++, &inst);
417 for (auto *opInst : collector.loadOpInsts) {
418 node.loads.push_back(opInst);
419 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55420 memrefAccesses[memref].insert(node.id);
421 }
Chris Lattner456ad6a2018-12-29 00:05:35422 for (auto *opInst : collector.storeOpInsts) {
423 node.stores.push_back(opInst);
424 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55425 memrefAccesses[memref].insert(node.id);
426 }
427 nodes.insert({node.id, node});
428 }
Chris Lattner456ad6a2018-12-29 00:05:35429 if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
430 if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55431 // Create graph node for top-level load op.
Chris Lattner456ad6a2018-12-29 00:05:35432 Node node(id++, &inst);
433 node.loads.push_back(opInst);
434 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55435 memrefAccesses[memref].insert(node.id);
436 nodes.insert({node.id, node});
437 }
Chris Lattner456ad6a2018-12-29 00:05:35438 if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55439 // Create graph node for top-level store op.
Chris Lattner456ad6a2018-12-29 00:05:35440 Node node(id++, &inst);
441 node.stores.push_back(opInst);
442 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55443 memrefAccesses[memref].insert(node.id);
444 nodes.insert({node.id, node});
445 }
446 }
Chris Lattner456ad6a2018-12-29 00:05:35447 // Return false if IfInsts are found (not currently supported).
448 if (isa<IfInst>(&inst))
MLIR Team6892ffb2018-12-20 04:42:55449 return false;
450 }
451
452 // Walk memref access lists and add graph edges between dependent nodes.
453 for (auto &memrefAndList : memrefAccesses) {
454 unsigned n = memrefAndList.second.size();
455 for (unsigned i = 0; i < n; ++i) {
456 unsigned srcId = memrefAndList.second[i];
457 bool srcHasStore =
458 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
459 for (unsigned j = i + 1; j < n; ++j) {
460 unsigned dstId = memrefAndList.second[j];
461 bool dstHasStore =
462 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
463 if (srcHasStore || dstHasStore)
464 addEdge(srcId, dstId, memrefAndList.first);
465 }
466 }
467 }
468 return true;
469}
470
MLIR Team38c2fe32019-01-14 19:26:25471namespace {
472
473// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
474// and operation count) for a loop nest up until the innermost loop body.
475struct LoopNestStats {
476 // Map from ForInst to immediate child ForInsts in its loop body.
477 DenseMap<ForInst *, SmallVector<ForInst *, 2>> loopMap;
478 // Map from ForInst to count of operations in its loop body.
479 DenseMap<ForInst *, uint64_t> opCountMap;
480 // Map from ForInst to its constant trip count.
481 DenseMap<ForInst *, uint64_t> tripCountMap;
482};
483
484// LoopNestStatsCollector walks a single loop nest and gathers per-loop
485// trip count and operation count statistics and records them in 'stats'.
486class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
487public:
488 LoopNestStats *stats;
489 bool hasLoopWithNonConstTripCount = false;
490
491 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
492
493 void visitForInst(ForInst *forInst) {
494 auto *parentInst = forInst->getParentInst();
495 if (parentInst != nullptr) {
496 assert(isa<ForInst>(parentInst) && "Expected parent ForInst");
497 // Add mapping to 'forInst' from its parent ForInst.
498 stats->loopMap[cast<ForInst>(parentInst)].push_back(forInst);
499 }
500 // Record the number of op instructions in the body of 'forInst'.
501 unsigned count = 0;
502 stats->opCountMap[forInst] = 0;
503 for (auto &inst : *forInst->getBody()) {
504 if (isa<OperationInst>(&inst))
505 ++count;
506 }
507 stats->opCountMap[forInst] = count;
508 // Record trip count for 'forInst'. Set flag if trip count is not constant.
509 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(*forInst);
510 if (!maybeConstTripCount.hasValue()) {
511 hasLoopWithNonConstTripCount = true;
512 return;
513 }
514 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
515 }
516};
517
518// Computes the total cost of the loop nest rooted at 'forInst'.
519// Currently, the total cost is computed by counting the total operation
520// instance count (i.e. total number of operations in the loop bodyloop
521// operation count * loop trip count) for the entire loop nest.
522// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
523// specified in the map when computing the total op instance count.
524// NOTE: this is used to compute the cost of computation slices, which are
525// sliced along the iteration dimension, and thus reduce the trip count.
526// If 'computeCostMap' is non-null, the total op count for forInsts specified
527// in the map is increased (not overridden) by adding the op count from the
528// map to the existing op count for the for loop. This is done before
529// multiplying by the loop's trip count, and is used to model the cost of
530// inserting a sliced loop nest of known cost into the loop's body.
531// NOTE: this is used to compute the cost of fusing a slice of some loop nest
532// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24533static int64_t getComputeCost(
MLIR Team27d067e2019-01-16 17:55:02534 ForInst *forInst, LoopNestStats *stats,
535 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
Uday Bondhugula864d9e02019-01-23 17:16:24536 DenseMap<ForInst *, int64_t> *computeCostMap) {
MLIR Team38c2fe32019-01-14 19:26:25537 // 'opCount' is the total number operations in one iteration of 'forInst' body
Uday Bondhugula864d9e02019-01-23 17:16:24538 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25539 if (stats->loopMap.count(forInst) > 0) {
540 for (auto *childForInst : stats->loopMap[forInst]) {
541 opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
542 computeCostMap);
543 }
544 }
545 // Add in additional op instances from slice (if specified in map).
546 if (computeCostMap != nullptr) {
547 auto it = computeCostMap->find(forInst);
548 if (it != computeCostMap->end()) {
549 opCount += it->second;
550 }
551 }
552 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24553 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25554 if (tripCountOverrideMap != nullptr) {
555 auto it = tripCountOverrideMap->find(forInst);
556 if (it != tripCountOverrideMap->end()) {
557 tripCount = it->second;
558 }
559 }
560 // Returns the total number of dynamic instances of operations in loop body.
561 return tripCount * opCount;
562}
563
564} // end anonymous namespace
565
MLIR Team27d067e2019-01-16 17:55:02566static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00567 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
568 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02569 assert(lbMap.getNumDims() == ubMap.getNumDims());
570 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
571 // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
572 // ub_expr - lb_expr
573 AffineExpr lbExpr(lbMap.getResult(0));
574 AffineExpr ubExpr(ubMap.getResult(0));
575 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
576 lbMap.getNumSymbols());
577 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
578 if (!cExpr)
579 return None;
580 return cExpr.getValue();
581}
582
MLIR Team38c2fe32019-01-14 19:26:25583// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
584// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
585// Returns true on success, false otherwise (if a non-constant trip count
586// was encountered).
587// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02588static bool buildSliceTripCountMap(
589 OperationInst *srcOpInst, ComputationSliceState *sliceState,
590 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
MLIR Team38c2fe32019-01-14 19:26:25591 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02592 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25593 unsigned numSrcLoopIVs = srcLoopIVs.size();
594 // Populate map from ForInst -> trip count
595 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
596 AffineMap lbMap = sliceState->lbs[i];
597 AffineMap ubMap = sliceState->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17598 if (lbMap == AffineMap() || ubMap == AffineMap()) {
MLIR Team38c2fe32019-01-14 19:26:25599 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
600 if (srcLoopIVs[i]->hasConstantLowerBound() &&
601 srcLoopIVs[i]->hasConstantUpperBound()) {
602 (*tripCountMap)[srcLoopIVs[i]] =
603 srcLoopIVs[i]->getConstantUpperBound() -
604 srcLoopIVs[i]->getConstantLowerBound();
605 continue;
606 }
607 return false;
608 }
MLIR Team27d067e2019-01-16 17:55:02609 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
610 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25611 return false;
MLIR Team27d067e2019-01-16 17:55:02612 (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25613 }
614 return true;
615}
616
MLIR Team27d067e2019-01-16 17:55:02617// Removes load operations from 'srcLoads' which operate on 'memref', and
618// adds them to 'dstLoads'.
619static void
620moveLoadsAccessingMemrefTo(Value *memref,
621 SmallVectorImpl<OperationInst *> *srcLoads,
622 SmallVectorImpl<OperationInst *> *dstLoads) {
623 dstLoads->clear();
624 SmallVector<OperationInst *, 4> srcLoadsToKeep;
625 for (auto *load : *srcLoads) {
626 if (load->cast<LoadOp>()->getMemRef() == memref)
627 dstLoads->push_back(load);
628 else
629 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25630 }
MLIR Team27d067e2019-01-16 17:55:02631 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25632}
633
MLIR Team27d067e2019-01-16 17:55:02634// Returns the innermost common loop depth for the set of operations in 'ops'.
635static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
636 unsigned numOps = ops.size();
637 assert(numOps > 0);
638
639 std::vector<SmallVector<ForInst *, 4>> loops(numOps);
640 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
641 for (unsigned i = 0; i < numOps; ++i) {
642 getLoopIVs(*ops[i], &loops[i]);
643 loopDepthLimit =
644 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25645 }
MLIR Team27d067e2019-01-16 17:55:02646
647 unsigned loopDepth = 0;
648 for (unsigned d = 0; d < loopDepthLimit; ++d) {
649 unsigned i;
650 for (i = 1; i < numOps; ++i) {
651 if (loops[i - 1][d] != loops[i][d]) {
652 break;
653 }
654 }
655 if (i != numOps)
656 break;
657 ++loopDepth;
658 }
659 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25660}
661
Uday Bondhugulac1ca23e2019-01-16 21:13:00662// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
663// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:02664// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
665// and 'sliceStateB' are aligned.
666// Specifically, when taking the union of overlapping intervals, it assumes
667// that both intervals start at zero. Support needs to be added to take into
668// account interval start offset when computing the union.
669// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:00670static bool getSliceUnion(const ComputationSliceState &sliceStateA,
671 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:02672 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
673 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
674
675 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
676 AffineMap lbMapA = sliceStateA.lbs[i];
677 AffineMap ubMapA = sliceStateA.ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17678 if (lbMapA == AffineMap()) {
679 assert(ubMapA == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02680 continue;
681 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00682 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:02683
684 AffineMap lbMapB = sliceStateB->lbs[i];
685 AffineMap ubMapB = sliceStateB->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17686 if (lbMapB == AffineMap()) {
687 assert(ubMapB == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02688 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
689 sliceStateB->lbs[i] = lbMapA;
690 sliceStateB->ubs[i] = ubMapA;
691 continue;
692 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00693
694 // TODO(andydavis) Change this code to take the min across all lower bounds
695 // and max across all upper bounds for each dimension. This code can for
696 // cases where a unique min or max could not be statically determined.
697
698 // Assumption: both lower bounds are the same.
699 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:02700 return false;
701
702 // Add bound with the largest trip count to union.
703 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
704 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
705 if (!tripCountA.hasValue() || !tripCountB.hasValue())
706 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:00707
MLIR Team27d067e2019-01-16 17:55:02708 if (tripCountA.getValue() > tripCountB.getValue()) {
709 sliceStateB->lbs[i] = lbMapA;
710 sliceStateB->ubs[i] = ubMapA;
711 }
712 }
713 return true;
714}
715
MLIR Teamc4237ae2019-01-18 16:56:27716// Creates and returns a private (single-user) memref for fused loop rooted
717// at 'forInst', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:52718// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
719// TODO(bondhugula): consider refactoring the common code from generateDma and
720// this one.
MLIR Teamc4237ae2019-01-18 16:56:27721static Value *createPrivateMemRef(ForInst *forInst,
Uday Bondhugula94a03f82019-01-22 21:58:52722 OperationInst *srcStoreOpInst,
723 unsigned dstLoopDepth) {
MLIR Teamc4237ae2019-01-18 16:56:27724 // Create builder to insert alloc op just before 'forInst'.
725 FuncBuilder b(forInst);
726 // Builder to create constants at the top level.
727 FuncBuilder top(forInst->getFunction());
728 // Create new memref type based on slice bounds.
729 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>()->getMemRef();
730 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
731 unsigned rank = oldMemRefType.getRank();
732
Uday Bondhugula94a03f82019-01-22 21:58:52733 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27734 MemRefRegion region;
Uday Bondhugula94a03f82019-01-22 21:58:52735 getMemRefRegion(srcStoreOpInst, dstLoopDepth, &region);
River Riddle6859f332019-01-23 22:39:45736 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:27737 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:52738 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:27739 lbs.reserve(rank);
740 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:52741 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27742 Optional<int64_t> numElements =
Uday Bondhugula94a03f82019-01-22 21:58:52743 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
MLIR Teamc4237ae2019-01-18 16:56:27744 assert(numElements.hasValue());
745
MLIR Teamc4237ae2019-01-18 16:56:27746 const FlatAffineConstraints *cst = region.getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:52747 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
748 // on; this would correspond to loop IVs surrounding the level at which the
749 // slice is being materialized.
750 SmallVector<Value *, 8> outerIVs;
751 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
752
753 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:27754 SmallVector<AffineExpr, 4> offsets;
755 offsets.reserve(rank);
756 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:52757 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
758
MLIR Teamc4237ae2019-01-18 16:56:27759 AffineExpr offset = top.getAffineConstantExpr(0);
760 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
761 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
762 }
Uday Bondhugula94a03f82019-01-22 21:58:52763 assert(lbDivisors[d] > 0);
764 offset =
765 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:27766 offsets.push_back(offset);
767 }
768
769 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
770 // by 'srcStoreOpInst'.
771 auto newMemRefType = b.getMemRefType(newShape, oldMemRefType.getElementType(),
772 {}, oldMemRefType.getMemorySpace());
773 // Gather alloc operands for the dynamic dimensions of the memref.
774 SmallVector<Value *, 4> allocOperands;
775 unsigned dynamicDimCount = 0;
776 for (auto dimSize : oldMemRefType.getShape()) {
777 if (dimSize == -1)
778 allocOperands.push_back(
779 b.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++));
780 }
781
782 // Create new private memref for fused loop 'forInst'.
783 Value *newMemRef =
784 b.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
785
786 // Build an AffineMap to remap access functions based on lower bound offsets.
787 SmallVector<AffineExpr, 4> remapExprs;
788 remapExprs.reserve(rank);
789 unsigned zeroOffsetCount = 0;
790 for (unsigned i = 0; i < rank; i++) {
791 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
792 if (constExpr.getValue() == 0)
793 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:52794 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
795
796 auto remapExpr =
797 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
798 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:27799 }
Uday Bondhugula94a03f82019-01-22 21:58:52800 auto indexRemap =
801 zeroOffsetCount == rank
Nicolas Vasilache0e7a8a92019-01-26 18:41:17802 ? AffineMap()
Uday Bondhugula94a03f82019-01-22 21:58:52803 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:27804 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:52805 bool ret =
806 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
807 /*extraOperands=*/outerIVs,
808 /*domInstFilter=*/&*forInst->getBody()->begin());
809 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:37810 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:27811 return newMemRef;
812}
813
Uday Bondhugula864d9e02019-01-23 17:16:24814// Does the slice have a single iteration?
815static uint64_t getSliceIterationCount(
816 const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) {
817 uint64_t iterCount = 1;
818 for (const auto &count : sliceTripCountMap) {
819 iterCount *= count.second;
820 }
821 return iterCount;
822}
823
MLIR Team27d067e2019-01-16 17:55:02824// Checks the profitability of fusing a backwards slice of the loop nest
825// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
Uday Bondhugulab4a14432019-01-26 00:00:50826// Returns true if it is profitable to fuse the candidate loop nests. Returns
827// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
828// to materialize the source loop nest slice.
MLIR Team38c2fe32019-01-14 19:26:25829// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:02830// *) Computes the backward computation slice at 'srcOpInst'. This
831// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:25832// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:02833// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:25834// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
835// loop nest is the total number of dynamic operation instances in the loop
836// nest).
837// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:02838// loop nest at various values of dst loop depth, attempting to fuse
839// the largest compution slice at the maximal dst loop depth (closest to the
840// load) to minimize reuse distance and potentially enable subsequent
841// load/store forwarding.
842// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for
843// the same memref as is written by 'srcOpInst', then the union of slice
844// loop bounds is used to compute the slice and associated slice cost.
Uday Bondhugulab4a14432019-01-26 00:00:50845// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
MLIR Team38c2fe32019-01-14 19:26:25846// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:02847// NOTE: We attempt to maximize the dst loop depth, but there are cases
848// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:25849// loop (within the src computation slice) at a depth which results in
850// execessive recomputation (see unit tests for examples).
851// *) Compares the total cost of the unfused loop nests to the min cost fused
852// loop nest computed in the previous step, and returns true if the latter
853// is lower.
MLIR Team27d067e2019-01-16 17:55:02854static bool isFusionProfitable(OperationInst *srcOpInst,
855 ArrayRef<OperationInst *> dstOpInsts,
MLIR Team38c2fe32019-01-14 19:26:25856 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:02857 unsigned *dstLoopDepth) {
Uday Bondhugula06d21d92019-01-25 01:01:49858 LLVM_DEBUG({
859 llvm::dbgs() << "Checking whether fusion is profitable between:\n";
860 llvm::dbgs() << " ";
861 srcOpInst->dump();
862 llvm::dbgs() << " and \n";
863 for (auto dstOpInst : dstOpInsts) {
864 llvm::dbgs() << " ";
865 dstOpInst->dump();
866 };
867 });
Uday Bondhugula864d9e02019-01-23 17:16:24868
MLIR Team38c2fe32019-01-14 19:26:25869 // Compute cost of sliced and unsliced src loop nest.
870 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02871 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25872 unsigned numSrcLoopIVs = srcLoopIVs.size();
873
874 // Walk src loop nest and collect stats.
875 LoopNestStats srcLoopNestStats;
876 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
877 srcStatsCollector.walk(srcLoopIVs[0]);
878 // Currently only constant trip count loop nests are supported.
879 if (srcStatsCollector.hasLoopWithNonConstTripCount)
880 return false;
881
882 // Compute cost of dst loop nest.
883 SmallVector<ForInst *, 4> dstLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02884 getLoopIVs(*dstOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25885
886 LoopNestStats dstLoopNestStats;
887 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
888 dstStatsCollector.walk(dstLoopIVs[0]);
889 // Currently only constant trip count loop nests are supported.
890 if (dstStatsCollector.hasLoopWithNonConstTripCount)
891 return false;
892
MLIR Team27d067e2019-01-16 17:55:02893 // Compute the innermost common loop for ops in 'dstOpInst'.
894 unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts);
895 if (maxDstLoopDepth == 0)
896 return false;
897
898 // Search for min cost value for 'dstLoopDepth'. At each value of
899 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
900 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
901 // of these bounds). Next the union slice bounds are used to calculate
902 // the cost of the slice and the cost of the slice inserted into the dst
903 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:24904 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
905 uint64_t maxStorageReduction = 0;
906 Optional<uint64_t> sliceMemEstimate = None;
907
MLIR Team27d067e2019-01-16 17:55:02908 SmallVector<ComputationSliceState, 4> sliceStates;
909 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:24910 // The best loop depth at which to materialize the slice.
911 Optional<unsigned> bestDstLoopDepth = None;
912
913 // Compute op instance count for the src loop nest without iteration slicing.
914 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
915 /*tripCountOverrideMap=*/nullptr,
916 /*computeCostMap=*/nullptr);
917
918 // Compute op instance count for the src loop nest.
919 uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
920 /*tripCountOverrideMap=*/nullptr,
921 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:02922
923 llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
Uday Bondhugula864d9e02019-01-23 17:16:24924 DenseMap<ForInst *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:02925 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
926 MemRefAccess srcAccess(srcOpInst);
927 // Handle the common case of one dst load without a copy.
928 if (!mlir::getBackwardComputationSliceState(
929 srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1]))
930 return false;
931 // Compute the union of slice bound of all ops in 'dstOpInsts'.
932 for (int j = 1, e = dstOpInsts.size(); j < e; ++j) {
933 MemRefAccess dstAccess(dstOpInsts[j]);
934 ComputationSliceState tmpSliceState;
935 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
936 &tmpSliceState))
937 return false;
938 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:00939 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:25940 }
Uday Bondhugulab4a14432019-01-26 00:00:50941 // Build trip count map for computation slice. We'll skip cases where the
942 // trip count was non-constant.
MLIR Team27d067e2019-01-16 17:55:02943 sliceTripCountMap.clear();
944 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
945 &sliceTripCountMap))
Uday Bondhugula864d9e02019-01-23 17:16:24946 continue;
947
948 // Checks whether a store to load forwarding will happen.
949 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
Uday Bondhugula864d9e02019-01-23 17:16:24950 assert(sliceIterationCount > 0);
Uday Bondhugulab4a14432019-01-26 00:00:50951 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
Uday Bondhugula864d9e02019-01-23 17:16:24952
953 // Compute cost of fusion for this dest loop depth.
954
955 computeCostMap.clear();
956
957 // The store and loads to this memref will disappear.
958 if (storeLoadFwdGuaranteed) {
959 // A single store disappears: -1 for that.
960 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1;
961 for (auto *loadOp : dstOpInsts) {
962 if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst()))
963 computeCostMap[loadLoop] = -1;
964 }
965 }
MLIR Team27d067e2019-01-16 17:55:02966
MLIR Team38c2fe32019-01-14 19:26:25967 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:24968 int64_t sliceComputeCost =
969 getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
970 /*tripCountOverrideMap=*/&sliceTripCountMap,
971 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:25972
Uday Bondhugula864d9e02019-01-23 17:16:24973 // Compute cost of fusion for this depth.
MLIR Team27d067e2019-01-16 17:55:02974 computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:24975
976 int64_t fusedLoopNestComputeCost =
MLIR Team27d067e2019-01-16 17:55:02977 getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
978 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:24979
980 double additionalComputeFraction =
981 fusedLoopNestComputeCost /
982 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
983 1;
984
985 // TODO(bondhugula): This is an ugly approximation. Fix this by finding a
986 // good way to calculate the footprint of the memref in the slice and
987 // divide it by the total memory footprint of the fused computation.
988 double storageReduction =
989 static_cast<double>(srcLoopNestCost) / sliceIterationCount;
990
Uday Bondhugula06d21d92019-01-25 01:01:49991 LLVM_DEBUG({
992 std::stringstream msg;
993 msg << " evaluating fusion profitability at depth : " << i << "\n"
994 << std::setprecision(2) << " additional compute fraction: "
995 << 100.0 * additionalComputeFraction << "%\n"
996 << " storage reduction factor: " << storageReduction << "x\n"
997 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
998 << " slice iteration count: " << sliceIterationCount << "\n";
999 llvm::dbgs() << msg.str();
1000 });
Uday Bondhugula864d9e02019-01-23 17:16:241001
1002 double computeToleranceThreshold =
1003 clFusionAddlComputeTolerance.getNumOccurrences() > 0
1004 ? clFusionAddlComputeTolerance
1005 : LoopFusion::kComputeToleranceThreshold;
1006
1007 // TODO(b/123247369): This is a placeholder cost model.
1008 // Among all choices that add an acceptable amount of redundant computation
1009 // (as per computeToleranceThreshold), we will simply pick the one that
1010 // reduces the intermediary size the most.
1011 if ((storageReduction > maxStorageReduction) &&
1012 (clMaximalLoopFusion ||
1013 (additionalComputeFraction < computeToleranceThreshold))) {
1014 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:021015 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:241016 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
1017 // TODO(bondhugula,andydavis): find a good way to compute the memory
1018 // footprint of the materialized slice.
1019 // Approximating this to the compute cost of the slice. This could be an
1020 // under-approximation or an overapproximation, but in many cases
1021 // accurate.
1022 sliceMemEstimate = sliceIterationCount;
MLIR Team38c2fe32019-01-14 19:26:251023 }
1024 }
1025
Uday Bondhugula864d9e02019-01-23 17:16:241026 // A simple cost model: fuse if it reduces the memory footprint. If
1027 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251028
Uday Bondhugula864d9e02019-01-23 17:16:241029 if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
1030 LLVM_DEBUG(llvm::dbgs()
1031 << "All fusion choices involve more than the threshold amount of"
1032 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251033 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241034 }
1035
1036 assert(bestDstLoopDepth.hasValue() &&
1037 "expected to have a value per logic above");
1038
1039 // Set dstLoopDepth based on best values from search.
1040 *dstLoopDepth = bestDstLoopDepth.getValue();
1041
1042 LLVM_DEBUG(
Uday Bondhugula06d21d92019-01-25 01:01:491043 llvm::dbgs() << " LoopFusion fusion stats:"
1044 << "\n best loop depth: " << bestDstLoopDepth
Uday Bondhugula864d9e02019-01-23 17:16:241045 << "\n src loop nest compute cost: " << srcLoopNestCost
1046 << "\n dst loop nest compute cost: " << dstLoopNestCost
1047 << "\n fused loop nest compute cost: "
1048 << minFusedLoopNestComputeCost << "\n");
1049
1050 auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]);
1051 auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]);
1052
1053 Optional<double> storageReduction = None;
1054
1055 if (!clMaximalLoopFusion) {
1056 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1057 LLVM_DEBUG(
1058 llvm::dbgs()
1059 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1060 return false;
1061 }
1062
1063 auto srcMemSizeVal = srcMemSize.getValue();
1064 auto dstMemSizeVal = dstMemSize.getValue();
1065
1066 assert(sliceMemEstimate.hasValue() && "expected value");
1067 // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
1068 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1069
1070 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1071 << " dst mem: " << dstMemSizeVal << "\n"
1072 << " fused mem: " << fusedMem << "\n"
1073 << " slice mem: " << sliceMemEstimate << "\n");
1074
1075 if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
1076 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1077 return false;
1078 }
1079 storageReduction =
1080 100.0 *
1081 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1082 }
1083
1084 double additionalComputeFraction =
1085 100.0 * (minFusedLoopNestComputeCost /
1086 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1087 1);
MLIR Team5c5739d2019-01-25 06:27:401088 (void)additionalComputeFraction;
Uday Bondhugula06d21d92019-01-25 01:01:491089 LLVM_DEBUG({
1090 std::stringstream msg;
1091 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
1092 << setprecision(2) << additionalComputeFraction
1093 << "% redundant computation and a ";
1094 msg << (storageReduction.hasValue()
1095 ? std::to_string(storageReduction.getValue())
1096 : "<unknown>");
1097 msg << "% storage reduction.\n";
1098 llvm::dbgs() << msg.str();
1099 });
Uday Bondhugula864d9e02019-01-23 17:16:241100
MLIR Team27d067e2019-01-16 17:55:021101 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241102 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021103 sliceState->lbs = bestSliceState->lbs;
1104 sliceState->ubs = bestSliceState->ubs;
1105 sliceState->lbOperands = bestSliceState->lbOperands;
1106 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241107
MLIR Team27d067e2019-01-16 17:55:021108 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251109 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Nicolas Vasilache0e7a8a92019-01-26 18:41:171110 if (sliceState->lbs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021111 canonicalizeMapAndOperands(&sliceState->lbs[i],
1112 &sliceState->lbOperands[i]);
1113 }
Nicolas Vasilache0e7a8a92019-01-26 18:41:171114 if (sliceState->ubs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021115 canonicalizeMapAndOperands(&sliceState->ubs[i],
1116 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251117 }
1118 }
1119 return true;
1120}
1121
MLIR Team6892ffb2018-12-20 04:42:551122// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:141123// relationship on a memref, with the goal of improving locality. Currently,
1124// this the producer/consumer relationship is required to be unique in the
Chris Lattner69d9e992018-12-28 16:48:091125// Function (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:001126//
MLIR Team3b692302018-12-17 17:57:141127// The steps of the algorithm are as follows:
1128//
MLIR Team6892ffb2018-12-20 04:42:551129// *) A worklist is initialized with node ids from the dependence graph.
1130// *) For each node id in the worklist:
Chris Lattner456ad6a2018-12-29 00:05:351131// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
1132// destination ForInst into which fusion will be attempted.
1133// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141134// *) For each LoadOp in 'dstLoadOps' do:
Chris Lattner69d9e992018-12-28 16:48:091135// *) Lookup dependent loop nests at earlier positions in the Function
MLIR Team3b692302018-12-17 17:57:141136// which have a single store op to the same memref.
1137// *) Check if dependences would be violated by the fusion. For example,
1138// the src loop nest may load from memrefs which are different than
1139// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:551140// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141141// bounds to be functions of 'dstLoopNest' IVs and symbols.
1142// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
1143// just before the dst load op user.
Chris Lattner456ad6a2018-12-29 00:05:351144// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:141145// and also add newly fuse load ops to 'dstLoopOps' to be considered
1146// as fusion dst load ops in another iteration.
1147// *) Remove old src loop nest and its associated state.
1148//
Chris Lattner456ad6a2018-12-29 00:05:351149// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141150// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551151// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141152//
MLIR Team6892ffb2018-12-20 04:42:551153// This greedy algorithm is not 'maximal' due to the current restriction of
1154// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141155//
1156// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551157// TODO(andydavis) Add support for fusing for input reuse (perhaps by
1158// constructing a graph with edges which represent loads from the same memref
MLIR Team5c5739d2019-01-25 06:27:401159// in two different loop nests.
MLIR Team6892ffb2018-12-20 04:42:551160struct GreedyFusion {
1161public:
1162 MemRefDependenceGraph *mdg;
MLIR Team3b692302018-12-17 17:57:141163 SmallVector<unsigned, 4> worklist;
MLIR Teamf28e4df2018-11-01 14:26:001164
MLIR Team6892ffb2018-12-20 04:42:551165 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
1166 // Initialize worklist with nodes from 'mdg'.
1167 worklist.resize(mdg->nodes.size());
1168 std::iota(worklist.begin(), worklist.end(), 0);
1169 }
MLIR Team3b692302018-12-17 17:57:141170
1171 void run() {
MLIR Team3b692302018-12-17 17:57:141172 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551173 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141174 worklist.pop_back();
MLIR Team6892ffb2018-12-20 04:42:551175 // Skip if this node was removed (fused into another node).
1176 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141177 continue;
MLIR Team6892ffb2018-12-20 04:42:551178 // Get 'dstNode' into which to attempt fusion.
1179 auto *dstNode = mdg->getNode(dstId);
1180 // Skip if 'dstNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:351181 if (!isa<ForInst>(dstNode->inst))
MLIR Team3b692302018-12-17 17:57:141182 continue;
1183
Chris Lattner5187cfc2018-12-28 05:21:411184 SmallVector<OperationInst *, 4> loads = dstNode->loads;
MLIR Team27d067e2019-01-16 17:55:021185 SmallVector<OperationInst *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271186 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551187 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021188 // Get memref of load on top of the stack.
1189 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271190 if (visitedMemrefs.count(memref) > 0)
1191 continue;
1192 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021193 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1194 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551195 // Skip if no input edges along which to fuse.
1196 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141197 continue;
MLIR Team6892ffb2018-12-20 04:42:551198 // Iterate through in edges for 'dstId'.
1199 for (auto &srcEdge : mdg->inEdges[dstId]) {
1200 // Skip 'srcEdge' if not for 'memref'.
1201 if (srcEdge.memref != memref)
1202 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241203
MLIR Team6892ffb2018-12-20 04:42:551204 auto *srcNode = mdg->getNode(srcEdge.id);
1205 // Skip if 'srcNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:351206 if (!isa<ForInst>(srcNode->inst))
MLIR Team6892ffb2018-12-20 04:42:551207 continue;
MLIR Teamb28009b2019-01-23 19:11:431208 // Skip if 'srcNode' has more than one store to any memref.
1209 // TODO(andydavis) Support fusing multi-output src loop nests.
1210 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551211 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241212
MLIR Team6892ffb2018-12-20 04:42:551213 // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly
1214 // TODO(andydavis) Track dependence type with edges, and just check
1215 // for WAW dependence edge here.
1216 if (mdg->getInEdgeCount(srcNode->id, memref) != 0)
1217 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241218
MLIR Team5c5739d2019-01-25 06:27:401219 // Skip if 'srcNode' has out edges on memrefs other than 'memref'
1220 // for nodes in instruction list range (srcNode.inst, dstNode.inst).
1221 if (mdg->hasDependenceTargetInRange(srcNode->id, dstNode->id, memref))
MLIR Team6892ffb2018-12-20 04:42:551222 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241223
Uday Bondhugulab4a14432019-01-26 00:00:501224 // Check if fusion would be profitable and at what depth.
MLIR Team6892ffb2018-12-20 04:42:551225 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351226 auto *srcStoreOpInst = srcNode->stores.front();
Uday Bondhugulab4a14432019-01-26 00:00:501227 unsigned bestDstLoopDepth;
MLIR Team38c2fe32019-01-14 19:26:251228 mlir::ComputationSliceState sliceState;
MLIR Team27d067e2019-01-16 17:55:021229 if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
Uday Bondhugulab4a14432019-01-26 00:00:501230 &bestDstLoopDepth))
MLIR Team38c2fe32019-01-14 19:26:251231 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241232
MLIR Team6892ffb2018-12-20 04:42:551233 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1234 auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
Uday Bondhugulab4a14432019-01-26 00:00:501235 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:551236 if (sliceLoopNest != nullptr) {
MLIR Teamc4237ae2019-01-18 16:56:271237 // Update edges between 'srcNode' and 'dstNode'.
1238 mdg->updateEdges(srcNode->id, dstNode->id);
1239
1240 // Collect slice loop stats.
1241 LoopNestStateCollector sliceCollector;
1242 sliceCollector.walkForInst(sliceLoopNest);
1243 // Promote single iteration slice loops to single IV value.
1244 for (auto *forInst : sliceCollector.forInsts) {
Chris Lattner456ad6a2018-12-29 00:05:351245 promoteIfSingleIteration(forInst);
MLIR Team6892ffb2018-12-20 04:42:551246 }
MLIR Teamc4237ae2019-01-18 16:56:271247
1248 // Create private memref for 'memref' in 'dstForInst'.
1249 auto *dstForInst = cast<ForInst>(dstNode->inst);
1250 SmallVector<OperationInst *, 4> storesForMemref;
1251 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
1252 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1253 storesForMemref.push_back(storeOpInst);
1254 }
1255 assert(storesForMemref.size() == 1);
Uday Bondhugula94a03f82019-01-22 21:58:521256 auto *newMemRef = createPrivateMemRef(
Uday Bondhugulab4a14432019-01-26 00:00:501257 dstForInst, storesForMemref[0], bestDstLoopDepth);
MLIR Teamc4237ae2019-01-18 16:56:271258 visitedMemrefs.insert(newMemRef);
1259
1260 // Collect dst loop stats after memref privatizaton transformation.
1261 LoopNestStateCollector dstLoopCollector;
1262 dstLoopCollector.walkForInst(dstForInst);
1263
1264 // Add new load ops to current Node load op list 'loads' to
1265 // continue fusing based on new operands.
1266 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
1267 auto *loadMemRef = loadOpInst->cast<LoadOp>()->getMemRef();
1268 if (visitedMemrefs.count(loadMemRef) == 0)
1269 loads.push_back(loadOpInst);
1270 }
1271
1272 // Clear and add back loads and stores
1273 mdg->clearNodeLoadAndStores(dstNode->id);
1274 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1275 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371276 // Remove old src loop nest if it no longer has outgoing dependence
1277 // edges, and it does not write to a memref which escapes the
1278 // function.
1279 if (!mdg->hasOutEdges(srcNode->id) &&
1280 !mdg->writesToLiveInOrEscapingMemrefs(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271281 mdg->removeNode(srcNode->id);
1282 cast<ForInst>(srcNode->inst)->erase();
1283 }
MLIR Team3b692302018-12-17 17:57:141284 }
MLIR Team3b692302018-12-17 17:57:141285 }
1286 }
1287 }
MLIR Teamc4237ae2019-01-18 16:56:271288 // Clean up any allocs with no users.
1289 for (auto &pair : mdg->memrefEdgeCount) {
1290 if (pair.second > 0)
1291 continue;
1292 auto *memref = pair.first;
MLIR Team71495d52019-01-22 21:23:371293 // Skip if there exist other uses (return instruction or function calls).
1294 if (!memref->use_empty())
1295 continue;
MLIR Teamc4237ae2019-01-18 16:56:271296 // Use list expected to match the dep graph info.
MLIR Teamc4237ae2019-01-18 16:56:271297 auto *inst = memref->getDefiningInst();
1298 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
1299 if (opInst && opInst->isa<AllocOp>())
1300 opInst->erase();
1301 }
MLIR Teamf28e4df2018-11-01 14:26:001302 }
MLIR Team3b692302018-12-17 17:57:141303};
1304
1305} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:001306
Chris Lattner79748892018-12-31 07:10:351307PassResult LoopFusion::runOnFunction(Function *f) {
MLIR Team6892ffb2018-12-20 04:42:551308 MemRefDependenceGraph g;
1309 if (g.init(f))
1310 GreedyFusion(&g).run();
MLIR Teamf28e4df2018-11-01 14:26:001311 return success();
1312}
Jacques Pienaar6f0fb222018-11-07 02:34:181313
1314static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");