blob: 521fca8979fa696b763c68566842600845d47283 [file] [log] [blame]
MLIR Teamf28e4df2018-11-01 14:26:001//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop fusion.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/AffineAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1423#include "mlir/Analysis/AffineStructures.h"
MLIR Teamf28e4df2018-11-01 14:26:0024#include "mlir/Analysis/LoopAnalysis.h"
MLIR Team3b692302018-12-17 17:57:1425#include "mlir/Analysis/Utils.h"
MLIR Teamf28e4df2018-11-01 14:26:0026#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinOps.h"
30#include "mlir/IR/StmtVisitor.h"
31#include "mlir/Pass.h"
32#include "mlir/StandardOps/StandardOps.h"
33#include "mlir/Transforms/LoopUtils.h"
34#include "mlir/Transforms/Passes.h"
35#include "llvm/ADT/DenseMap.h"
MLIR Team3b692302018-12-17 17:57:1436#include "llvm/ADT/DenseSet.h"
37#include "llvm/ADT/SetVector.h"
38#include "llvm/Support/raw_ostream.h"
39
40using llvm::SetVector;
MLIR Teamf28e4df2018-11-01 14:26:0041
42using namespace mlir;
43
44namespace {
45
MLIR Team3b692302018-12-17 17:57:1446/// Loop fusion pass. This pass currently supports a greedy fusion policy,
47/// which fuses loop nests with single-writer/single-reader memref dependences
48/// with the goal of improving locality.
49
50// TODO(andydavis) Support fusion of source loop nests which write to multiple
51// memrefs, where each memref can have multiple users (if profitable).
MLIR Teamf28e4df2018-11-01 14:26:0052// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
53// and add support for more general loop fusion algorithms.
MLIR Team3b692302018-12-17 17:57:1454
MLIR Teamf28e4df2018-11-01 14:26:0055struct LoopFusion : public FunctionPass {
Jacques Pienaarcc9a6ed2018-11-07 18:24:0356 LoopFusion() : FunctionPass(&LoopFusion::passID) {}
MLIR Teamf28e4df2018-11-01 14:26:0057
58 PassResult runOnMLFunction(MLFunction *f) override;
Jacques Pienaar6f0fb222018-11-07 02:34:1859 static char passID;
MLIR Teamf28e4df2018-11-01 14:26:0060};
61
MLIR Teamf28e4df2018-11-01 14:26:0062} // end anonymous namespace
63
Jacques Pienaar6f0fb222018-11-07 02:34:1864char LoopFusion::passID = 0;
65
MLIR Teamf28e4df2018-11-01 14:26:0066FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
67
MLIR Teamf28e4df2018-11-01 14:26:0068static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt,
69 MemRefAccess *access) {
70 if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
71 access->memref = cast<MLValue>(loadOp->getMemRef());
72 access->opStmt = loadOrStoreOpStmt;
73 auto loadMemrefType = loadOp->getMemRefType();
74 access->indices.reserve(loadMemrefType.getRank());
75 for (auto *index : loadOp->getIndices()) {
76 access->indices.push_back(cast<MLValue>(index));
77 }
78 } else {
79 assert(loadOrStoreOpStmt->isa<StoreOp>());
80 auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
81 access->opStmt = loadOrStoreOpStmt;
82 access->memref = cast<MLValue>(storeOp->getMemRef());
83 auto storeMemrefType = storeOp->getMemRefType();
84 access->indices.reserve(storeMemrefType.getRank());
85 for (auto *index : storeOp->getIndices()) {
86 access->indices.push_back(cast<MLValue>(index));
87 }
88 }
89}
90
MLIR Team3b692302018-12-17 17:57:1491// FusionCandidate encapsulates source and destination memref access within
92// loop nests which are candidates for loop fusion.
93struct FusionCandidate {
94 // Load or store access within src loop nest to be fused into dst loop nest.
95 MemRefAccess srcAccess;
96 // Load or store access within dst loop nest.
97 MemRefAccess dstAccess;
98};
MLIR Teamf28e4df2018-11-01 14:26:0099
MLIR Team3b692302018-12-17 17:57:14100static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt,
101 OperationStmt *dstLoadOpStmt) {
102 FusionCandidate candidate;
103 // Get store access for src loop nest.
104 getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess);
105 // Get load access for dst loop nest.
106 getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess);
107 return candidate;
MLIR Teamf28e4df2018-11-01 14:26:00108}
109
MLIR Team3b692302018-12-17 17:57:14110namespace {
MLIR Teamf28e4df2018-11-01 14:26:00111
MLIR Team3b692302018-12-17 17:57:14112// LoopNestStateCollector walks loop nests and collects load and store
113// operations, and whether or not an IfStmt was encountered in the loop nest.
114class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> {
115public:
116 SmallVector<ForStmt *, 4> forStmts;
117 SmallVector<OperationStmt *, 4> loadOpStmts;
118 SmallVector<OperationStmt *, 4> storeOpStmts;
119 bool hasIfStmt = false;
120
121 void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
122
123 void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
124
125 void visitOperationStmt(OperationStmt *opStmt) {
126 if (opStmt->isa<LoadOp>())
127 loadOpStmts.push_back(opStmt);
128 if (opStmt->isa<StoreOp>())
129 storeOpStmts.push_back(opStmt);
130 }
131};
132
133// GreedyFusionPolicy greedily fuses loop nests which have a producer/consumer
134// relationship on a memref, with the goal of improving locality. Currently,
135// this the producer/consumer relationship is required to be unique in the
136// MLFunction (there are TODOs to relax this constraint in the future).
MLIR Teamf28e4df2018-11-01 14:26:00137//
MLIR Team3b692302018-12-17 17:57:14138// The steps of the algorithm are as follows:
139//
140// *) Initialize. While visiting each statement in the MLFunction do:
141// *) Assign each top-level ForStmt a 'position' which is its initial
142// position in the MLFunction's StmtBlock at the start of the pass.
143// *) Gather memref load/store state aggregated by top-level statement. For
144// example, all loads and stores contained in a loop nest are aggregated
145// under the loop nest's top-level ForStmt.
146// *) Add each top-level ForStmt to a worklist.
147//
148// *) Run. The algorithm processes the worklist with the following steps:
149// *) The worklist is processed in reverse order (starting from the last
150// top-level ForStmt in the MLFunction).
151// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate
152// destination ForStmt into which fusion will be attempted.
153// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'.
154// *) For each LoadOp in 'dstLoadOps' do:
155// *) Lookup dependent loop nests at earlier positions in the MLFunction
156// which have a single store op to the same memref.
157// *) Check if dependences would be violated by the fusion. For example,
158// the src loop nest may load from memrefs which are different than
159// the producer-consumer memref between src and dest loop nests.
160// *) Get a computation slice of 'srcLoopNest', which adjust its loop
161// bounds to be functions of 'dstLoopNest' IVs and symbols.
162// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
163// just before the dst load op user.
164// *) Add the newly fused load/store operation statements to the state,
165// and also add newly fuse load ops to 'dstLoopOps' to be considered
166// as fusion dst load ops in another iteration.
167// *) Remove old src loop nest and its associated state.
168//
169// Given a graph where top-level statements are vertices in the set 'V' and
170// edges in the set 'E' are dependences between vertices, this algorithm
171// takes O(V) time for initialization, and has runtime O(V * E).
172// TODO(andydavis) Reduce this time complexity to O(V + E).
173//
174// This greedy algorithm is not 'maximally' but there is a TODO to fix this.
175//
176// TODO(andydavis) Experiment with other fusion policies.
177struct GreedyFusionPolicy {
178 // Convenience wrapper with information about 'stmt' ready to access.
179 struct StmtInfo {
180 Statement *stmt;
181 bool isOrContainsIfStmt = false;
182 };
183 // The worklist of top-level loop nest positions.
184 SmallVector<unsigned, 4> worklist;
185 // Mapping from top-level position to StmtInfo.
186 DenseMap<unsigned, StmtInfo> posToStmtInfo;
187 // Mapping from memref MLValue to set of top-level positions of loop nests
188 // which contain load ops on that memref.
189 DenseMap<MLValue *, DenseSet<unsigned>> memrefToLoadPosSet;
190 // Mapping from memref MLValue to set of top-level positions of loop nests
191 // which contain store ops on that memref.
192 DenseMap<MLValue *, DenseSet<unsigned>> memrefToStorePosSet;
193 // Mapping from top-level loop nest to the set of load ops it contains.
194 DenseMap<ForStmt *, SetVector<OperationStmt *>> forStmtToLoadOps;
195 // Mapping from top-level loop nest to the set of store ops it contains.
196 DenseMap<ForStmt *, SetVector<OperationStmt *>> forStmtToStoreOps;
MLIR Teamf28e4df2018-11-01 14:26:00197
MLIR Team3b692302018-12-17 17:57:14198 GreedyFusionPolicy(MLFunction *f) { init(f); }
199
200 void run() {
201 if (hasIfStmts())
202 return;
203
204 while (!worklist.empty()) {
205 // Pop the position of a loop nest into which fusion will be attempted.
206 unsigned dstPos = worklist.back();
207 worklist.pop_back();
208 // Skip if 'dstPos' is not tracked (was fused into another loop nest).
209 if (posToStmtInfo.count(dstPos) == 0)
210 continue;
211 // Get the top-level ForStmt at 'dstPos'.
212 auto *dstForStmt = getForStmtAtPos(dstPos);
213 // Skip if this ForStmt contains no load ops.
214 if (forStmtToLoadOps.count(dstForStmt) == 0)
215 continue;
216
217 // Greedy Policy: iterate through load ops in 'dstForStmt', greedily
218 // fusing in src loop nests which have a single store op on the same
219 // memref, until a fixed point is reached where there is nothing left to
220 // fuse.
221 SetVector<OperationStmt *> dstLoadOps = forStmtToLoadOps[dstForStmt];
222 while (!dstLoadOps.empty()) {
223 auto *dstLoadOpStmt = dstLoadOps.pop_back_val();
224
225 auto dstLoadOp = dstLoadOpStmt->cast<LoadOp>();
226 auto *memref = cast<MLValue>(dstLoadOp->getMemRef());
227 // Skip if not single src store / dst load pair on 'memref'.
228 if (memrefToLoadPosSet[memref].size() != 1 ||
229 memrefToStorePosSet[memref].size() != 1)
230 continue;
231 unsigned srcPos = *memrefToStorePosSet[memref].begin();
232 if (srcPos >= dstPos)
233 continue;
234 auto *srcForStmt = getForStmtAtPos(srcPos);
235 // Skip if 'srcForStmt' has more than one store op.
236 if (forStmtToStoreOps[srcForStmt].size() > 1)
237 continue;
238 // Skip if fusion would violated dependences between 'memref' access
239 // for loop nests between 'srcPos' and 'dstPos':
240 // For each src load op: check for store ops in range (srcPos, dstPos).
241 // For each src store op: check for load ops in range (srcPos, dstPos).
242 if (moveWouldViolateDependences(srcPos, dstPos))
243 continue;
244 auto *srcStoreOpStmt = forStmtToStoreOps[srcForStmt].front();
245 // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'.
246 FusionCandidate candidate =
247 buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt);
248 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
249 auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
250 &candidate.srcAccess, &candidate.dstAccess);
251 if (sliceLoopNest != nullptr) {
252 // Remove 'srcPos' mappings from 'state'.
253 moveAccessesAndRemovePos(srcPos, dstPos);
254 // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'.
255 LoopNestStateCollector collector;
256 collector.walkForStmt(sliceLoopNest);
257 // Record mappings for loads and stores from 'collector'.
258 for (auto *opStmt : collector.loadOpStmts) {
259 addLoadOpStmtAt(dstPos, opStmt, dstForStmt);
260 // Add newly fused load ops to 'dstLoadOps' to be considered for
261 // fusion on subsequent iterations.
262 dstLoadOps.insert(opStmt);
263 }
264 for (auto *opStmt : collector.storeOpStmts) {
265 addStoreOpStmtAt(dstPos, opStmt, dstForStmt);
266 }
267 for (auto *forStmt : collector.forStmts) {
268 promoteIfSingleIteration(forStmt);
269 }
270 // Remove old src loop nest.
271 srcForStmt->erase();
272 }
273 }
274 }
MLIR Teamf28e4df2018-11-01 14:26:00275 }
MLIR Team3b692302018-12-17 17:57:14276
277 // Walk MLFunction 'f' assigning each top-level statement a position, and
278 // gathering state on load and store ops.
279 void init(MLFunction *f) {
280 unsigned pos = 0;
281 for (auto &stmt : *f) {
282 if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
283 // Record all loads and store accesses in 'forStmt' at 'pos'.
284 LoopNestStateCollector collector;
285 collector.walkForStmt(forStmt);
286 // Create StmtInfo for 'forStmt' for top-level loop nests.
287 addStmtInfoAt(pos, forStmt, collector.hasIfStmt);
288 // Record mappings for loads and stores from 'collector'.
289 for (auto *opStmt : collector.loadOpStmts) {
290 addLoadOpStmtAt(pos, opStmt, forStmt);
291 }
292 for (auto *opStmt : collector.storeOpStmts) {
293 addStoreOpStmtAt(pos, opStmt, forStmt);
294 }
295 // Add 'pos' associated with 'forStmt' to worklist.
296 worklist.push_back(pos);
297 }
298 if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
299 if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
300 // Create StmtInfo for top-level load op.
301 addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false);
302 addLoadOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr);
303 }
304 if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
305 // Create StmtInfo for top-level store op.
306 addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false);
307 addStoreOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr);
308 }
309 }
310 if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
311 addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/true);
312 }
313 ++pos;
314 }
MLIR Teamf28e4df2018-11-01 14:26:00315 }
MLIR Team3b692302018-12-17 17:57:14316
317 // Check if fusing loop nest at 'srcPos' into the loop nest at 'dstPos'
318 // would violated any dependences w.r.t other loop nests in that range.
319 bool moveWouldViolateDependences(unsigned srcPos, unsigned dstPos) {
320 // Lookup src ForStmt at 'srcPos'.
321 auto *srcForStmt = getForStmtAtPos(srcPos);
322 // For each src load op: check for store ops in range (srcPos, dstPos).
323 if (forStmtToLoadOps.count(srcForStmt) > 0) {
324 for (auto *opStmt : forStmtToLoadOps[srcForStmt]) {
325 auto loadOp = opStmt->cast<LoadOp>();
326 auto *memref = cast<MLValue>(loadOp->getMemRef());
327 for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) {
328 if (memrefToStorePosSet.count(memref) > 0 &&
329 memrefToStorePosSet[memref].count(pos) > 0)
330 return true;
331 }
332 }
333 }
334 // For each src store op: check for load ops in range (srcPos, dstPos).
335 if (forStmtToStoreOps.count(srcForStmt) > 0) {
336 for (auto *opStmt : forStmtToStoreOps[srcForStmt]) {
337 auto storeOp = opStmt->cast<StoreOp>();
338 auto *memref = cast<MLValue>(storeOp->getMemRef());
339 for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) {
340 if (memrefToLoadPosSet.count(memref) > 0 &&
341 memrefToLoadPosSet[memref].count(pos) > 0)
342 return true;
343 }
344 }
345 }
346 return false;
347 }
348
349 // Update mappings of memref loads and stores at 'srcPos' to 'dstPos'.
350 void moveAccessesAndRemovePos(unsigned srcPos, unsigned dstPos) {
351 // Lookup ForStmt at 'srcPos'.
352 auto *srcForStmt = getForStmtAtPos(srcPos);
353 // Move load op accesses from src to dst.
354 if (forStmtToLoadOps.count(srcForStmt) > 0) {
355 for (auto *opStmt : forStmtToLoadOps[srcForStmt]) {
356 auto loadOp = opStmt->cast<LoadOp>();
357 auto *memref = cast<MLValue>(loadOp->getMemRef());
358 // Remove 'memref' to 'srcPos' mapping.
359 memrefToLoadPosSet[memref].erase(srcPos);
360 }
361 }
362 // Move store op accesses from src to dst.
363 if (forStmtToStoreOps.count(srcForStmt) > 0) {
364 for (auto *opStmt : forStmtToStoreOps[srcForStmt]) {
365 auto storeOp = opStmt->cast<StoreOp>();
366 auto *memref = cast<MLValue>(storeOp->getMemRef());
367 // Remove 'memref' to 'srcPos' mapping.
368 memrefToStorePosSet[memref].erase(srcPos);
369 }
370 }
371 // Remove old state.
372 forStmtToLoadOps.erase(srcForStmt);
373 forStmtToStoreOps.erase(srcForStmt);
374 posToStmtInfo.erase(srcPos);
375 }
376
377 ForStmt *getForStmtAtPos(unsigned pos) {
378 assert(posToStmtInfo.count(pos) > 0);
379 assert(isa<ForStmt>(posToStmtInfo[pos].stmt));
380 return cast<ForStmt>(posToStmtInfo[pos].stmt);
381 }
382
383 void addStmtInfoAt(unsigned pos, Statement *stmt, bool hasIfStmt) {
384 StmtInfo stmtInfo;
385 stmtInfo.stmt = stmt;
386 stmtInfo.isOrContainsIfStmt = hasIfStmt;
387 // Add mapping from 'pos' to StmtInfo for 'forStmt'.
388 posToStmtInfo[pos] = stmtInfo;
389 }
390
391 // Adds the following mappings:
392 // *) 'containingForStmt' to load 'opStmt'
393 // *) 'memref' of load 'opStmt' to 'topLevelPos'.
394 void addLoadOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt,
395 ForStmt *containingForStmt) {
396 if (containingForStmt != nullptr) {
397 // Add mapping from 'containingForStmt' to 'opStmt' for load op.
398 forStmtToLoadOps[containingForStmt].insert(opStmt);
399 }
400 auto loadOp = opStmt->cast<LoadOp>();
401 auto *memref = cast<MLValue>(loadOp->getMemRef());
402 // Add mapping from 'memref' to 'topLevelPos' for load.
403 memrefToLoadPosSet[memref].insert(topLevelPos);
404 }
405
406 // Adds the following mappings:
407 // *) 'containingForStmt' to store 'opStmt'
408 // *) 'memref' of store 'opStmt' to 'topLevelPos'.
409 void addStoreOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt,
410 ForStmt *containingForStmt) {
411 if (containingForStmt != nullptr) {
412 // Add mapping from 'forStmt' to 'opStmt' for store op.
413 forStmtToStoreOps[containingForStmt].insert(opStmt);
414 }
415 auto storeOp = opStmt->cast<StoreOp>();
416 auto *memref = cast<MLValue>(storeOp->getMemRef());
417 // Add mapping from 'memref' to 'topLevelPos' for store.
418 memrefToStorePosSet[memref].insert(topLevelPos);
419 }
420
421 bool hasIfStmts() {
422 for (auto &pair : posToStmtInfo)
423 if (pair.second.isOrContainsIfStmt)
424 return true;
425 return false;
426 }
427};
428
429} // end anonymous namespace
MLIR Teamf28e4df2018-11-01 14:26:00430
431PassResult LoopFusion::runOnMLFunction(MLFunction *f) {
MLIR Team3b692302018-12-17 17:57:14432 GreedyFusionPolicy(f).run();
MLIR Teamf28e4df2018-11-01 14:26:00433 return success();
434}
Jacques Pienaar6f0fb222018-11-07 02:34:18435
436static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");