Factor fusion compute cost calculation out of LoopFusion and into LoopFusionUtils (NFC).

PiperOrigin-RevId: 253797886
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 95890a6..8d2e75b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -732,156 +732,6 @@
   return true;
 }
 
-namespace {
-
-// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
-// and operation count) for a loop nest up until the innermost loop body.
-struct LoopNestStats {
-  // Map from AffineForOp to immediate child AffineForOps in its loop body.
-  DenseMap<Operation *, SmallVector<AffineForOp, 2>> loopMap;
-  // Map from AffineForOp to count of operations in its loop body.
-  DenseMap<Operation *, uint64_t> opCountMap;
-  // Map from AffineForOp to its constant trip count.
-  DenseMap<Operation *, uint64_t> tripCountMap;
-};
-
-// LoopNestStatsCollector walks a single loop nest and gathers per-loop
-// trip count and operation count statistics and records them in 'stats'.
-struct LoopNestStatsCollector {
-  LoopNestStats *stats;
-  bool hasLoopWithNonConstTripCount = false;
-
-  LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
-
-  void collect(Operation *op) {
-    op->walk<AffineForOp>([&](AffineForOp forOp) {
-      auto *forInst = forOp.getOperation();
-      auto *parentInst = forOp.getOperation()->getParentOp();
-      if (parentInst != nullptr) {
-        assert(isa<AffineForOp>(parentInst) && "Expected parent AffineForOp");
-        // Add mapping to 'forOp' from its parent AffineForOp.
-        stats->loopMap[parentInst].push_back(forOp);
-      }
-
-      // Record the number of op operations in the body of 'forOp'.
-      unsigned count = 0;
-      stats->opCountMap[forInst] = 0;
-      for (auto &op : *forOp.getBody()) {
-        if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
-          ++count;
-      }
-      stats->opCountMap[forInst] = count;
-      // Record trip count for 'forOp'. Set flag if trip count is not
-      // constant.
-      Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
-      if (!maybeConstTripCount.hasValue()) {
-        hasLoopWithNonConstTripCount = true;
-        return;
-      }
-      stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
-    });
-  }
-};
-
-// Computes the total cost of the loop nest rooted at 'forOp'.
-// Currently, the total cost is computed by counting the total operation
-// instance count (i.e. total number of operations in the loop bodyloop
-// operation count * loop trip count) for the entire loop nest.
-// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
-// specified in the map when computing the total op instance count.
-// NOTEs: 1) This is used to compute the cost of computation slices, which are
-// sliced along the iteration dimension, and thus reduce the trip count.
-// If 'computeCostMap' is non-null, the total op count for forOps specified
-// in the map is increased (not overridden) by adding the op count from the
-// map to the existing op count for the for loop. This is done before
-// multiplying by the loop's trip count, and is used to model the cost of
-// inserting a sliced loop nest of known cost into the loop's body.
-// 2) This is also used to compute the cost of fusing a slice of some loop nest
-// within another loop.
-static int64_t getComputeCost(
-    Operation *forInst, LoopNestStats *stats,
-    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
-    DenseMap<Operation *, int64_t> *computeCostMap) {
-  // 'opCount' is the total number operations in one iteration of 'forOp' body,
-  // minus terminator op which is a no-op.
-  int64_t opCount = stats->opCountMap[forInst] - 1;
-  if (stats->loopMap.count(forInst) > 0) {
-    for (auto childForOp : stats->loopMap[forInst]) {
-      opCount += getComputeCost(childForOp.getOperation(), stats,
-                                tripCountOverrideMap, computeCostMap);
-    }
-  }
-  // Add in additional op instances from slice (if specified in map).
-  if (computeCostMap != nullptr) {
-    auto it = computeCostMap->find(forInst);
-    if (it != computeCostMap->end()) {
-      opCount += it->second;
-    }
-  }
-  // Override trip count (if specified in map).
-  int64_t tripCount = stats->tripCountMap[forInst];
-  if (tripCountOverrideMap != nullptr) {
-    auto it = tripCountOverrideMap->find(forInst);
-    if (it != tripCountOverrideMap->end()) {
-      tripCount = it->second;
-    }
-  }
-  // Returns the total number of dynamic instances of operations in loop body.
-  return tripCount * opCount;
-}
-
-} // end anonymous namespace
-
-// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
-static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
-  assert(lbMap.getNumResults() == 1 && "expected single result bound map");
-  assert(ubMap.getNumResults() == 1 && "expected single result bound map");
-  assert(lbMap.getNumDims() == ubMap.getNumDims());
-  assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
-  AffineExpr lbExpr(lbMap.getResult(0));
-  AffineExpr ubExpr(ubMap.getResult(0));
-  auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
-                                         lbMap.getNumSymbols());
-  auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
-  if (!cExpr)
-    return None;
-  return cExpr.getValue();
-}
-
-// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
-// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
-// Returns true on success, false otherwise (if a non-constant trip count
-// was encountered).
-// TODO(andydavis) Make this work with non-unit step loops.
-static bool buildSliceTripCountMap(
-    Operation *srcOpInst, ComputationSliceState *sliceState,
-    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
-  SmallVector<AffineForOp, 4> srcLoopIVs;
-  getLoopIVs(*srcOpInst, &srcLoopIVs);
-  unsigned numSrcLoopIVs = srcLoopIVs.size();
-  // Populate map from AffineForOp -> trip count
-  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
-    AffineMap lbMap = sliceState->lbs[i];
-    AffineMap ubMap = sliceState->ubs[i];
-    if (lbMap == AffineMap() || ubMap == AffineMap()) {
-      // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
-      if (srcLoopIVs[i].hasConstantLowerBound() &&
-          srcLoopIVs[i].hasConstantUpperBound()) {
-        (*tripCountMap)[srcLoopIVs[i].getOperation()] =
-            srcLoopIVs[i].getConstantUpperBound() -
-            srcLoopIVs[i].getConstantLowerBound();
-        continue;
-      }
-      return false;
-    }
-    Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
-    if (!tripCount.hasValue())
-      return false;
-    (*tripCountMap)[srcLoopIVs[i].getOperation()] = tripCount.getValue();
-  }
-  return true;
-}
-
 // Removes load operations from 'srcLoads' which operate on 'memref', and
 // adds them to 'dstLoads'.
 static void moveLoadsAccessingMemrefTo(Value *memref,
@@ -1110,16 +960,6 @@
   return newMemRef;
 }
 
