Factor fusion compute cost calculation out of LoopFusion and into LoopFusionUtils (NFC).
PiperOrigin-RevId: 253797886
diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h
index ccda669..b6d1ea4 100644
--- a/mlir/include/mlir/Transforms/LoopFusionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -24,9 +24,13 @@
#ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
#define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+
namespace mlir {
class AffineForOp;
struct ComputationSliceState;
+class Operation;
// TODO(andydavis) Extend this module to include utility functions for querying
// fusion cost/storage reduction, and for performing the loop fusion
@@ -54,6 +58,43 @@
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
unsigned dstLoopDepth,
ComputationSliceState *srcSlice);
+
+/// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
+/// and operation count) for a loop nest up until (and including) the innermost
+/// loop body.
+struct LoopNestStats {
+ /// Map from AffineForOp to immediate child AffineForOps in its loop body.
+ llvm::DenseMap<Operation *, llvm::SmallVector<AffineForOp, 2>> loopMap;
+ /// Map from AffineForOp to count of operations in its loop body.
+ llvm::DenseMap<Operation *, uint64_t> opCountMap;
+ /// Map from AffineForOp to its constant trip count.
+ llvm::DenseMap<Operation *, uint64_t> tripCountMap;
+};
+
+/// Collect loop nest statistics (eg. loop trip count and operation count)
+/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
+/// returns false otherwise.
+// TODO(andydavis) Consider moving this to LoopUtils.
+bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats);
+
+/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
+/// Currently, the total cost is computed by counting the total operation
+/// instance count (i.e. total number of operations in the loop body * loop
+/// trip count) for the entire loop nest.
+// TODO(andydavis) Improve this cost model.
+int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats);
+
+/// Computes and returns in 'computeCost', the total compute cost of fusing the
+/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
+/// the total cost is computed by counting the total operation instance count
+/// (i.e. total number of operations in the loop body * loop trip count) for
+/// the entire loop nest.
+/// Returns true on success, failure otherwise (e.g. non-constant trip counts).
+// TODO(andydavis) Improve this cost model.
+bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
+ AffineForOp dstForOp, LoopNestStats &dstStats,
+ ComputationSliceState *slice, int64_t *computeCost);
+
} // end namespace mlir
#endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 95890a6..8d2e75b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -732,156 +732,6 @@
return true;
}
-namespace {
-
-// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
-// and operation count) for a loop nest up until the innermost loop body.
-struct LoopNestStats {
- // Map from AffineForOp to immediate child AffineForOps in its loop body.
- DenseMap<Operation *, SmallVector<AffineForOp, 2>> loopMap;
- // Map from AffineForOp to count of operations in its loop body.
- DenseMap<Operation *, uint64_t> opCountMap;
- // Map from AffineForOp to its constant trip count.
- DenseMap<Operation *, uint64_t> tripCountMap;
-};
-
-// LoopNestStatsCollector walks a single loop nest and gathers per-loop
-// trip count and operation count statistics and records them in 'stats'.
-struct LoopNestStatsCollector {
- LoopNestStats *stats;
- bool hasLoopWithNonConstTripCount = false;
-
- LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
-
- void collect(Operation *op) {
- op->walk<AffineForOp>([&](AffineForOp forOp) {
- auto *forInst = forOp.getOperation();
- auto *parentInst = forOp.getOperation()->getParentOp();
- if (parentInst != nullptr) {
- assert(isa<AffineForOp>(parentInst) && "Expected parent AffineForOp");
- // Add mapping to 'forOp' from its parent AffineForOp.
- stats->loopMap[parentInst].push_back(forOp);
- }
-
- // Record the number of op operations in the body of 'forOp'.
- unsigned count = 0;
- stats->opCountMap[forInst] = 0;
- for (auto &op : *forOp.getBody()) {
- if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
- ++count;
- }
- stats->opCountMap[forInst] = count;
- // Record trip count for 'forOp'. Set flag if trip count is not
- // constant.
- Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
- if (!maybeConstTripCount.hasValue()) {
- hasLoopWithNonConstTripCount = true;
- return;
- }
- stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
- });
- }
-};
-
-// Computes the total cost of the loop nest rooted at 'forOp'.
-// Currently, the total cost is computed by counting the total operation
-// instance count (i.e. total number of operations in the loop bodyloop
-// operation count * loop trip count) for the entire loop nest.
-// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
-// specified in the map when computing the total op instance count.
-// NOTEs: 1) This is used to compute the cost of computation slices, which are
-// sliced along the iteration dimension, and thus reduce the trip count.
-// If 'computeCostMap' is non-null, the total op count for forOps specified
-// in the map is increased (not overridden) by adding the op count from the
-// map to the existing op count for the for loop. This is done before
-// multiplying by the loop's trip count, and is used to model the cost of
-// inserting a sliced loop nest of known cost into the loop's body.
-// 2) This is also used to compute the cost of fusing a slice of some loop nest
-// within another loop.
-static int64_t getComputeCost(
- Operation *forInst, LoopNestStats *stats,
- llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
- DenseMap<Operation *, int64_t> *computeCostMap) {
- // 'opCount' is the total number operations in one iteration of 'forOp' body,
- // minus terminator op which is a no-op.
- int64_t opCount = stats->opCountMap[forInst] - 1;
- if (stats->loopMap.count(forInst) > 0) {
- for (auto childForOp : stats->loopMap[forInst]) {
- opCount += getComputeCost(childForOp.getOperation(), stats,
- tripCountOverrideMap, computeCostMap);
- }
- }
- // Add in additional op instances from slice (if specified in map).
- if (computeCostMap != nullptr) {
- auto it = computeCostMap->find(forInst);
- if (it != computeCostMap->end()) {
- opCount += it->second;
- }
- }
- // Override trip count (if specified in map).
- int64_t tripCount = stats->tripCountMap[forInst];
- if (tripCountOverrideMap != nullptr) {
- auto it = tripCountOverrideMap->find(forInst);
- if (it != tripCountOverrideMap->end()) {
- tripCount = it->second;
- }
- }
- // Returns the total number of dynamic instances of operations in loop body.
- return tripCount * opCount;
-}
-
-} // end anonymous namespace
-
-// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
-static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
- assert(lbMap.getNumResults() == 1 && "expected single result bound map");
- assert(ubMap.getNumResults() == 1 && "expected single result bound map");
- assert(lbMap.getNumDims() == ubMap.getNumDims());
- assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
- AffineExpr lbExpr(lbMap.getResult(0));
- AffineExpr ubExpr(ubMap.getResult(0));
- auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
- lbMap.getNumSymbols());
- auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
- if (!cExpr)
- return None;
- return cExpr.getValue();
-}
-
-// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
-// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
-// Returns true on success, false otherwise (if a non-constant trip count
-// was encountered).
-// TODO(andydavis) Make this work with non-unit step loops.
-static bool buildSliceTripCountMap(
- Operation *srcOpInst, ComputationSliceState *sliceState,
- llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
- SmallVector<AffineForOp, 4> srcLoopIVs;
- getLoopIVs(*srcOpInst, &srcLoopIVs);
- unsigned numSrcLoopIVs = srcLoopIVs.size();
- // Populate map from AffineForOp -> trip count
- for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
- AffineMap lbMap = sliceState->lbs[i];
- AffineMap ubMap = sliceState->ubs[i];
- if (lbMap == AffineMap() || ubMap == AffineMap()) {
- // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
- if (srcLoopIVs[i].hasConstantLowerBound() &&
- srcLoopIVs[i].hasConstantUpperBound()) {
- (*tripCountMap)[srcLoopIVs[i].getOperation()] =
- srcLoopIVs[i].getConstantUpperBound() -
- srcLoopIVs[i].getConstantLowerBound();
- continue;
- }
- return false;
- }
- Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
- if (!tripCount.hasValue())
- return false;
- (*tripCountMap)[srcLoopIVs[i].getOperation()] = tripCount.getValue();
- }
- return true;
-}
-
// Removes load operations from 'srcLoads' which operate on 'memref', and
// adds them to 'dstLoads'.
static void moveLoadsAccessingMemrefTo(Value *memref,
@@ -1110,16 +960,6 @@
return newMemRef;
}
-// Return the number of iterations in the given slice.
-static uint64_t getSliceIterationCount(
- const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
- uint64_t iterCount = 1;
- for (const auto &count : sliceTripCountMap) {
- iterCount *= count.second;
- }
- return iterCount;
-}
-
// Checks if node 'srcId' (which writes to a live out memref), can be safely
// fused into node 'dstId'. Returns true if the following conditions are met:
// *) 'srcNode' only writes to live out 'memref'.
@@ -1250,25 +1090,16 @@
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
- LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
- srcStatsCollector.collect(srcLoopIVs[0].getOperation());
- // Currently only constant trip count loop nests are supported.
- if (srcStatsCollector.hasLoopWithNonConstTripCount) {
- LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
+ if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
return false;
- }
+
// Compute cost of dst loop nest.
SmallVector<AffineForOp, 4> dstLoopIVs;
getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
LoopNestStats dstLoopNestStats;
- LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
- dstStatsCollector.collect(dstLoopIVs[0].getOperation());
- // Currently only constant trip count loop nests are supported.
- if (dstStatsCollector.hasLoopWithNonConstTripCount) {
- LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
+ if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
return false;
- }
// Compute the maximum loop depth at which we can can insert the src slice
// and still satisfy dest loop nest dependences, for producer-consumer fusion.
@@ -1297,10 +1128,7 @@
Optional<unsigned> bestDstLoopDepth = None;
// Compute op instance count for the src loop nest without iteration slicing.
- uint64_t srcLoopNestCost =
- getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
// Compute src loop nest write region size.
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
@@ -1317,15 +1145,10 @@
int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
// Compute op instance count for the src loop nest.
- uint64_t dstLoopNestCost =
- getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
// Evaluate all depth choices for materializing the slice in the destination
// loop nest.
- llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
- DenseMap<Operation *, int64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
// Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
@@ -1338,47 +1161,14 @@
continue;
}
- // Build trip count map for computation slice. We'll skip cases where the
- // trip count was non-constant.
- sliceTripCountMap.clear();
- if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
- &sliceTripCountMap)) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to build slice trip count map.\n.");
+ int64_t fusedLoopNestComputeCost;
+ if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
+ dstLoopNestStats, &sliceStates[i - 1],
+ &fusedLoopNestComputeCost)) {
+ LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
continue;
}
- // Checks whether a store to load forwarding will happen.
- int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
- assert(sliceIterationCount > 0);
- bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
-
- // Compute cost of fusion for this dest loop depth.
-
- computeCostMap.clear();
-
- // The store and loads to this memref will disappear.
- // TODO(andydavis) Add load coalescing to memref data flow opt pass.
- if (storeLoadFwdGuaranteed) {
- // A single store disappears: -1 for that.
- computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getOperation()] = -1;
- for (auto *loadOp : dstLoadOpInsts)
- if (auto forOp = dyn_cast_or_null<AffineForOp>(loadOp->getParentOp()))
- computeCostMap[forOp] = -1;
- }
-
- // Compute op instance count for the src loop nest with iteration slicing.
- int64_t sliceComputeCost =
- getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
- /*tripCountOverrideMap=*/&sliceTripCountMap,
- /*computeCostMap=*/&computeCostMap);
-
- // Compute cost of fusion for this depth.
- computeCostMap[dstLoopIVs[i - 1].getOperation()] = sliceComputeCost;
-
- int64_t fusedLoopNestComputeCost =
- getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
- /*tripCountOverrideMap=*/nullptr, &computeCostMap);
-
double additionalComputeFraction =
fusedLoopNestComputeCost /
(static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
@@ -1427,7 +1217,6 @@
<< 100.0 * additionalComputeFraction << "%\n"
<< " storage reduction factor: " << storageReduction << "x\n"
<< " fused nest cost: " << fusedLoopNestComputeCost << "\n"
- << " slice iteration count: " << sliceIterationCount << "\n"
<< " src write region size: " << srcWriteRegionSizeBytes << "\n"
<< " slice write region size: " << sliceWriteRegionSizeBytes
<< "\n";
diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 1fb41a2..93503d1 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -24,6 +24,7 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -250,3 +251,236 @@
return FusionResult::Success;
}
+
+/// Collect loop nest statistics (eg. loop trip count and operation count)
+/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
+/// returns false otherwise.
+bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
+ bool ret = true;
+ forOpRoot.getOperation()->walk<AffineForOp>([&](AffineForOp forOp) {
+ auto *childForOp = forOp.getOperation();
+ auto *parentForOp = forOp.getOperation()->getParentOp();
+ if (parentForOp != nullptr) {
+ if (!isa<AffineForOp>(parentForOp)) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
+ ret = false;
+ return;
+ }
+ // Add mapping to 'forOp' from its parent AffineForOp.
+ stats->loopMap[parentForOp].push_back(forOp);
+ }
+
+ // Record the number of op operations in the body of 'forOp'.
+ unsigned count = 0;
+ stats->opCountMap[childForOp] = 0;
+ for (auto &op : *forOp.getBody()) {
+ if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
+ ++count;
+ }
+ stats->opCountMap[childForOp] = count;
+ // Record trip count for 'forOp'. Set flag if trip count is not
+ // constant.
+ Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+ if (!maybeConstTripCount.hasValue()) {
+ // Currently only constant trip count loop nests are supported.
+ LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
+ ret = false;
+ return;
+ }
+ stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
+ });
+ return ret;
+}
+
+// Computes the total cost of the loop nest rooted at 'forOp'.
+// Currently, the total cost is computed by counting the total operation
+// instance count (i.e. total number of operations in the loop bodyloop
+// operation count * loop trip count) for the entire loop nest.
+// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
+// specified in the map when computing the total op instance count.
+// NOTEs: 1) This is used to compute the cost of computation slices, which are
+// sliced along the iteration dimension, and thus reduce the trip count.
+// If 'computeCostMap' is non-null, the total op count for forOps specified
+// in the map is increased (not overridden) by adding the op count from the
+// map to the existing op count for the for loop. This is done before
+// multiplying by the loop's trip count, and is used to model the cost of
+// inserting a sliced loop nest of known cost into the loop's body.
+// 2) This is also used to compute the cost of fusing a slice of some loop nest
+// within another loop.
+static int64_t getComputeCostHelper(
+ Operation *forOp, LoopNestStats &stats,
+ llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
+ DenseMap<Operation *, int64_t> *computeCostMap) {
+ // 'opCount' is the total number operations in one iteration of 'forOp' body,
+ // minus terminator op which is a no-op.
+ int64_t opCount = stats.opCountMap[forOp] - 1;
+ if (stats.loopMap.count(forOp) > 0) {
+ for (auto childForOp : stats.loopMap[forOp]) {
+ opCount += getComputeCostHelper(childForOp.getOperation(), stats,
+ tripCountOverrideMap, computeCostMap);
+ }
+ }
+ // Add in additional op instances from slice (if specified in map).
+ if (computeCostMap != nullptr) {
+ auto it = computeCostMap->find(forOp);
+ if (it != computeCostMap->end()) {
+ opCount += it->second;
+ }
+ }
+ // Override trip count (if specified in map).
+ int64_t tripCount = stats.tripCountMap[forOp];
+ if (tripCountOverrideMap != nullptr) {
+ auto it = tripCountOverrideMap->find(forOp);
+ if (it != tripCountOverrideMap->end()) {
+ tripCount = it->second;
+ }
+ }
+ // Returns the total number of dynamic instances of operations in loop body.
+ return tripCount * opCount;
+}
+
+// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
+static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
+ assert(lbMap.getNumResults() == 1 && "expected single result bound map");
+ assert(ubMap.getNumResults() == 1 && "expected single result bound map");
+ assert(lbMap.getNumDims() == ubMap.getNumDims());
+ assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
+ AffineExpr lbExpr(lbMap.getResult(0));
+ AffineExpr ubExpr(ubMap.getResult(0));
+ auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
+ lbMap.getNumSymbols());
+ auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
+ if (!cExpr)
+ return None;
+ return cExpr.getValue();
+}
+
+// Return the number of iterations in the given slice.
+static uint64_t getSliceIterationCount(
+ const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
+ uint64_t iterCount = 1;
+ for (const auto &count : sliceTripCountMap) {
+ iterCount *= count.second;
+ }
+ return iterCount;
+}
+
+// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
+// nest surrounding represented by slice loop bounds in 'slice'.
+// Returns true on success, false otherwise (if a non-constant trip count
+// was encountered).
+// TODO(andydavis) Make this work with non-unit step loops.
+static bool buildSliceTripCountMap(
+ ComputationSliceState *slice,
+ llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
+ unsigned numSrcLoopIVs = slice->ivs.size();
+ // Populate map from AffineForOp -> trip count
+ for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
+ AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]);
+ auto *op = forOp.getOperation();
+ AffineMap lbMap = slice->lbs[i];
+ AffineMap ubMap = slice->ubs[i];
+ if (lbMap == AffineMap() || ubMap == AffineMap()) {
+ // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
+ if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
+ (*tripCountMap)[op] =
+ forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
+ continue;
+ }
+ Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+ if (maybeConstTripCount.hasValue()) {
+ (*tripCountMap)[op] = maybeConstTripCount.getValue();
+ continue;
+ }
+ return false;
+ }
+ Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
+ // Slice bounds are created with a constant ub - lb difference.
+ if (!tripCount.hasValue())
+ return false;
+ (*tripCountMap)[op] = tripCount.getValue();
+ }
+ return true;
+}
+
+/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
+/// Currently, the total cost is computed by counting the total operation
+/// instance count (i.e. total number of operations in the loop body * loop
+/// trip count) for the entire loop nest.
+int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
+ return getComputeCostHelper(forOp.getOperation(), stats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
+}
+
+/// Computes and returns in 'computeCost', the total compute cost of fusing the
+/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
+/// the total cost is computed by counting the total operation instance count
+/// (i.e. total number of operations in the loop body * loop trip count) for
+/// the entire loop nest.
+bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
+ AffineForOp dstForOp, LoopNestStats &dstStats,
+ ComputationSliceState *slice,
+ int64_t *computeCost) {
+ llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
+ DenseMap<Operation *, int64_t> computeCostMap;
+
+ // Build trip count map for computation slice.
+ if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
+ return false;
+ // Checks whether a store to load forwarding will happen.
+ int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
+ assert(sliceIterationCount > 0);
+ bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
+ auto *insertPointParent = slice->insertPoint->getParentOp();
+
+ // The store and loads to this memref will disappear.
+ // TODO(andydavis) Add load coalescing to memref data flow opt pass.
+ if (storeLoadFwdGuaranteed) {
+ // Subtract from operation count the loads/store we expect load/store
+ // forwarding to remove.
+ unsigned storeCount = 0;
+ llvm::SmallDenseSet<Value *, 4> storeMemrefs;
+ srcForOp.getOperation()->walk([&](Operation *op) {
+ if (auto storeOp = dyn_cast<StoreOp>(op)) {
+ storeMemrefs.insert(storeOp.getMemRef());
+ ++storeCount;
+ }
+ });
+ // Subtract out any store ops in single-iteration src slice loop nest.
+ if (storeCount > 0)
+ computeCostMap[insertPointParent] = -storeCount;
+ // Subtract out any load users of 'storeMemrefs' nested below
+ // 'insertPointParent'.
+ for (auto *value : storeMemrefs) {
+ for (auto *user : value->getUsers()) {
+ if (auto loadOp = dyn_cast<LoadOp>(user)) {
+ SmallVector<AffineForOp, 4> loops;
+ // Check if any loop in loop nest surrounding 'user' is
+ // 'insertPointParent'.
+ getLoopIVs(*user, &loops);
+ if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
+ if (auto forOp =
+ dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
+ if (computeCostMap.count(forOp) == 0)
+ computeCostMap[forOp] = 0;
+ computeCostMap[forOp] -= 1;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Compute op instance count for the src loop nest with iteration slicing.
+ int64_t sliceComputeCost = getComputeCostHelper(
+ srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
+
+ // Compute cost of fusion for this depth.
+ computeCostMap[insertPointParent] = sliceComputeCost;
+
+ *computeCost =
+ getComputeCostHelper(dstForOp.getOperation(), dstStats,
+ /*tripCountOverrideMap=*/nullptr, &computeCostMap);
+ return true;
+}