Factor fusion compute cost calculation out of LoopFusion and into LoopFusionUtils (NFC).
PiperOrigin-RevId: 253797886
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 95890a6..8d2e75b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -732,156 +732,6 @@
return true;
}
-namespace {
-
-// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
-// and operation count) for a loop nest up until the innermost loop body.
-struct LoopNestStats {
- // Map from AffineForOp to immediate child AffineForOps in its loop body.
- DenseMap<Operation *, SmallVector<AffineForOp, 2>> loopMap;
- // Map from AffineForOp to count of operations in its loop body.
- DenseMap<Operation *, uint64_t> opCountMap;
- // Map from AffineForOp to its constant trip count.
- DenseMap<Operation *, uint64_t> tripCountMap;
-};
-
-// LoopNestStatsCollector walks a single loop nest and gathers per-loop
-// trip count and operation count statistics and records them in 'stats'.
-struct LoopNestStatsCollector {
- LoopNestStats *stats;
- bool hasLoopWithNonConstTripCount = false;
-
- LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
-
- void collect(Operation *op) {
- op->walk<AffineForOp>([&](AffineForOp forOp) {
- auto *forInst = forOp.getOperation();
- auto *parentInst = forOp.getOperation()->getParentOp();
- if (parentInst != nullptr) {
- assert(isa<AffineForOp>(parentInst) && "Expected parent AffineForOp");
- // Add mapping to 'forOp' from its parent AffineForOp.
- stats->loopMap[parentInst].push_back(forOp);
- }
-
- // Record the number of op operations in the body of 'forOp'.
- unsigned count = 0;
- stats->opCountMap[forInst] = 0;
- for (auto &op : *forOp.getBody()) {
- if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
- ++count;
- }
- stats->opCountMap[forInst] = count;
- // Record trip count for 'forOp'. Set flag if trip count is not
- // constant.
- Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
- if (!maybeConstTripCount.hasValue()) {
- hasLoopWithNonConstTripCount = true;
- return;
- }
- stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
- });
- }
-};
-
-// Computes the total cost of the loop nest rooted at 'forOp'.
-// Currently, the total cost is computed by counting the total operation
-// instance count (i.e. total number of operations in the loop bodyloop
-// operation count * loop trip count) for the entire loop nest.
-// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
-// specified in the map when computing the total op instance count.
-// NOTEs: 1) This is used to compute the cost of computation slices, which are
-// sliced along the iteration dimension, and thus reduce the trip count.
-// If 'computeCostMap' is non-null, the total op count for forOps specified
-// in the map is increased (not overridden) by adding the op count from the
-// map to the existing op count for the for loop. This is done before
-// multiplying by the loop's trip count, and is used to model the cost of
-// inserting a sliced loop nest of known cost into the loop's body.
-// 2) This is also used to compute the cost of fusing a slice of some loop nest
-// within another loop.
-static int64_t getComputeCost(
- Operation *forInst, LoopNestStats *stats,
- llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
- DenseMap<Operation *, int64_t> *computeCostMap) {
- // 'opCount' is the total number operations in one iteration of 'forOp' body,
- // minus terminator op which is a no-op.
- int64_t opCount = stats->opCountMap[forInst] - 1;
- if (stats->loopMap.count(forInst) > 0) {
- for (auto childForOp : stats->loopMap[forInst]) {
- opCount += getComputeCost(childForOp.getOperation(), stats,
- tripCountOverrideMap, computeCostMap);
- }
- }
- // Add in additional op instances from slice (if specified in map).
- if (computeCostMap != nullptr) {
- auto it = computeCostMap->find(forInst);
- if (it != computeCostMap->end()) {
- opCount += it->second;
- }
- }
- // Override trip count (if specified in map).
- int64_t tripCount = stats->tripCountMap[forInst];
- if (tripCountOverrideMap != nullptr) {
- auto it = tripCountOverrideMap->find(forInst);
- if (it != tripCountOverrideMap->end()) {
- tripCount = it->second;
- }
- }
- // Returns the total number of dynamic instances of operations in loop body.
- return tripCount * opCount;
-}
-
-} // end anonymous namespace
-
-// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
-static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
- assert(lbMap.getNumResults() == 1 && "expected single result bound map");
- assert(ubMap.getNumResults() == 1 && "expected single result bound map");
- assert(lbMap.getNumDims() == ubMap.getNumDims());
- assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
- AffineExpr lbExpr(lbMap.getResult(0));
- AffineExpr ubExpr(ubMap.getResult(0));
- auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
- lbMap.getNumSymbols());
- auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
- if (!cExpr)
- return None;
- return cExpr.getValue();
-}
-
-// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
-// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
-// Returns true on success, false otherwise (if a non-constant trip count
-// was encountered).
-// TODO(andydavis) Make this work with non-unit step loops.
-static bool buildSliceTripCountMap(
- Operation *srcOpInst, ComputationSliceState *sliceState,
- llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
- SmallVector<AffineForOp, 4> srcLoopIVs;
- getLoopIVs(*srcOpInst, &srcLoopIVs);
- unsigned numSrcLoopIVs = srcLoopIVs.size();
- // Populate map from AffineForOp -> trip count
- for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
- AffineMap lbMap = sliceState->lbs[i];
- AffineMap ubMap = sliceState->ubs[i];
- if (lbMap == AffineMap() || ubMap == AffineMap()) {
- // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
- if (srcLoopIVs[i].hasConstantLowerBound() &&
- srcLoopIVs[i].hasConstantUpperBound()) {
- (*tripCountMap)[srcLoopIVs[i].getOperation()] =
- srcLoopIVs[i].getConstantUpperBound() -
- srcLoopIVs[i].getConstantLowerBound();
- continue;
- }
- return false;
- }
- Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
- if (!tripCount.hasValue())
- return false;
- (*tripCountMap)[srcLoopIVs[i].getOperation()] = tripCount.getValue();
- }
- return true;
-}
-
// Removes load operations from 'srcLoads' which operate on 'memref', and
// adds them to 'dstLoads'.
static void moveLoadsAccessingMemrefTo(Value *memref,
@@ -1110,16 +960,6 @@
return newMemRef;
}
-// Return the number of iterations in the given slice.
-static uint64_t getSliceIterationCount(
- const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
- uint64_t iterCount = 1;
- for (const auto &count : sliceTripCountMap) {
- iterCount *= count.second;
- }
- return iterCount;
-}
-
// Checks if node 'srcId' (which writes to a live out memref), can be safely
// fused into node 'dstId'. Returns true if the following conditions are met:
// *) 'srcNode' only writes to live out 'memref'.
@@ -1250,25 +1090,16 @@
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
- LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
- srcStatsCollector.collect(srcLoopIVs[0].getOperation());
- // Currently only constant trip count loop nests are supported.
- if (srcStatsCollector.hasLoopWithNonConstTripCount) {
- LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
+ if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
return false;
- }
+
// Compute cost of dst loop nest.
SmallVector<AffineForOp, 4> dstLoopIVs;
getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
LoopNestStats dstLoopNestStats;
- LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
- dstStatsCollector.collect(dstLoopIVs[0].getOperation());
- // Currently only constant trip count loop nests are supported.
- if (dstStatsCollector.hasLoopWithNonConstTripCount) {
- LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
+ if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
return false;
- }
// Compute the maximum loop depth at which we can can insert the src slice
// and still satisfy dest loop nest dependences, for producer-consumer fusion.
@@ -1297,10 +1128,7 @@
Optional<unsigned> bestDstLoopDepth = None;
// Compute op instance count for the src loop nest without iteration slicing.
- uint64_t srcLoopNestCost =
- getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
// Compute src loop nest write region size.
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
@@ -1317,15 +1145,10 @@
int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
// Compute op instance count for the src loop nest.
- uint64_t dstLoopNestCost =
- getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
// Evaluate all depth choices for materializing the slice in the destination
// loop nest.
- llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
- DenseMap<Operation *, int64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
// Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
@@ -1338,47 +1161,14 @@
continue;
}
- // Build trip count map for computation slice. We'll skip cases where the
- // trip count was non-constant.
- sliceTripCountMap.clear();
- if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
- &sliceTripCountMap)) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to build slice trip count map.\n.");
+ int64_t fusedLoopNestComputeCost;
+ if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
+ dstLoopNestStats, &sliceStates[i - 1],
+ &fusedLoopNestComputeCost)) {
+ LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
continue;
}
- // Checks whether a store to load forwarding will happen.
- int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
- assert(sliceIterationCount > 0);
- bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
-
- // Compute cost of fusion for this dest loop depth.
-
- computeCostMap.clear();
-
- // The store and loads to this memref will disappear.
- // TODO(andydavis) Add load coalescing to memref data flow opt pass.
- if (storeLoadFwdGuaranteed) {
- // A single store disappears: -1 for that.
- computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getOperation()] = -1;
- for (auto *loadOp : dstLoadOpInsts)
- if (auto forOp = dyn_cast_or_null<AffineForOp>(loadOp->getParentOp()))
- computeCostMap[forOp] = -1;
- }
-
- // Compute op instance count for the src loop nest with iteration slicing.
- int64_t sliceComputeCost =
- getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
- /*tripCountOverrideMap=*/&sliceTripCountMap,
- /*computeCostMap=*/&computeCostMap);
-
- // Compute cost of fusion for this depth.
- computeCostMap[dstLoopIVs[i - 1].getOperation()] = sliceComputeCost;
-
- int64_t fusedLoopNestComputeCost =
- getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
- /*tripCountOverrideMap=*/nullptr, &computeCostMap);
-
double additionalComputeFraction =
fusedLoopNestComputeCost /
(static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
@@ -1427,7 +1217,6 @@
<< 100.0 * additionalComputeFraction << "%\n"
<< " storage reduction factor: " << storageReduction << "x\n"
<< " fused nest cost: " << fusedLoopNestComputeCost << "\n"
- << " slice iteration count: " << sliceIterationCount << "\n"
<< " src write region size: " << srcWriteRegionSizeBytes << "\n"
<< " slice write region size: " << sliceWriteRegionSizeBytes
<< "\n";