blob: f33afba3806a0e7e676643f65c34cea370308afa [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
River Riddle75553832019-01-29 05:23:5322#include "mlir/AffineOps/AffineOps.h"
MLIR Teamf28e4df2018-11-01 14:26:0023#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1424#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0025#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1426#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0027#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinOps.h"
Chris Lattner456ad6a2018-12-29 00:05:3531#include "mlir/IR/InstVisitor.h"
MLIR Teamf28e4df2018-11-01 14:26:0032#include "mlir/Pass.h"
33#include "mlir/StandardOps/StandardOps.h"
34#include "mlir/Transforms/LoopUtils.h"
35#include "mlir/Transforms/Passes.h"
MLIR Teamc4237ae2019-01-18 16:56:2736#include "mlir/Transforms/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0037#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1438#include "llvm/ADT/DenseSet.h"
39#include "llvm/ADT/SetVector.h"
MLIR Team4eef7952018-12-21 19:06:2340#include "llvm/Support/CommandLine.h"
MLIR Team38c2fe32019-01-14 19:26:2541#include "llvm/Support/Debug.h"
MLIR Team3b692302018-12-17 17:57:1442#include "llvm/Support/raw_ostream.h"
Uday Bondhugula864d9e02019-01-23 17:16:2443#include <iomanip>
MLIR Team3b692302018-12-17 17:57:1444
MLIR Team38c2fe32019-01-14 19:26:2545#define DEBUG_TYPE "loop-fusion"
46
MLIR Team3b692302018-12-17 17:57:1447using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0048
49using namespace mlir;
50
River Riddle75c21e12019-01-26 06:14:0451static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
52
Uday Bondhugula864d9e02019-01-23 17:16:2453/// Disables fusion profitability check and fuses if valid.
MLIR Teamc4237ae2019-01-18 16:56:2754static llvm::cl::opt<bool>
55 clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
River Riddle75c21e12019-01-26 06:14:0456 llvm::cl::desc("Enables maximal loop fusion"),
57 llvm::cl::cat(clOptionsCategory));
Uday Bondhugula864d9e02019-01-23 17:16:2458
59/// A threshold in percent of additional computation allowed when fusing.
60static llvm::cl::opt<double> clFusionAddlComputeTolerance(
61 "fusion-compute-tolerance", llvm::cl::Hidden,
62 llvm::cl::desc("Fractional increase in additional"
River Riddle75c21e12019-01-26 06:14:0463 " computation tolerated while fusing"),
64 llvm::cl::cat(clOptionsCategory));
MLIR Teamc4237ae2019-01-18 16:56:2765
MLIR Teamf28e4df2018-11-01 14:26:0066namespace {
67
MLIR Team3b692302018-12-17 17:57:1468/// Loop fusion pass. This pass currently supports a greedy fusion policy,
69/// which fuses loop nests with single-writer/single-reader memref dependences
70/// with the goal of improving locality.
71
72// TODO(andydavis) Support fusion of source loop nests which write to multiple
73// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0074// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
75// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1476
MLIR Teamf28e4df2018-11-01 14:26:0077struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0378 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0079
Chris Lattner79748892018-12-31 07:10:3580 PassResult runOnFunction(Function *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1881 static char passID;
Uday Bondhugula864d9e02019-01-23 17:16:2482
83 // The amount of additional computation that is tolerated while fusing
84 // pair-wise as a fraction of the total computation.
85 constexpr static double kComputeToleranceThreshold = 0.30f;
MLIR Teamf28e4df2018-11-01 14:26:0086};
87
MLIR Teamf28e4df2018-11-01 14:26:0088} // end anonymous namespace
89
Jacques Pienaar6f0fb222018-11-07 02:34:1890char LoopFusion::passID = 0;
91
MLIR Teamf28e4df2018-11-01 14:26:0092FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
93
MLIR Team3b692302018-12-17 17:57:1494namespace {
MLIR Teamf28e4df2018-11-01 14:26:0095
MLIR Team3b692302018-12-17 17:57:1496// LoopNestStateCollector walks loop nests and collects load and store
Chris Lattner456ad6a2018-12-29 00:05:3597// operations, and whether or not an IfInst was encountered in the loop nest.
98class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
MLIR Team3b692302018-12-17 17:57:1499public:
Chris Lattner456ad6a2018-12-29 00:05:35100 SmallVector<ForInst *, 4> forInsts;
101 SmallVector<OperationInst *, 4> loadOpInsts;
102 SmallVector<OperationInst *, 4> storeOpInsts;
River Riddle75553832019-01-29 05:23:53103 bool hasNonForRegion = false;
MLIR Team3b692302018-12-17 17:57:14104
Chris Lattner456ad6a2018-12-29 00:05:35105 void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
MLIR Team3b692302018-12-17 17:57:14106
Chris Lattner456ad6a2018-12-29 00:05:35107 void visitOperationInst(OperationInst *opInst) {
River Riddle75553832019-01-29 05:23:53108 if (opInst->getNumBlockLists() != 0)
109 hasNonForRegion = true;
110 else if (opInst->isa<LoadOp>())
Chris Lattner456ad6a2018-12-29 00:05:35111 loadOpInsts.push_back(opInst);
River Riddle75553832019-01-29 05:23:53112 else if (opInst->isa<StoreOp>())
Chris Lattner456ad6a2018-12-29 00:05:35113 storeOpInsts.push_back(opInst);
MLIR Team3b692302018-12-17 17:57:14114 }
115};
116
MLIR Team71495d52019-01-22 21:23:37117// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
118static bool isMemRefDereferencingOp(const OperationInst &op) {
119 if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
120 op.isa<DmaWaitOp>())
121 return true;
122 return false;
123}
MLIR Team6892ffb2018-12-20 04:42:55124// MemRefDependenceGraph is a graph data structure where graph nodes are
Chris Lattner456ad6a2018-12-29 00:05:35125// top-level instructions in a Function which contain load/store ops, and edges
MLIR Team6892ffb2018-12-20 04:42:55126// are memref dependences between the nodes.
MLIR Teamc4237ae2019-01-18 16:56:27127// TODO(andydavis) Add a more flexible dependece graph representation.
MLIR Team6892ffb2018-12-20 04:42:55128// TODO(andydavis) Add a depth parameter to dependence graph construction.
129struct MemRefDependenceGraph {
130public:
131 // Node represents a node in the graph. A Node is either an entire loop nest
132 // rooted at the top level which contains loads/stores, or a top level
133 // load/store.
134 struct Node {
135 // The unique identifier of this node in the graph.
136 unsigned id;
137 // The top-level statment which is (or contains) loads/stores.
Chris Lattner456ad6a2018-12-29 00:05:35138 Instruction *inst;
Chris Lattner5187cfc2018-12-28 05:21:41139 // List of load operations.
140 SmallVector<OperationInst *, 4> loads;
Chris Lattner456ad6a2018-12-29 00:05:35141 // List of store op insts.
Chris Lattner5187cfc2018-12-28 05:21:41142 SmallVector<OperationInst *, 4> stores;
Chris Lattner456ad6a2018-12-29 00:05:35143 Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
MLIR Team6892ffb2018-12-20 04:42:55144
145 // Returns the load op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10146 unsigned getLoadOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55147 unsigned loadOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35148 for (auto *loadOpInst : loads) {
149 if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55150 ++loadOpCount;
151 }
152 return loadOpCount;
153 }
154
155 // Returns the store op count for 'memref'.
Chris Lattner3f190312018-12-27 22:35:10156 unsigned getStoreOpCount(Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55157 unsigned storeOpCount = 0;
Chris Lattner456ad6a2018-12-29 00:05:35158 for (auto *storeOpInst : stores) {
159 if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
MLIR Team6892ffb2018-12-20 04:42:55160 ++storeOpCount;
161 }
162 return storeOpCount;
163 }
164 };
165
MLIR Teama0f3db402019-01-29 17:36:41166 // Edge represents a data dependece between nodes in the graph.
MLIR Team6892ffb2018-12-20 04:42:55167 struct Edge {
168 // The id of the node at the other end of the edge.
169 unsigned id;
MLIR Teama0f3db402019-01-29 17:36:41170 // The SSA value on which this edge represents a dependence.
171 // If the value is a memref, then the dependence is between graph nodes
172 // which contain accesses to the same memref 'value'. If the value is a
173 // non-memref value, then the dependence is between a graph node which
174 // defines an SSA value and another graph node which uses the SSA value
175 // (e.g. a constant instruction defining a value which is used inside a loop
176 // nest).
177 Value *value;
MLIR Team6892ffb2018-12-20 04:42:55178 };
179
180 // Map from node id to Node.
181 DenseMap<unsigned, Node> nodes;
182 // Map from node id to list of input edges.
183 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
184 // Map from node id to list of output edges.
185 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
MLIR Teamc4237ae2019-01-18 16:56:27186 // Map from memref to a count on the dependence edges associated with that
187 // memref.
188 DenseMap<Value *, unsigned> memrefEdgeCount;
MLIR Teama0f3db402019-01-29 17:36:41189 // The next unique identifier to use for newly created graph nodes.
190 unsigned nextNodeId = 0;
MLIR Team6892ffb2018-12-20 04:42:55191
192 MemRefDependenceGraph() {}
193
194 // Initializes the dependence graph based on operations in 'f'.
195 // Returns true on success, false otherwise.
Chris Lattner69d9e992018-12-28 16:48:09196 bool init(Function *f);
MLIR Team6892ffb2018-12-20 04:42:55197
198 // Returns the graph node for 'id'.
199 Node *getNode(unsigned id) {
200 auto it = nodes.find(id);
201 assert(it != nodes.end());
202 return &it->second;
203 }
204
MLIR Teama0f3db402019-01-29 17:36:41205 // Adds a node with 'inst' to the graph and returns its unique identifier.
206 unsigned addNode(Instruction *inst) {
207 Node node(nextNodeId++, inst);
208 nodes.insert({node.id, node});
209 return node.id;
210 }
211
MLIR Teamc4237ae2019-01-18 16:56:27212 // Remove node 'id' (and its associated edges) from graph.
213 void removeNode(unsigned id) {
214 // Remove each edge in 'inEdges[id]'.
215 if (inEdges.count(id) > 0) {
216 SmallVector<Edge, 2> oldInEdges = inEdges[id];
217 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41218 removeEdge(inEdge.id, id, inEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27219 }
220 }
221 // Remove each edge in 'outEdges[id]'.
222 if (outEdges.count(id) > 0) {
223 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
224 for (auto &outEdge : oldOutEdges) {
MLIR Teama0f3db402019-01-29 17:36:41225 removeEdge(id, outEdge.id, outEdge.value);
MLIR Teamc4237ae2019-01-18 16:56:27226 }
227 }
228 // Erase remaining node state.
229 inEdges.erase(id);
230 outEdges.erase(id);
231 nodes.erase(id);
232 }
233
MLIR Teama0f3db402019-01-29 17:36:41234 // Returns true if node 'id' can be removed from the graph. Returns false
235 // otherwise. A node can be removed from the graph iff the following
236 // conditions are met:
237 // *) The node does not write to any memref which escapes (or is an argument
238 // to) the function/block.
239 // *) The node has no successors in the dependence graph.
240 bool canRemoveNode(unsigned id) {
MLIR Team71495d52019-01-22 21:23:37241 Node *node = getNode(id);
242 for (auto *storeOpInst : node->stores) {
243 auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
244 auto *inst = memref->getDefiningInst();
245 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
246 // Return false if 'memref' is a function argument.
247 if (opInst == nullptr)
MLIR Teama0f3db402019-01-29 17:36:41248 return false;
MLIR Team71495d52019-01-22 21:23:37249 // Return false if any use of 'memref' escapes the function.
250 for (auto &use : memref->getUses()) {
251 auto *user = dyn_cast<OperationInst>(use.getOwner());
252 if (!user || !isMemRefDereferencingOp(*user))
MLIR Teama0f3db402019-01-29 17:36:41253 return false;
MLIR Team71495d52019-01-22 21:23:37254 }
MLIR Teama0f3db402019-01-29 17:36:41255 // Return false if there exist out edges from 'id' on 'memref'.
256 if (getOutEdgeCount(id, memref) > 0)
257 return false;
MLIR Team71495d52019-01-22 21:23:37258 }
MLIR Teama0f3db402019-01-29 17:36:41259 return true;
MLIR Team71495d52019-01-22 21:23:37260 }
261
MLIR Team27d067e2019-01-16 17:55:02262 // Returns true iff there is an edge from node 'srcId' to node 'dstId' for
MLIR Teama0f3db402019-01-29 17:36:41263 // 'value'. Returns false otherwise.
264 bool hasEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team27d067e2019-01-16 17:55:02265 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
266 return false;
267 }
268 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41269 return edge.id == dstId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02270 });
271 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
MLIR Teama0f3db402019-01-29 17:36:41272 return edge.id == srcId && edge.value == value;
MLIR Team27d067e2019-01-16 17:55:02273 });
274 return hasOutEdge && hasInEdge;
275 }
276
MLIR Teama0f3db402019-01-29 17:36:41277 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
278 void addEdge(unsigned srcId, unsigned dstId, Value *value) {
279 if (!hasEdge(srcId, dstId, value)) {
280 outEdges[srcId].push_back({dstId, value});
281 inEdges[dstId].push_back({srcId, value});
282 if (value->getType().isa<MemRefType>())
283 memrefEdgeCount[value]++;
MLIR Team27d067e2019-01-16 17:55:02284 }
MLIR Team6892ffb2018-12-20 04:42:55285 }
286
MLIR Teama0f3db402019-01-29 17:36:41287 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
288 void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
MLIR Team6892ffb2018-12-20 04:42:55289 assert(inEdges.count(dstId) > 0);
290 assert(outEdges.count(srcId) > 0);
MLIR Teama0f3db402019-01-29 17:36:41291 if (value->getType().isa<MemRefType>()) {
292 assert(memrefEdgeCount.count(value) > 0);
293 memrefEdgeCount[value]--;
294 }
MLIR Team6892ffb2018-12-20 04:42:55295 // Remove 'srcId' from 'inEdges[dstId]'.
296 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41297 if ((*it).id == srcId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55298 inEdges[dstId].erase(it);
299 break;
300 }
301 }
302 // Remove 'dstId' from 'outEdges[srcId]'.
303 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
MLIR Teama0f3db402019-01-29 17:36:41304 if ((*it).id == dstId && (*it).value == value) {
MLIR Team6892ffb2018-12-20 04:42:55305 outEdges[srcId].erase(it);
306 break;
307 }
308 }
309 }
310
MLIR Teama0f3db402019-01-29 17:36:41311 // Returns the input edge count for node 'id' and 'memref' from src nodes
312 // which access 'memref'.
313 unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55314 unsigned inEdgeCount = 0;
315 if (inEdges.count(id) > 0)
316 for (auto &inEdge : inEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41317 if (inEdge.value == memref) {
318 Node *srcNode = getNode(inEdge.id);
319 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
320 if (srcNode->getLoadOpCount(memref) > 0 ||
321 srcNode->getStoreOpCount(memref) > 0)
322 ++inEdgeCount;
323 }
MLIR Team6892ffb2018-12-20 04:42:55324 return inEdgeCount;
325 }
326
327 // Returns the output edge count for node 'id' and 'memref'.
Chris Lattner3f190312018-12-27 22:35:10328 unsigned getOutEdgeCount(unsigned id, Value *memref) {
MLIR Team6892ffb2018-12-20 04:42:55329 unsigned outEdgeCount = 0;
330 if (outEdges.count(id) > 0)
331 for (auto &outEdge : outEdges[id])
MLIR Teama0f3db402019-01-29 17:36:41332 if (outEdge.value == memref)
MLIR Team6892ffb2018-12-20 04:42:55333 ++outEdgeCount;
334 return outEdgeCount;
335 }
336
MLIR Teama0f3db402019-01-29 17:36:41337 // Computes and returns an insertion point instruction, before which the
338 // the fused <srcId, dstId> loop nest can be inserted while preserving
339 // dependences. Returns nullptr if no such insertion point is found.
340 Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId,
341 Value *memrefToSkip) {
MLIR Team5c5739d2019-01-25 06:27:40342 if (outEdges.count(srcId) == 0)
MLIR Teama0f3db402019-01-29 17:36:41343 return getNode(dstId)->inst;
344
345 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
346 SmallPtrSet<Instruction *, 2> srcDepInsts;
347 for (auto &outEdge : outEdges[srcId])
348 if (outEdge.id != dstId && outEdge.value != memrefToSkip)
349 srcDepInsts.insert(getNode(outEdge.id)->inst);
350
351 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
352 SmallPtrSet<Instruction *, 2> dstDepInsts;
353 for (auto &inEdge : inEdges[dstId])
354 if (inEdge.id != srcId && inEdge.value != memrefToSkip)
355 dstDepInsts.insert(getNode(inEdge.id)->inst);
356
357 Instruction *srcNodeInst = getNode(srcId)->inst;
358 Instruction *dstNodeInst = getNode(dstId)->inst;
359
360 // Computing insertion point:
361 // *) Walk all instruction positions in Block instruction list in the
362 // range (src, dst). For each instruction 'inst' visited in this search:
363 // *) Store in 'firstSrcDepPos' the first position where 'inst' has a
364 // dependence edge from 'srcNode'.
365 // *) Store in 'lastDstDepPost' the last position where 'inst' has a
366 // dependence edge to 'dstNode'.
367 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
368 // instruction insertion point (or return null pointer if no such
369 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
370 SmallVector<Instruction *, 2> depInsts;
371 Optional<unsigned> firstSrcDepPos;
372 Optional<unsigned> lastDstDepPos;
373 unsigned pos = 0;
374 for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
375 it != Block::iterator(dstNodeInst); ++it) {
376 Instruction *inst = &(*it);
377 if (srcDepInsts.count(inst) > 0 && firstSrcDepPos == None)
378 firstSrcDepPos = pos;
379 if (dstDepInsts.count(inst) > 0)
380 lastDstDepPos = pos;
381 depInsts.push_back(inst);
382 ++pos;
MLIR Team5c5739d2019-01-25 06:27:40383 }
MLIR Teama0f3db402019-01-29 17:36:41384
385 if (firstSrcDepPos.hasValue()) {
386 if (lastDstDepPos.hasValue()) {
387 if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
388 // No valid insertion point exists which preserves dependences.
389 return nullptr;
390 }
391 }
392 // Return the insertion point at 'firstSrcDepPos'.
393 return depInsts[firstSrcDepPos.getValue()];
394 }
395 // No dependence targets in range (or only dst deps in range), return
396 // 'dstNodInst' insertion point.
397 return dstNodeInst;
MLIR Team6892ffb2018-12-20 04:42:55398 }
399
MLIR Teama0f3db402019-01-29 17:36:41400 // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
401 // has been replaced in node at 'dstId' by a private memref.
402 void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
MLIR Team6892ffb2018-12-20 04:42:55403 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
404 if (inEdges.count(srcId) > 0) {
405 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
406 for (auto &inEdge : oldInEdges) {
MLIR Teama0f3db402019-01-29 17:36:41407 // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
408 if (inEdge.value != oldMemRef)
409 addEdge(inEdge.id, dstId, inEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55410 }
411 }
MLIR Teamc4237ae2019-01-18 16:56:27412 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
MLIR Team6892ffb2018-12-20 04:42:55413 if (outEdges.count(srcId) > 0) {
414 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
415 for (auto &outEdge : oldOutEdges) {
MLIR Teamc4237ae2019-01-18 16:56:27416 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
417 if (outEdge.id == dstId)
MLIR Teama0f3db402019-01-29 17:36:41418 removeEdge(srcId, outEdge.id, outEdge.value);
MLIR Team6892ffb2018-12-20 04:42:55419 }
420 }
MLIR Teama0f3db402019-01-29 17:36:41421 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
422 // replaced by a private memref). These edges could come from nodes
423 // other than 'srcId' which were removed in the previous step.
424 if (inEdges.count(dstId) > 0) {
425 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
426 for (auto &inEdge : oldInEdges)
427 if (inEdge.value == oldMemRef)
428 removeEdge(inEdge.id, dstId, inEdge.value);
429 }
MLIR Team6892ffb2018-12-20 04:42:55430 }
431
432 // Adds ops in 'loads' and 'stores' to node at 'id'.
Chris Lattner5187cfc2018-12-28 05:21:41433 void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
434 const SmallVectorImpl<OperationInst *> &stores) {
MLIR Team6892ffb2018-12-20 04:42:55435 Node *node = getNode(id);
Chris Lattner456ad6a2018-12-29 00:05:35436 for (auto *loadOpInst : loads)
437 node->loads.push_back(loadOpInst);
438 for (auto *storeOpInst : stores)
439 node->stores.push_back(storeOpInst);
MLIR Team6892ffb2018-12-20 04:42:55440 }
441
MLIR Teamc4237ae2019-01-18 16:56:27442 void clearNodeLoadAndStores(unsigned id) {
443 Node *node = getNode(id);
444 node->loads.clear();
445 node->stores.clear();
446 }
447
MLIR Team6892ffb2018-12-20 04:42:55448 void print(raw_ostream &os) const {
449 os << "\nMemRefDependenceGraph\n";
450 os << "\nNodes:\n";
451 for (auto &idAndNode : nodes) {
452 os << "Node: " << idAndNode.first << "\n";
453 auto it = inEdges.find(idAndNode.first);
454 if (it != inEdges.end()) {
455 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41456 os << " InEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55457 }
458 it = outEdges.find(idAndNode.first);
459 if (it != outEdges.end()) {
460 for (const auto &e : it->second)
MLIR Teama0f3db402019-01-29 17:36:41461 os << " OutEdge: " << e.id << " " << e.value << "\n";
MLIR Team6892ffb2018-12-20 04:42:55462 }
463 }
464 }
465 void dump() const { print(llvm::errs()); }
466};
467
Chris Lattner456ad6a2018-12-29 00:05:35468// Intializes the data dependence graph by walking instructions in 'f'.
MLIR Team6892ffb2018-12-20 04:42:55469// Assigns each node in the graph a node id based on program order in 'f'.
Chris Lattner315a4662018-12-28 21:07:39470// TODO(andydavis) Add support for taking a Block arg to construct the
MLIR Team6892ffb2018-12-20 04:42:55471// dependence graph at a different depth.
Chris Lattner69d9e992018-12-28 16:48:09472bool MemRefDependenceGraph::init(Function *f) {
Chris Lattner3f190312018-12-27 22:35:10473 DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
Chris Lattnerdffc5892018-12-29 23:33:43474
475 // TODO: support multi-block functions.
476 if (f->getBlocks().size() != 1)
477 return false;
478
MLIR Teama0f3db402019-01-29 17:36:41479 DenseMap<ForInst *, unsigned> forToNodeMap;
Chris Lattnerdffc5892018-12-29 23:33:43480 for (auto &inst : f->front()) {
Chris Lattner456ad6a2018-12-29 00:05:35481 if (auto *forInst = dyn_cast<ForInst>(&inst)) {
482 // Create graph node 'id' to represent top-level 'forInst' and record
MLIR Team6892ffb2018-12-20 04:42:55483 // all loads and store accesses it contains.
484 LoopNestStateCollector collector;
Chris Lattner456ad6a2018-12-29 00:05:35485 collector.walkForInst(forInst);
MLIR Teama0f3db402019-01-29 17:36:41486 // Return false if IfInsts are found (not currently supported).
River Riddle75553832019-01-29 05:23:53487 if (collector.hasNonForRegion)
MLIR Team6892ffb2018-12-20 04:42:55488 return false;
MLIR Teama0f3db402019-01-29 17:36:41489 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35490 for (auto *opInst : collector.loadOpInsts) {
491 node.loads.push_back(opInst);
492 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55493 memrefAccesses[memref].insert(node.id);
494 }
Chris Lattner456ad6a2018-12-29 00:05:35495 for (auto *opInst : collector.storeOpInsts) {
496 node.stores.push_back(opInst);
497 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55498 memrefAccesses[memref].insert(node.id);
499 }
MLIR Teama0f3db402019-01-29 17:36:41500 forToNodeMap[forInst] = node.id;
MLIR Team6892ffb2018-12-20 04:42:55501 nodes.insert({node.id, node});
502 }
Chris Lattner456ad6a2018-12-29 00:05:35503 if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
504 if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55505 // Create graph node for top-level load op.
MLIR Teama0f3db402019-01-29 17:36:41506 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35507 node.loads.push_back(opInst);
508 auto *memref = opInst->cast<LoadOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55509 memrefAccesses[memref].insert(node.id);
510 nodes.insert({node.id, node});
River Riddle75553832019-01-29 05:23:53511 } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
MLIR Team6892ffb2018-12-20 04:42:55512 // Create graph node for top-level store op.
MLIR Teama0f3db402019-01-29 17:36:41513 Node node(nextNodeId++, &inst);
Chris Lattner456ad6a2018-12-29 00:05:35514 node.stores.push_back(opInst);
515 auto *memref = opInst->cast<StoreOp>()->getMemRef();
MLIR Team6892ffb2018-12-20 04:42:55516 memrefAccesses[memref].insert(node.id);
517 nodes.insert({node.id, node});
River Riddle75553832019-01-29 05:23:53518 } else if (opInst->getNumBlockLists() != 0) {
519 // Return false if another region is found (not currently supported).
520 return false;
MLIR Teama0f3db402019-01-29 17:36:41521 } else if (opInst->getNumResults() > 0 && !opInst->use_empty()) {
522 // Create graph node for top-level producer of SSA values, which
523 // could be used by loop nest nodes.
524 Node node(nextNodeId++, &inst);
525 nodes.insert({node.id, node});
526 }
527 }
528 }
529
530 // Add dependence edges between nodes which produce SSA values and their
531 // users.
532 for (auto &idAndNode : nodes) {
533 const Node &node = idAndNode.second;
534 if (!node.loads.empty() || !node.stores.empty())
535 continue;
536 auto *opInst = cast<OperationInst>(node.inst);
537 for (auto *value : opInst->getResults()) {
538 for (auto &use : value->getUses()) {
539 auto *userOpInst = cast<OperationInst>(use.getOwner());
540 SmallVector<ForInst *, 4> loops;
541 getLoopIVs(*userOpInst, &loops);
542 if (loops.empty())
543 continue;
544 assert(forToNodeMap.count(loops[0]) > 0);
545 unsigned userLoopNestId = forToNodeMap[loops[0]];
546 addEdge(node.id, userLoopNestId, value);
MLIR Team6892ffb2018-12-20 04:42:55547 }
548 }
MLIR Team6892ffb2018-12-20 04:42:55549 }
550
551 // Walk memref access lists and add graph edges between dependent nodes.
552 for (auto &memrefAndList : memrefAccesses) {
553 unsigned n = memrefAndList.second.size();
554 for (unsigned i = 0; i < n; ++i) {
555 unsigned srcId = memrefAndList.second[i];
556 bool srcHasStore =
557 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
558 for (unsigned j = i + 1; j < n; ++j) {
559 unsigned dstId = memrefAndList.second[j];
560 bool dstHasStore =
561 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
562 if (srcHasStore || dstHasStore)
563 addEdge(srcId, dstId, memrefAndList.first);
564 }
565 }
566 }
567 return true;
568}
569
MLIR Team38c2fe32019-01-14 19:26:25570namespace {
571
572// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
573// and operation count) for a loop nest up until the innermost loop body.
574struct LoopNestStats {
575 // Map from ForInst to immediate child ForInsts in its loop body.
576 DenseMap<ForInst *, SmallVector<ForInst *, 2>> loopMap;
577 // Map from ForInst to count of operations in its loop body.
578 DenseMap<ForInst *, uint64_t> opCountMap;
579 // Map from ForInst to its constant trip count.
580 DenseMap<ForInst *, uint64_t> tripCountMap;
581};
582
583// LoopNestStatsCollector walks a single loop nest and gathers per-loop
584// trip count and operation count statistics and records them in 'stats'.
585class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
586public:
587 LoopNestStats *stats;
588 bool hasLoopWithNonConstTripCount = false;
589
590 LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
591
592 void visitForInst(ForInst *forInst) {
593 auto *parentInst = forInst->getParentInst();
594 if (parentInst != nullptr) {
595 assert(isa<ForInst>(parentInst) && "Expected parent ForInst");
596 // Add mapping to 'forInst' from its parent ForInst.
597 stats->loopMap[cast<ForInst>(parentInst)].push_back(forInst);
598 }
599 // Record the number of op instructions in the body of 'forInst'.
600 unsigned count = 0;
601 stats->opCountMap[forInst] = 0;
602 for (auto &inst : *forInst->getBody()) {
603 if (isa<OperationInst>(&inst))
604 ++count;
605 }
606 stats->opCountMap[forInst] = count;
607 // Record trip count for 'forInst'. Set flag if trip count is not constant.
608 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(*forInst);
609 if (!maybeConstTripCount.hasValue()) {
610 hasLoopWithNonConstTripCount = true;
611 return;
612 }
613 stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
614 }
615};
616
617// Computes the total cost of the loop nest rooted at 'forInst'.
618// Currently, the total cost is computed by counting the total operation
619// instance count (i.e. total number of operations in the loop bodyloop
620// operation count * loop trip count) for the entire loop nest.
621// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
622// specified in the map when computing the total op instance count.
623// NOTE: this is used to compute the cost of computation slices, which are
624// sliced along the iteration dimension, and thus reduce the trip count.
625// If 'computeCostMap' is non-null, the total op count for forInsts specified
626// in the map is increased (not overridden) by adding the op count from the
627// map to the existing op count for the for loop. This is done before
628// multiplying by the loop's trip count, and is used to model the cost of
629// inserting a sliced loop nest of known cost into the loop's body.
630// NOTE: this is used to compute the cost of fusing a slice of some loop nest
631// within another loop.
Uday Bondhugula864d9e02019-01-23 17:16:24632static int64_t getComputeCost(
MLIR Team27d067e2019-01-16 17:55:02633 ForInst *forInst, LoopNestStats *stats,
634 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
Uday Bondhugula864d9e02019-01-23 17:16:24635 DenseMap<ForInst *, int64_t> *computeCostMap) {
MLIR Team38c2fe32019-01-14 19:26:25636 // 'opCount' is the total number operations in one iteration of 'forInst' body
Uday Bondhugula864d9e02019-01-23 17:16:24637 int64_t opCount = stats->opCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25638 if (stats->loopMap.count(forInst) > 0) {
639 for (auto *childForInst : stats->loopMap[forInst]) {
640 opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
641 computeCostMap);
642 }
643 }
644 // Add in additional op instances from slice (if specified in map).
645 if (computeCostMap != nullptr) {
646 auto it = computeCostMap->find(forInst);
647 if (it != computeCostMap->end()) {
648 opCount += it->second;
649 }
650 }
651 // Override trip count (if specified in map).
Uday Bondhugula864d9e02019-01-23 17:16:24652 int64_t tripCount = stats->tripCountMap[forInst];
MLIR Team38c2fe32019-01-14 19:26:25653 if (tripCountOverrideMap != nullptr) {
654 auto it = tripCountOverrideMap->find(forInst);
655 if (it != tripCountOverrideMap->end()) {
656 tripCount = it->second;
657 }
658 }
659 // Returns the total number of dynamic instances of operations in loop body.
660 return tripCount * opCount;
661}
662
663} // end anonymous namespace
664
MLIR Team27d067e2019-01-16 17:55:02665static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
Uday Bondhugulac1ca23e2019-01-16 21:13:00666 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
667 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
MLIR Team27d067e2019-01-16 17:55:02668 assert(lbMap.getNumDims() == ubMap.getNumDims());
669 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
670 // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
671 // ub_expr - lb_expr
672 AffineExpr lbExpr(lbMap.getResult(0));
673 AffineExpr ubExpr(ubMap.getResult(0));
674 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
675 lbMap.getNumSymbols());
676 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
677 if (!cExpr)
678 return None;
679 return cExpr.getValue();
680}
681
MLIR Team38c2fe32019-01-14 19:26:25682// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
683// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
684// Returns true on success, false otherwise (if a non-constant trip count
685// was encountered).
686// TODO(andydavis) Make this work with non-unit step loops.
MLIR Team27d067e2019-01-16 17:55:02687static bool buildSliceTripCountMap(
688 OperationInst *srcOpInst, ComputationSliceState *sliceState,
689 llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
MLIR Team38c2fe32019-01-14 19:26:25690 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02691 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25692 unsigned numSrcLoopIVs = srcLoopIVs.size();
693 // Populate map from ForInst -> trip count
694 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
695 AffineMap lbMap = sliceState->lbs[i];
696 AffineMap ubMap = sliceState->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17697 if (lbMap == AffineMap() || ubMap == AffineMap()) {
MLIR Team38c2fe32019-01-14 19:26:25698 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
699 if (srcLoopIVs[i]->hasConstantLowerBound() &&
700 srcLoopIVs[i]->hasConstantUpperBound()) {
701 (*tripCountMap)[srcLoopIVs[i]] =
702 srcLoopIVs[i]->getConstantUpperBound() -
703 srcLoopIVs[i]->getConstantLowerBound();
704 continue;
705 }
706 return false;
707 }
MLIR Team27d067e2019-01-16 17:55:02708 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
709 if (!tripCount.hasValue())
MLIR Team38c2fe32019-01-14 19:26:25710 return false;
MLIR Team27d067e2019-01-16 17:55:02711 (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
MLIR Team38c2fe32019-01-14 19:26:25712 }
713 return true;
714}
715
MLIR Team27d067e2019-01-16 17:55:02716// Removes load operations from 'srcLoads' which operate on 'memref', and
717// adds them to 'dstLoads'.
718static void
719moveLoadsAccessingMemrefTo(Value *memref,
720 SmallVectorImpl<OperationInst *> *srcLoads,
721 SmallVectorImpl<OperationInst *> *dstLoads) {
722 dstLoads->clear();
723 SmallVector<OperationInst *, 4> srcLoadsToKeep;
724 for (auto *load : *srcLoads) {
725 if (load->cast<LoadOp>()->getMemRef() == memref)
726 dstLoads->push_back(load);
727 else
728 srcLoadsToKeep.push_back(load);
MLIR Team38c2fe32019-01-14 19:26:25729 }
MLIR Team27d067e2019-01-16 17:55:02730 srcLoads->swap(srcLoadsToKeep);
MLIR Team38c2fe32019-01-14 19:26:25731}
732
MLIR Team27d067e2019-01-16 17:55:02733// Returns the innermost common loop depth for the set of operations in 'ops'.
734static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
735 unsigned numOps = ops.size();
736 assert(numOps > 0);
737
738 std::vector<SmallVector<ForInst *, 4>> loops(numOps);
739 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
740 for (unsigned i = 0; i < numOps; ++i) {
741 getLoopIVs(*ops[i], &loops[i]);
742 loopDepthLimit =
743 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
MLIR Team38c2fe32019-01-14 19:26:25744 }
MLIR Team27d067e2019-01-16 17:55:02745
746 unsigned loopDepth = 0;
747 for (unsigned d = 0; d < loopDepthLimit; ++d) {
748 unsigned i;
749 for (i = 1; i < numOps; ++i) {
750 if (loops[i - 1][d] != loops[i][d]) {
751 break;
752 }
753 }
754 if (i != numOps)
755 break;
756 ++loopDepth;
757 }
758 return loopDepth;
MLIR Team38c2fe32019-01-14 19:26:25759}
760
Uday Bondhugulac1ca23e2019-01-16 21:13:00761// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
762// using a rectangular bounding box.
MLIR Team27d067e2019-01-16 17:55:02763// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
764// and 'sliceStateB' are aligned.
765// Specifically, when taking the union of overlapping intervals, it assumes
766// that both intervals start at zero. Support needs to be added to take into
767// account interval start offset when computing the union.
768// TODO(andydavis) Move this function to an analysis library.
Uday Bondhugulac1ca23e2019-01-16 21:13:00769static bool getSliceUnion(const ComputationSliceState &sliceStateA,
770 ComputationSliceState *sliceStateB) {
MLIR Team27d067e2019-01-16 17:55:02771 assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
772 assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
773
774 for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
775 AffineMap lbMapA = sliceStateA.lbs[i];
776 AffineMap ubMapA = sliceStateA.ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17777 if (lbMapA == AffineMap()) {
778 assert(ubMapA == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02779 continue;
780 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00781 assert(ubMapA && "expected non-null ub map");
MLIR Team27d067e2019-01-16 17:55:02782
783 AffineMap lbMapB = sliceStateB->lbs[i];
784 AffineMap ubMapB = sliceStateB->ubs[i];
Nicolas Vasilache0e7a8a92019-01-26 18:41:17785 if (lbMapB == AffineMap()) {
786 assert(ubMapB == AffineMap());
MLIR Team27d067e2019-01-16 17:55:02787 // Union 'sliceStateB' does not have a bound for 'i' so copy from A.
788 sliceStateB->lbs[i] = lbMapA;
789 sliceStateB->ubs[i] = ubMapA;
790 continue;
791 }
Uday Bondhugulac1ca23e2019-01-16 21:13:00792
793 // TODO(andydavis) Change this code to take the min across all lower bounds
794 // and max across all upper bounds for each dimension. This code can for
795 // cases where a unique min or max could not be statically determined.
796
797 // Assumption: both lower bounds are the same.
798 if (lbMapA != lbMapB)
MLIR Team27d067e2019-01-16 17:55:02799 return false;
800
801 // Add bound with the largest trip count to union.
802 Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
803 Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
804 if (!tripCountA.hasValue() || !tripCountB.hasValue())
805 return false;
Uday Bondhugulac1ca23e2019-01-16 21:13:00806
MLIR Team27d067e2019-01-16 17:55:02807 if (tripCountA.getValue() > tripCountB.getValue()) {
808 sliceStateB->lbs[i] = lbMapA;
809 sliceStateB->ubs[i] = ubMapA;
810 }
811 }
812 return true;
813}
814
MLIR Teamc4237ae2019-01-18 16:56:27815// Creates and returns a private (single-user) memref for fused loop rooted
816// at 'forInst', with (potentially reduced) memref size based on the
Uday Bondhugula94a03f82019-01-22 21:58:52817// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
818// TODO(bondhugula): consider refactoring the common code from generateDma and
819// this one.
MLIR Teamc4237ae2019-01-18 16:56:27820static Value *createPrivateMemRef(ForInst *forInst,
Uday Bondhugula94a03f82019-01-22 21:58:52821 OperationInst *srcStoreOpInst,
822 unsigned dstLoopDepth) {
MLIR Teamc4237ae2019-01-18 16:56:27823 // Create builder to insert alloc op just before 'forInst'.
824 FuncBuilder b(forInst);
825 // Builder to create constants at the top level.
826 FuncBuilder top(forInst->getFunction());
827 // Create new memref type based on slice bounds.
828 auto *oldMemRef = srcStoreOpInst->cast<StoreOp>()->getMemRef();
829 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
830 unsigned rank = oldMemRefType.getRank();
831
Uday Bondhugula94a03f82019-01-22 21:58:52832 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27833 MemRefRegion region;
Uday Bondhugula94a03f82019-01-22 21:58:52834 getMemRefRegion(srcStoreOpInst, dstLoopDepth, &region);
River Riddle6859f332019-01-23 22:39:45835 SmallVector<int64_t, 4> newShape;
MLIR Teamc4237ae2019-01-18 16:56:27836 std::vector<SmallVector<int64_t, 4>> lbs;
Uday Bondhugula94a03f82019-01-22 21:58:52837 SmallVector<int64_t, 8> lbDivisors;
MLIR Teamc4237ae2019-01-18 16:56:27838 lbs.reserve(rank);
839 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
Uday Bondhugula94a03f82019-01-22 21:58:52840 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
MLIR Teamc4237ae2019-01-18 16:56:27841 Optional<int64_t> numElements =
Uday Bondhugula94a03f82019-01-22 21:58:52842 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
MLIR Teamc4237ae2019-01-18 16:56:27843 assert(numElements.hasValue());
844
MLIR Teamc4237ae2019-01-18 16:56:27845 const FlatAffineConstraints *cst = region.getConstraints();
Uday Bondhugula94a03f82019-01-22 21:58:52846 // 'outerIVs' holds the values that this memory region is symbolic/paramteric
847 // on; this would correspond to loop IVs surrounding the level at which the
848 // slice is being materialized.
849 SmallVector<Value *, 8> outerIVs;
850 cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
851
852 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
MLIR Teamc4237ae2019-01-18 16:56:27853 SmallVector<AffineExpr, 4> offsets;
854 offsets.reserve(rank);
855 for (unsigned d = 0; d < rank; ++d) {
Uday Bondhugula94a03f82019-01-22 21:58:52856 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
857
MLIR Teamc4237ae2019-01-18 16:56:27858 AffineExpr offset = top.getAffineConstantExpr(0);
859 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
860 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
861 }
Uday Bondhugula94a03f82019-01-22 21:58:52862 assert(lbDivisors[d] > 0);
863 offset =
864 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
MLIR Teamc4237ae2019-01-18 16:56:27865 offsets.push_back(offset);
866 }
867
868 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
869 // by 'srcStoreOpInst'.
MLIR Teama0f3db402019-01-29 17:36:41870 auto newMemRefType =
871 top.getMemRefType(newShape, oldMemRefType.getElementType(), {},
872 oldMemRefType.getMemorySpace());
MLIR Teamc4237ae2019-01-18 16:56:27873 // Gather alloc operands for the dynamic dimensions of the memref.
874 SmallVector<Value *, 4> allocOperands;
875 unsigned dynamicDimCount = 0;
876 for (auto dimSize : oldMemRefType.getShape()) {
877 if (dimSize == -1)
878 allocOperands.push_back(
MLIR Teama0f3db402019-01-29 17:36:41879 top.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++));
MLIR Teamc4237ae2019-01-18 16:56:27880 }
881
882 // Create new private memref for fused loop 'forInst'.
MLIR Teama0f3db402019-01-29 17:36:41883 // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
884 // consumer loop nests to reduce their live range. Currently they are added
885 // at the beginning of the function, because loop nests can be reordered
886 // during the fusion pass.
MLIR Teamc4237ae2019-01-18 16:56:27887 Value *newMemRef =
MLIR Teama0f3db402019-01-29 17:36:41888 top.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
MLIR Teamc4237ae2019-01-18 16:56:27889
890 // Build an AffineMap to remap access functions based on lower bound offsets.
891 SmallVector<AffineExpr, 4> remapExprs;
892 remapExprs.reserve(rank);
893 unsigned zeroOffsetCount = 0;
894 for (unsigned i = 0; i < rank; i++) {
895 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
896 if (constExpr.getValue() == 0)
897 ++zeroOffsetCount;
Uday Bondhugula94a03f82019-01-22 21:58:52898 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
899
900 auto remapExpr =
901 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
902 remapExprs.push_back(remapExpr);
MLIR Teamc4237ae2019-01-18 16:56:27903 }
Uday Bondhugula94a03f82019-01-22 21:58:52904 auto indexRemap =
905 zeroOffsetCount == rank
Nicolas Vasilache0e7a8a92019-01-26 18:41:17906 ? AffineMap()
Uday Bondhugula94a03f82019-01-22 21:58:52907 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
MLIR Teamc4237ae2019-01-18 16:56:27908 // Replace all users of 'oldMemRef' with 'newMemRef'.
Uday Bondhugula94a03f82019-01-22 21:58:52909 bool ret =
910 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
911 /*extraOperands=*/outerIVs,
912 /*domInstFilter=*/&*forInst->getBody()->begin());
913 assert(ret && "replaceAllMemrefUsesWith should always succeed here");
MLIR Team71495d52019-01-22 21:23:37914 (void)ret;
MLIR Teamc4237ae2019-01-18 16:56:27915 return newMemRef;
916}
917
Uday Bondhugula864d9e02019-01-23 17:16:24918// Does the slice have a single iteration?
919static uint64_t getSliceIterationCount(
920 const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) {
921 uint64_t iterCount = 1;
922 for (const auto &count : sliceTripCountMap) {
923 iterCount *= count.second;
924 }
925 return iterCount;
926}
927
MLIR Team27d067e2019-01-16 17:55:02928// Checks the profitability of fusing a backwards slice of the loop nest
929// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
Uday Bondhugulab4a14432019-01-26 00:00:50930// Returns true if it is profitable to fuse the candidate loop nests. Returns
931// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
932// to materialize the source loop nest slice.
MLIR Team38c2fe32019-01-14 19:26:25933// The profitability model executes the following steps:
MLIR Team27d067e2019-01-16 17:55:02934// *) Computes the backward computation slice at 'srcOpInst'. This
935// computation slice of the loop nest surrounding 'srcOpInst' is
MLIR Team38c2fe32019-01-14 19:26:25936// represented by modified src loop bounds in 'sliceState', which are
MLIR Team27d067e2019-01-16 17:55:02937// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
MLIR Team38c2fe32019-01-14 19:26:25938// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
939// loop nest is the total number of dynamic operation instances in the loop
940// nest).
941// *) Computes the cost of fusing a slice of the src loop nest into the dst
MLIR Team27d067e2019-01-16 17:55:02942// loop nest at various values of dst loop depth, attempting to fuse
943// the largest compution slice at the maximal dst loop depth (closest to the
944// load) to minimize reuse distance and potentially enable subsequent
945// load/store forwarding.
946// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for
947// the same memref as is written by 'srcOpInst', then the union of slice
948// loop bounds is used to compute the slice and associated slice cost.
Uday Bondhugulab4a14432019-01-26 00:00:50949// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
MLIR Team38c2fe32019-01-14 19:26:25950// nest, at which the src computation slice is inserted/fused.
MLIR Team27d067e2019-01-16 17:55:02951// NOTE: We attempt to maximize the dst loop depth, but there are cases
952// where a particular setting for 'dstLoopNest' might fuse an unsliced
MLIR Team38c2fe32019-01-14 19:26:25953// loop (within the src computation slice) at a depth which results in
954// execessive recomputation (see unit tests for examples).
955// *) Compares the total cost of the unfused loop nests to the min cost fused
956// loop nest computed in the previous step, and returns true if the latter
957// is lower.
MLIR Team27d067e2019-01-16 17:55:02958static bool isFusionProfitable(OperationInst *srcOpInst,
959 ArrayRef<OperationInst *> dstOpInsts,
MLIR Team38c2fe32019-01-14 19:26:25960 ComputationSliceState *sliceState,
MLIR Team27d067e2019-01-16 17:55:02961 unsigned *dstLoopDepth) {
Uday Bondhugula06d21d92019-01-25 01:01:49962 LLVM_DEBUG({
963 llvm::dbgs() << "Checking whether fusion is profitable between:\n";
964 llvm::dbgs() << " ";
965 srcOpInst->dump();
966 llvm::dbgs() << " and \n";
967 for (auto dstOpInst : dstOpInsts) {
968 llvm::dbgs() << " ";
969 dstOpInst->dump();
970 };
971 });
Uday Bondhugula864d9e02019-01-23 17:16:24972
MLIR Team38c2fe32019-01-14 19:26:25973 // Compute cost of sliced and unsliced src loop nest.
974 SmallVector<ForInst *, 4> srcLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02975 getLoopIVs(*srcOpInst, &srcLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25976 unsigned numSrcLoopIVs = srcLoopIVs.size();
977
978 // Walk src loop nest and collect stats.
979 LoopNestStats srcLoopNestStats;
980 LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
981 srcStatsCollector.walk(srcLoopIVs[0]);
982 // Currently only constant trip count loop nests are supported.
983 if (srcStatsCollector.hasLoopWithNonConstTripCount)
984 return false;
985
986 // Compute cost of dst loop nest.
987 SmallVector<ForInst *, 4> dstLoopIVs;
MLIR Team27d067e2019-01-16 17:55:02988 getLoopIVs(*dstOpInsts[0], &dstLoopIVs);
MLIR Team38c2fe32019-01-14 19:26:25989
990 LoopNestStats dstLoopNestStats;
991 LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
992 dstStatsCollector.walk(dstLoopIVs[0]);
993 // Currently only constant trip count loop nests are supported.
994 if (dstStatsCollector.hasLoopWithNonConstTripCount)
995 return false;
996
MLIR Team27d067e2019-01-16 17:55:02997 // Compute the innermost common loop for ops in 'dstOpInst'.
998 unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts);
999 if (maxDstLoopDepth == 0)
1000 return false;
1001
1002 // Search for min cost value for 'dstLoopDepth'. At each value of
1003 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
1004 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
1005 // of these bounds). Next the union slice bounds are used to calculate
1006 // the cost of the slice and the cost of the slice inserted into the dst
1007 // loop nest at 'dstLoopDepth'.
Uday Bondhugula864d9e02019-01-23 17:16:241008 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
1009 uint64_t maxStorageReduction = 0;
1010 Optional<uint64_t> sliceMemEstimate = None;
1011
MLIR Team27d067e2019-01-16 17:55:021012 SmallVector<ComputationSliceState, 4> sliceStates;
1013 sliceStates.resize(maxDstLoopDepth);
Uday Bondhugula864d9e02019-01-23 17:16:241014 // The best loop depth at which to materialize the slice.
1015 Optional<unsigned> bestDstLoopDepth = None;
1016
1017 // Compute op instance count for the src loop nest without iteration slicing.
1018 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
1019 /*tripCountOverrideMap=*/nullptr,
1020 /*computeCostMap=*/nullptr);
1021
1022 // Compute op instance count for the src loop nest.
1023 uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
1024 /*tripCountOverrideMap=*/nullptr,
1025 /*computeCostMap=*/nullptr);
MLIR Team27d067e2019-01-16 17:55:021026
1027 llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
Uday Bondhugula864d9e02019-01-23 17:16:241028 DenseMap<ForInst *, int64_t> computeCostMap;
MLIR Team27d067e2019-01-16 17:55:021029 for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
1030 MemRefAccess srcAccess(srcOpInst);
1031 // Handle the common case of one dst load without a copy.
1032 if (!mlir::getBackwardComputationSliceState(
1033 srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1]))
1034 return false;
1035 // Compute the union of slice bound of all ops in 'dstOpInsts'.
1036 for (int j = 1, e = dstOpInsts.size(); j < e; ++j) {
1037 MemRefAccess dstAccess(dstOpInsts[j]);
1038 ComputationSliceState tmpSliceState;
1039 if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
1040 &tmpSliceState))
1041 return false;
1042 // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
Uday Bondhugulac1ca23e2019-01-16 21:13:001043 getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
MLIR Team38c2fe32019-01-14 19:26:251044 }
Uday Bondhugulab4a14432019-01-26 00:00:501045 // Build trip count map for computation slice. We'll skip cases where the
1046 // trip count was non-constant.
MLIR Team27d067e2019-01-16 17:55:021047 sliceTripCountMap.clear();
1048 if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
1049 &sliceTripCountMap))
Uday Bondhugula864d9e02019-01-23 17:16:241050 continue;
1051
1052 // Checks whether a store to load forwarding will happen.
1053 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
Uday Bondhugula864d9e02019-01-23 17:16:241054 assert(sliceIterationCount > 0);
Uday Bondhugulab4a14432019-01-26 00:00:501055 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
Uday Bondhugula864d9e02019-01-23 17:16:241056
1057 // Compute cost of fusion for this dest loop depth.
1058
1059 computeCostMap.clear();
1060
1061 // The store and loads to this memref will disappear.
1062 if (storeLoadFwdGuaranteed) {
1063 // A single store disappears: -1 for that.
1064 computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1;
1065 for (auto *loadOp : dstOpInsts) {
1066 if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst()))
1067 computeCostMap[loadLoop] = -1;
1068 }
1069 }
MLIR Team27d067e2019-01-16 17:55:021070
MLIR Team38c2fe32019-01-14 19:26:251071 // Compute op instance count for the src loop nest with iteration slicing.
Uday Bondhugula864d9e02019-01-23 17:16:241072 int64_t sliceComputeCost =
1073 getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
1074 /*tripCountOverrideMap=*/&sliceTripCountMap,
1075 /*computeCostMap=*/&computeCostMap);
MLIR Team38c2fe32019-01-14 19:26:251076
Uday Bondhugula864d9e02019-01-23 17:16:241077 // Compute cost of fusion for this depth.
MLIR Team27d067e2019-01-16 17:55:021078 computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
Uday Bondhugula864d9e02019-01-23 17:16:241079
1080 int64_t fusedLoopNestComputeCost =
MLIR Team27d067e2019-01-16 17:55:021081 getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
1082 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
Uday Bondhugula864d9e02019-01-23 17:16:241083
1084 double additionalComputeFraction =
1085 fusedLoopNestComputeCost /
1086 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1087 1;
1088
1089 // TODO(bondhugula): This is an ugly approximation. Fix this by finding a
1090 // good way to calculate the footprint of the memref in the slice and
1091 // divide it by the total memory footprint of the fused computation.
1092 double storageReduction =
1093 static_cast<double>(srcLoopNestCost) / sliceIterationCount;
1094
Uday Bondhugula06d21d92019-01-25 01:01:491095 LLVM_DEBUG({
1096 std::stringstream msg;
1097 msg << " evaluating fusion profitability at depth : " << i << "\n"
1098 << std::setprecision(2) << " additional compute fraction: "
1099 << 100.0 * additionalComputeFraction << "%\n"
1100 << " storage reduction factor: " << storageReduction << "x\n"
1101 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
1102 << " slice iteration count: " << sliceIterationCount << "\n";
1103 llvm::dbgs() << msg.str();
1104 });
Uday Bondhugula864d9e02019-01-23 17:16:241105
1106 double computeToleranceThreshold =
1107 clFusionAddlComputeTolerance.getNumOccurrences() > 0
1108 ? clFusionAddlComputeTolerance
1109 : LoopFusion::kComputeToleranceThreshold;
1110
1111 // TODO(b/123247369): This is a placeholder cost model.
1112 // Among all choices that add an acceptable amount of redundant computation
1113 // (as per computeToleranceThreshold), we will simply pick the one that
1114 // reduces the intermediary size the most.
1115 if ((storageReduction > maxStorageReduction) &&
1116 (clMaximalLoopFusion ||
1117 (additionalComputeFraction < computeToleranceThreshold))) {
1118 maxStorageReduction = storageReduction;
MLIR Team27d067e2019-01-16 17:55:021119 bestDstLoopDepth = i;
Uday Bondhugula864d9e02019-01-23 17:16:241120 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
1121 // TODO(bondhugula,andydavis): find a good way to compute the memory
1122 // footprint of the materialized slice.
1123 // Approximating this to the compute cost of the slice. This could be an
1124 // under-approximation or an overapproximation, but in many cases
1125 // accurate.
1126 sliceMemEstimate = sliceIterationCount;
MLIR Team38c2fe32019-01-14 19:26:251127 }
1128 }
1129
Uday Bondhugula864d9e02019-01-23 17:16:241130 // A simple cost model: fuse if it reduces the memory footprint. If
1131 // -maximal-fusion is set, fuse nevertheless.
MLIR Team38c2fe32019-01-14 19:26:251132
Uday Bondhugula864d9e02019-01-23 17:16:241133 if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
1134 LLVM_DEBUG(llvm::dbgs()
1135 << "All fusion choices involve more than the threshold amount of"
1136 "redundant computation; NOT fusing.\n");
MLIR Team38c2fe32019-01-14 19:26:251137 return false;
Uday Bondhugula864d9e02019-01-23 17:16:241138 }
1139
1140 assert(bestDstLoopDepth.hasValue() &&
1141 "expected to have a value per logic above");
1142
1143 // Set dstLoopDepth based on best values from search.
1144 *dstLoopDepth = bestDstLoopDepth.getValue();
1145
1146 LLVM_DEBUG(
Uday Bondhugula06d21d92019-01-25 01:01:491147 llvm::dbgs() << " LoopFusion fusion stats:"
1148 << "\n best loop depth: " << bestDstLoopDepth
Uday Bondhugula864d9e02019-01-23 17:16:241149 << "\n src loop nest compute cost: " << srcLoopNestCost
1150 << "\n dst loop nest compute cost: " << dstLoopNestCost
1151 << "\n fused loop nest compute cost: "
1152 << minFusedLoopNestComputeCost << "\n");
1153
1154 auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]);
1155 auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]);
1156
1157 Optional<double> storageReduction = None;
1158
1159 if (!clMaximalLoopFusion) {
1160 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
1161 LLVM_DEBUG(
1162 llvm::dbgs()
1163 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
1164 return false;
1165 }
1166
1167 auto srcMemSizeVal = srcMemSize.getValue();
1168 auto dstMemSizeVal = dstMemSize.getValue();
1169
1170 assert(sliceMemEstimate.hasValue() && "expected value");
1171 // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
1172 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
1173
1174 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
1175 << " dst mem: " << dstMemSizeVal << "\n"
1176 << " fused mem: " << fusedMem << "\n"
1177 << " slice mem: " << sliceMemEstimate << "\n");
1178
1179 if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
1180 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
1181 return false;
1182 }
1183 storageReduction =
1184 100.0 *
1185 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1186 }
1187
1188 double additionalComputeFraction =
1189 100.0 * (minFusedLoopNestComputeCost /
1190 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1191 1);
MLIR Team5c5739d2019-01-25 06:27:401192 (void)additionalComputeFraction;
Uday Bondhugula06d21d92019-01-25 01:01:491193 LLVM_DEBUG({
1194 std::stringstream msg;
1195 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
1196 << setprecision(2) << additionalComputeFraction
1197 << "% redundant computation and a ";
1198 msg << (storageReduction.hasValue()
1199 ? std::to_string(storageReduction.getValue())
1200 : "<unknown>");
1201 msg << "% storage reduction.\n";
1202 llvm::dbgs() << msg.str();
1203 });
Uday Bondhugula864d9e02019-01-23 17:16:241204
MLIR Team27d067e2019-01-16 17:55:021205 // Update return parameter 'sliceState' with 'bestSliceState'.
Uday Bondhugula864d9e02019-01-23 17:16:241206 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
MLIR Team27d067e2019-01-16 17:55:021207 sliceState->lbs = bestSliceState->lbs;
1208 sliceState->ubs = bestSliceState->ubs;
1209 sliceState->lbOperands = bestSliceState->lbOperands;
1210 sliceState->ubOperands = bestSliceState->ubOperands;
Uday Bondhugula864d9e02019-01-23 17:16:241211
MLIR Team27d067e2019-01-16 17:55:021212 // Canonicalize slice bound affine maps.
MLIR Team38c2fe32019-01-14 19:26:251213 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Nicolas Vasilache0e7a8a92019-01-26 18:41:171214 if (sliceState->lbs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021215 canonicalizeMapAndOperands(&sliceState->lbs[i],
1216 &sliceState->lbOperands[i]);
1217 }
Nicolas Vasilache0e7a8a92019-01-26 18:41:171218 if (sliceState->ubs[i] != AffineMap()) {
MLIR Team27d067e2019-01-16 17:55:021219 canonicalizeMapAndOperands(&sliceState->ubs[i],
1220 &sliceState->ubOperands[i]);
MLIR Team38c2fe32019-01-14 19:26:251221 }
1222 }
1223 return true;
1224}
1225
MLIR Team6892ffb2018-12-20 04:42:551226// GreedyFusion greedily fuses loop nests which have a producer/consumer
MLIR Team3b692302018-12-17 17:57:141227// relationship on a memref, with the goal of improving locality. Currently,
1228// this the producer/consumer relationship is required to be unique in the
Chris Lattner69d9e992018-12-28 16:48:091229// Function (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:001230//
MLIR Team3b692302018-12-17 17:57:141231// The steps of the algorithm are as follows:
1232//
MLIR Team6892ffb2018-12-20 04:42:551233// *) A worklist is initialized with node ids from the dependence graph.
1234// *) For each node id in the worklist:
Chris Lattner456ad6a2018-12-29 00:05:351235// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
1236// destination ForInst into which fusion will be attempted.
1237// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
MLIR Team3b692302018-12-17 17:57:141238// *) For each LoadOp in 'dstLoadOps' do:
Chris Lattner69d9e992018-12-28 16:48:091239// *) Lookup dependent loop nests at earlier positions in the Function
MLIR Team3b692302018-12-17 17:57:141240// which have a single store op to the same memref.
1241// *) Check if dependences would be violated by the fusion. For example,
1242// the src loop nest may load from memrefs which are different than
1243// the producer-consumer memref between src and dest loop nests.
MLIR Team6892ffb2018-12-20 04:42:551244// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
MLIR Team3b692302018-12-17 17:57:141245// bounds to be functions of 'dstLoopNest' IVs and symbols.
1246// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
1247// just before the dst load op user.
Chris Lattner456ad6a2018-12-29 00:05:351248// *) Add the newly fused load/store operation instructions to the state,
MLIR Team3b692302018-12-17 17:57:141249// and also add newly fuse load ops to 'dstLoopOps' to be considered
1250// as fusion dst load ops in another iteration.
1251// *) Remove old src loop nest and its associated state.
1252//
Chris Lattner456ad6a2018-12-29 00:05:351253// Given a graph where top-level instructions are vertices in the set 'V' and
MLIR Team3b692302018-12-17 17:57:141254// edges in the set 'E' are dependences between vertices, this algorithm
MLIR Team6892ffb2018-12-20 04:42:551255// takes O(V) time for initialization, and has runtime O(V + E).
MLIR Team3b692302018-12-17 17:57:141256//
MLIR Team6892ffb2018-12-20 04:42:551257// This greedy algorithm is not 'maximal' due to the current restriction of
1258// fusing along single producer consumer edges, but there is a TODO to fix this.
MLIR Team3b692302018-12-17 17:57:141259//
1260// TODO(andydavis) Experiment with other fusion policies.
MLIR Team6892ffb2018-12-20 04:42:551261// TODO(andydavis) Add support for fusing for input reuse (perhaps by
1262// constructing a graph with edges which represent loads from the same memref
MLIR Team5c5739d2019-01-25 06:27:401263// in two different loop nests.
MLIR Team6892ffb2018-12-20 04:42:551264struct GreedyFusion {
1265public:
1266 MemRefDependenceGraph *mdg;
MLIR Team3b692302018-12-17 17:57:141267 SmallVector<unsigned, 4> worklist;
MLIR Teamf28e4df2018-11-01 14:26:001268
MLIR Team6892ffb2018-12-20 04:42:551269 GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
1270 // Initialize worklist with nodes from 'mdg'.
1271 worklist.resize(mdg->nodes.size());
1272 std::iota(worklist.begin(), worklist.end(), 0);
1273 }
MLIR Team3b692302018-12-17 17:57:141274
1275 void run() {
MLIR Team3b692302018-12-17 17:57:141276 while (!worklist.empty()) {
MLIR Team6892ffb2018-12-20 04:42:551277 unsigned dstId = worklist.back();
MLIR Team3b692302018-12-17 17:57:141278 worklist.pop_back();
MLIR Team6892ffb2018-12-20 04:42:551279 // Skip if this node was removed (fused into another node).
1280 if (mdg->nodes.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141281 continue;
MLIR Team6892ffb2018-12-20 04:42:551282 // Get 'dstNode' into which to attempt fusion.
1283 auto *dstNode = mdg->getNode(dstId);
1284 // Skip if 'dstNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:351285 if (!isa<ForInst>(dstNode->inst))
MLIR Team3b692302018-12-17 17:57:141286 continue;
1287
Chris Lattner5187cfc2018-12-28 05:21:411288 SmallVector<OperationInst *, 4> loads = dstNode->loads;
MLIR Team27d067e2019-01-16 17:55:021289 SmallVector<OperationInst *, 4> dstLoadOpInsts;
MLIR Teamc4237ae2019-01-18 16:56:271290 DenseSet<Value *> visitedMemrefs;
MLIR Team6892ffb2018-12-20 04:42:551291 while (!loads.empty()) {
MLIR Team27d067e2019-01-16 17:55:021292 // Get memref of load on top of the stack.
1293 auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
MLIR Teamc4237ae2019-01-18 16:56:271294 if (visitedMemrefs.count(memref) > 0)
1295 continue;
1296 visitedMemrefs.insert(memref);
MLIR Team27d067e2019-01-16 17:55:021297 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
1298 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
MLIR Team6892ffb2018-12-20 04:42:551299 // Skip if no input edges along which to fuse.
1300 if (mdg->inEdges.count(dstId) == 0)
MLIR Team3b692302018-12-17 17:57:141301 continue;
MLIR Team6892ffb2018-12-20 04:42:551302 // Iterate through in edges for 'dstId'.
1303 for (auto &srcEdge : mdg->inEdges[dstId]) {
1304 // Skip 'srcEdge' if not for 'memref'.
MLIR Teama0f3db402019-01-29 17:36:411305 if (srcEdge.value != memref)
MLIR Team6892ffb2018-12-20 04:42:551306 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241307
MLIR Team6892ffb2018-12-20 04:42:551308 auto *srcNode = mdg->getNode(srcEdge.id);
1309 // Skip if 'srcNode' is not a loop nest.
Chris Lattner456ad6a2018-12-29 00:05:351310 if (!isa<ForInst>(srcNode->inst))
MLIR Team6892ffb2018-12-20 04:42:551311 continue;
MLIR Teamb28009b2019-01-23 19:11:431312 // Skip if 'srcNode' has more than one store to any memref.
1313 // TODO(andydavis) Support fusing multi-output src loop nests.
1314 if (srcNode->stores.size() != 1)
MLIR Team6892ffb2018-12-20 04:42:551315 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241316
MLIR Teama0f3db402019-01-29 17:36:411317 // Skip 'srcNode' if it has in edges on 'memref'.
MLIR Team6892ffb2018-12-20 04:42:551318 // TODO(andydavis) Track dependence type with edges, and just check
MLIR Teama0f3db402019-01-29 17:36:411319 // for WAW dependence edge here. Note that this check is overly
1320 // conservative and will be removed in the future.
1321 if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0)
MLIR Team6892ffb2018-12-20 04:42:551322 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241323
MLIR Teama0f3db402019-01-29 17:36:411324 // Compute an instruction list insertion point for the fused loop
1325 // nest which preserves dependences.
1326 Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint(
1327 srcNode->id, dstNode->id, memref);
1328 if (insertPointInst == nullptr)
MLIR Team6892ffb2018-12-20 04:42:551329 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241330
MLIR Team6892ffb2018-12-20 04:42:551331 // Get unique 'srcNode' store op.
Chris Lattner456ad6a2018-12-29 00:05:351332 auto *srcStoreOpInst = srcNode->stores.front();
Uday Bondhugulab4a14432019-01-26 00:00:501333 unsigned bestDstLoopDepth;
MLIR Team38c2fe32019-01-14 19:26:251334 mlir::ComputationSliceState sliceState;
MLIR Teama0f3db402019-01-29 17:36:411335 // Check if fusion would be profitable.
MLIR Team27d067e2019-01-16 17:55:021336 if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
Uday Bondhugulab4a14432019-01-26 00:00:501337 &bestDstLoopDepth))
MLIR Team38c2fe32019-01-14 19:26:251338 continue;
Uday Bondhugula864d9e02019-01-23 17:16:241339
MLIR Team6892ffb2018-12-20 04:42:551340 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1341 auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
Uday Bondhugulab4a14432019-01-26 00:00:501342 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
MLIR Team6892ffb2018-12-20 04:42:551343 if (sliceLoopNest != nullptr) {
MLIR Teama0f3db402019-01-29 17:36:411344 // Move 'dstForInst' before 'insertPointInst' if needed.
1345 auto *dstForInst = cast<ForInst>(dstNode->inst);
1346 if (insertPointInst != dstForInst) {
1347 dstForInst->moveBefore(insertPointInst);
1348 }
MLIR Teamc4237ae2019-01-18 16:56:271349 // Update edges between 'srcNode' and 'dstNode'.
MLIR Teama0f3db402019-01-29 17:36:411350 mdg->updateEdges(srcNode->id, dstNode->id, memref);
MLIR Teamc4237ae2019-01-18 16:56:271351
1352 // Collect slice loop stats.
1353 LoopNestStateCollector sliceCollector;
1354 sliceCollector.walkForInst(sliceLoopNest);
1355 // Promote single iteration slice loops to single IV value.
1356 for (auto *forInst : sliceCollector.forInsts) {
Chris Lattner456ad6a2018-12-29 00:05:351357 promoteIfSingleIteration(forInst);
MLIR Team6892ffb2018-12-20 04:42:551358 }
MLIR Teamc4237ae2019-01-18 16:56:271359 // Create private memref for 'memref' in 'dstForInst'.
MLIR Teamc4237ae2019-01-18 16:56:271360 SmallVector<OperationInst *, 4> storesForMemref;
1361 for (auto *storeOpInst : sliceCollector.storeOpInsts) {
1362 if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
1363 storesForMemref.push_back(storeOpInst);
1364 }
1365 assert(storesForMemref.size() == 1);
Uday Bondhugula94a03f82019-01-22 21:58:521366 auto *newMemRef = createPrivateMemRef(
Uday Bondhugulab4a14432019-01-26 00:00:501367 dstForInst, storesForMemref[0], bestDstLoopDepth);
MLIR Teamc4237ae2019-01-18 16:56:271368 visitedMemrefs.insert(newMemRef);
MLIR Teama0f3db402019-01-29 17:36:411369 // Create new node in dependence graph for 'newMemRef' alloc op.
1370 unsigned newMemRefNodeId =
1371 mdg->addNode(newMemRef->getDefiningInst());
1372 // Add edge from 'newMemRef' node to dstNode.
1373 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
MLIR Teamc4237ae2019-01-18 16:56:271374
1375 // Collect dst loop stats after memref privatizaton transformation.
1376 LoopNestStateCollector dstLoopCollector;
1377 dstLoopCollector.walkForInst(dstForInst);
1378
1379 // Add new load ops to current Node load op list 'loads' to
1380 // continue fusing based on new operands.
1381 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
1382 auto *loadMemRef = loadOpInst->cast<LoadOp>()->getMemRef();
1383 if (visitedMemrefs.count(loadMemRef) == 0)
1384 loads.push_back(loadOpInst);
1385 }
1386
1387 // Clear and add back loads and stores
1388 mdg->clearNodeLoadAndStores(dstNode->id);
1389 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1390 dstLoopCollector.storeOpInsts);
MLIR Team71495d52019-01-22 21:23:371391 // Remove old src loop nest if it no longer has outgoing dependence
1392 // edges, and it does not write to a memref which escapes the
1393 // function.
MLIR Teama0f3db402019-01-29 17:36:411394 if (mdg->canRemoveNode(srcNode->id)) {
MLIR Teamc4237ae2019-01-18 16:56:271395 mdg->removeNode(srcNode->id);
1396 cast<ForInst>(srcNode->inst)->erase();
1397 }
MLIR Team3b692302018-12-17 17:57:141398 }
MLIR Team3b692302018-12-17 17:57:141399 }
1400 }
1401 }
MLIR Teamc4237ae2019-01-18 16:56:271402 // Clean up any allocs with no users.
1403 for (auto &pair : mdg->memrefEdgeCount) {
1404 if (pair.second > 0)
1405 continue;
1406 auto *memref = pair.first;
MLIR Team71495d52019-01-22 21:23:371407 // Skip if there exist other uses (return instruction or function calls).
1408 if (!memref->use_empty())
1409 continue;
MLIR Teamc4237ae2019-01-18 16:56:271410 // Use list expected to match the dep graph info.
MLIR Teamc4237ae2019-01-18 16:56:271411 auto *inst = memref->getDefiningInst();
1412 auto *opInst = dyn_cast_or_null<OperationInst>(inst);
1413 if (opInst && opInst->isa<AllocOp>())
1414 opInst->erase();
1415 }
MLIR Teamf28e4df2018-11-01 14:26:001416 }
MLIR Team3b692302018-12-17 17:57:141417};
1418
1419} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:001420
Chris Lattner79748892018-12-31 07:10:351421PassResult LoopFusion::runOnFunction(Function *f) {
MLIR Team6892ffb2018-12-20 04:42:551422 MemRefDependenceGraph g;
1423 if (g.init(f))
1424 GreedyFusion(&g).run();
MLIR Teamf28e4df2018-11-01 14:26:001425 return success();
1426}
Jacques Pienaar6f0fb222018-11-07 02:34:181427
1428static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");