-// Return the number of iterations in the given slice.
-static uint64_t getSliceIterationCount(
-    const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
-  uint64_t iterCount = 1;
-  for (const auto &count : sliceTripCountMap) {
-    iterCount *= count.second;
-  }
-  return iterCount;
-}
-
 // Checks if node 'srcId' (which writes to a live out memref), can be safely
 // fused into node 'dstId'. Returns true if the following conditions are met:
 // *) 'srcNode' only writes to live out 'memref'.
@@ -1250,25 +1090,16 @@
 
   // Walk src loop nest and collect stats.
   LoopNestStats srcLoopNestStats;
-  LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
-  srcStatsCollector.collect(srcLoopIVs[0].getOperation());
-  // Currently only constant trip count loop nests are supported.
-  if (srcStatsCollector.hasLoopWithNonConstTripCount) {
-    LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
+  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
     return false;
-  }
+
   // Compute cost of dst loop nest.
   SmallVector<AffineForOp, 4> dstLoopIVs;
   getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
 
   LoopNestStats dstLoopNestStats;
-  LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
-  dstStatsCollector.collect(dstLoopIVs[0].getOperation());
-  // Currently only constant trip count loop nests are supported.
-  if (dstStatsCollector.hasLoopWithNonConstTripCount) {
-    LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
+  if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
     return false;
-  }
 
   // Compute the maximum loop depth at which we can can insert the src slice
   // and still satisfy dest loop nest dependences, for producer-consumer fusion.
@@ -1297,10 +1128,7 @@
   Optional<unsigned> bestDstLoopDepth = None;
 
   // Compute op instance count for the src loop nest without iteration slicing.
-  uint64_t srcLoopNestCost =
-      getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
-                     /*tripCountOverrideMap=*/nullptr,
-                     /*computeCostMap=*/nullptr);
+  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
 
   // Compute src loop nest write region size.
   MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
@@ -1317,15 +1145,10 @@
   int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
 
   // Compute op instance count for the src loop nest.
-  uint64_t dstLoopNestCost =
-      getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
-                     /*tripCountOverrideMap=*/nullptr,
-                     /*computeCostMap=*/nullptr);
+  uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
 
   // Evaluate all depth choices for materializing the slice in the destination
   // loop nest.
-  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
-  DenseMap<Operation *, int64_t> computeCostMap;
   for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
     // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
     if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
@@ -1338,47 +1161,14 @@
       continue;
     }
 
-    // Build trip count map for computation slice. We'll skip cases where the
-    // trip count was non-constant.
-    sliceTripCountMap.clear();
-    if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
-                                &sliceTripCountMap)) {
-      LLVM_DEBUG(llvm::dbgs() << "Unable to build slice trip count map.\n.");
+    int64_t fusedLoopNestComputeCost;
+    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
+                              dstLoopNestStats, &sliceStates[i - 1],
+                              &fusedLoopNestComputeCost)) {
+      LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
       continue;
     }
 
-    // Checks whether a store to load forwarding will happen.
-    int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
-    assert(sliceIterationCount > 0);
-    bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
-
-    // Compute cost of fusion for this dest loop depth.
-
-    computeCostMap.clear();
-
-    // The store and loads to this memref will disappear.
-    // TODO(andydavis) Add load coalescing to memref data flow opt pass.
-    if (storeLoadFwdGuaranteed) {
-      // A single store disappears: -1 for that.
-      computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getOperation()] = -1;
-      for (auto *loadOp : dstLoadOpInsts)
-        if (auto forOp = dyn_cast_or_null<AffineForOp>(loadOp->getParentOp()))
-          computeCostMap[forOp] = -1;
-    }
-
-    // Compute op instance count for the src loop nest with iteration slicing.
-    int64_t sliceComputeCost =
-        getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
-                       /*tripCountOverrideMap=*/&sliceTripCountMap,
-                       /*computeCostMap=*/&computeCostMap);
-
-    // Compute cost of fusion for this depth.
-    computeCostMap[dstLoopIVs[i - 1].getOperation()] = sliceComputeCost;
-
-    int64_t fusedLoopNestComputeCost =
-        getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
-                       /*tripCountOverrideMap=*/nullptr, &computeCostMap);
-
     double additionalComputeFraction =
         fusedLoopNestComputeCost /
             (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
@@ -1427,7 +1217,6 @@
           << 100.0 * additionalComputeFraction << "%\n"
           << "   storage reduction factor: " << storageReduction << "x\n"
           << "   fused nest cost: " << fusedLoopNestComputeCost << "\n"
-          << "   slice iteration count: " << sliceIterationCount << "\n"
           << "   src write region size: " << srcWriteRegionSizeBytes << "\n"
           << "   slice write region size: " << sliceWriteRegionSizeBytes
           << "\n";