Factor fusion compute cost calculation out of LoopFusion and into LoopFusionUtils (NFC).

PiperOrigin-RevId: 253797886
diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h
index ccda669..b6d1ea4 100644
--- a/mlir/include/mlir/Transforms/LoopFusionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -24,9 +24,13 @@
 #ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
 #define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
 
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+
 namespace mlir {
 class AffineForOp;
 struct ComputationSliceState;
+class Operation;
 
 // TODO(andydavis) Extend this module to include utility functions for querying
 // fusion cost/storage reduction, and for performing the loop fusion
@@ -54,6 +58,43 @@
 FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
                           unsigned dstLoopDepth,
                           ComputationSliceState *srcSlice);
+
+/// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
+/// and operation count) for a loop nest up until (and including) the innermost
+/// loop body.
+struct LoopNestStats {
+  /// Map from AffineForOp to immediate child AffineForOps in its loop body.
+  llvm::DenseMap<Operation *, llvm::SmallVector<AffineForOp, 2>> loopMap;
+  /// Map from AffineForOp to count of operations in its loop body.
+  llvm::DenseMap<Operation *, uint64_t> opCountMap;
+  /// Map from AffineForOp to its constant trip count.
+  llvm::DenseMap<Operation *, uint64_t> tripCountMap;
+};
+
+/// Collect loop nest statistics (eg. loop trip count and operation count)
+/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
+/// returns false otherwise.
+// TODO(andydavis) Consider moving this to LoopUtils.
+bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats);
+
+/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
+/// Currently, the total cost is computed by counting the total operation
+/// instance count (i.e. total number of operations in the loop body * loop
+/// trip count) for the entire loop nest.
+// TODO(andydavis) Improve this cost model.
+int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats);
+
+/// Computes and returns in 'computeCost', the total compute cost of fusing the
+/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
+/// the total cost is computed by counting the total operation instance count
+/// (i.e. total number of operations in the loop body * loop trip count) for
+/// the entire loop nest.
+/// Returns true on success, failure otherwise (e.g. non-constant trip counts).
+// TODO(andydavis) Improve this cost model.
+bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
+                          AffineForOp dstForOp, LoopNestStats &dstStats,
+                          ComputationSliceState *slice, int64_t *computeCost);
+
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 95890a6..8d2e75b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -732,156 +732,6 @@
   return true;
 }
 
-namespace {
-
-// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
-// and operation count) for a loop nest up until the innermost loop body.
-struct LoopNestStats {
-  // Map from AffineForOp to immediate child AffineForOps in its loop body.
-  DenseMap<Operation *, SmallVector<AffineForOp, 2>> loopMap;
-  // Map from AffineForOp to count of operations in its loop body.
-  DenseMap<Operation *, uint64_t> opCountMap;
-  // Map from AffineForOp to its constant trip count.
-  DenseMap<Operation *, uint64_t> tripCountMap;
-};
-
-// LoopNestStatsCollector walks a single loop nest and gathers per-loop
-// trip count and operation count statistics and records them in 'stats'.
-struct LoopNestStatsCollector {
-  LoopNestStats *stats;
-  bool hasLoopWithNonConstTripCount = false;
-
-  LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
-
-  void collect(Operation *op) {
-    op->walk<AffineForOp>([&](AffineForOp forOp) {
-      auto *forInst = forOp.getOperation();
-      auto *parentInst = forOp.getOperation()->getParentOp();
-      if (parentInst != nullptr) {
-        assert(isa<AffineForOp>(parentInst) && "Expected parent AffineForOp");
-        // Add mapping to 'forOp' from its parent AffineForOp.
-        stats->loopMap[parentInst].push_back(forOp);
-      }
-
-      // Record the number of op operations in the body of 'forOp'.
-      unsigned count = 0;
-      stats->opCountMap[forInst] = 0;
-      for (auto &op : *forOp.getBody()) {
-        if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
-          ++count;
-      }
-      stats->opCountMap[forInst] = count;
-      // Record trip count for 'forOp'. Set flag if trip count is not
-      // constant.
-      Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
-      if (!maybeConstTripCount.hasValue()) {
-        hasLoopWithNonConstTripCount = true;
-        return;
-      }
-      stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
-    });
-  }
-};
-
-// Computes the total cost of the loop nest rooted at 'forOp'.
-// Currently, the total cost is computed by counting the total operation
-// instance count (i.e. total number of operations in the loop bodyloop
-// operation count * loop trip count) for the entire loop nest.
-// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
-// specified in the map when computing the total op instance count.
-// NOTEs: 1) This is used to compute the cost of computation slices, which are
-// sliced along the iteration dimension, and thus reduce the trip count.
-// If 'computeCostMap' is non-null, the total op count for forOps specified
-// in the map is increased (not overridden) by adding the op count from the
-// map to the existing op count for the for loop. This is done before
-// multiplying by the loop's trip count, and is used to model the cost of
-// inserting a sliced loop nest of known cost into the loop's body.
-// 2) This is also used to compute the cost of fusing a slice of some loop nest
-// within another loop.
-static int64_t getComputeCost(
-    Operation *forInst, LoopNestStats *stats,
-    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
-    DenseMap<Operation *, int64_t> *computeCostMap) {
-  // 'opCount' is the total number operations in one iteration of 'forOp' body,
-  // minus terminator op which is a no-op.
-  int64_t opCount = stats->opCountMap[forInst] - 1;
-  if (stats->loopMap.count(forInst) > 0) {
-    for (auto childForOp : stats->loopMap[forInst]) {
-      opCount += getComputeCost(childForOp.getOperation(), stats,
-                                tripCountOverrideMap, computeCostMap);
-    }
-  }
-  // Add in additional op instances from slice (if specified in map).
-  if (computeCostMap != nullptr) {
-    auto it = computeCostMap->find(forInst);
-    if (it != computeCostMap->end()) {
-      opCount += it->second;
-    }
-  }
-  // Override trip count (if specified in map).
-  int64_t tripCount = stats->tripCountMap[forInst];
-  if (tripCountOverrideMap != nullptr) {
-    auto it = tripCountOverrideMap->find(forInst);
-    if (it != tripCountOverrideMap->end()) {
-      tripCount = it->second;
-    }
-  }
-  // Returns the total number of dynamic instances of operations in loop body.
-  return tripCount * opCount;
-}
-
-} // end anonymous namespace
-
-// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
-static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
-  assert(lbMap.getNumResults() == 1 && "expected single result bound map");
-  assert(ubMap.getNumResults() == 1 && "expected single result bound map");
-  assert(lbMap.getNumDims() == ubMap.getNumDims());
-  assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
-  AffineExpr lbExpr(lbMap.getResult(0));
-  AffineExpr ubExpr(ubMap.getResult(0));
-  auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
-                                         lbMap.getNumSymbols());
-  auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
-  if (!cExpr)
-    return None;
-  return cExpr.getValue();
-}
-
-// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
-// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
-// Returns true on success, false otherwise (if a non-constant trip count
-// was encountered).
-// TODO(andydavis) Make this work with non-unit step loops.
-static bool buildSliceTripCountMap(
-    Operation *srcOpInst, ComputationSliceState *sliceState,
-    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
-  SmallVector<AffineForOp, 4> srcLoopIVs;
-  getLoopIVs(*srcOpInst, &srcLoopIVs);
-  unsigned numSrcLoopIVs = srcLoopIVs.size();
-  // Populate map from AffineForOp -> trip count
-  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
-    AffineMap lbMap = sliceState->lbs[i];
-    AffineMap ubMap = sliceState->ubs[i];
-    if (lbMap == AffineMap() || ubMap == AffineMap()) {
-      // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
-      if (srcLoopIVs[i].hasConstantLowerBound() &&
-          srcLoopIVs[i].hasConstantUpperBound()) {
-        (*tripCountMap)[srcLoopIVs[i].getOperation()] =
-            srcLoopIVs[i].getConstantUpperBound() -
-            srcLoopIVs[i].getConstantLowerBound();
-        continue;
-      }
-      return false;
-    }
-    Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
-    if (!tripCount.hasValue())
-      return false;
-    (*tripCountMap)[srcLoopIVs[i].getOperation()] = tripCount.getValue();
-  }
-  return true;
-}
-
 // Removes load operations from 'srcLoads' which operate on 'memref', and
 // adds them to 'dstLoads'.
 static void moveLoadsAccessingMemrefTo(Value *memref,
@@ -1110,16 +960,6 @@
   return newMemRef;
 }
 
-// Return the number of iterations in the given slice.
-static uint64_t getSliceIterationCount(
-    const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
-  uint64_t iterCount = 1;
-  for (const auto &count : sliceTripCountMap) {
-    iterCount *= count.second;
-  }
-  return iterCount;
-}
-
 // Checks if node 'srcId' (which writes to a live out memref), can be safely
 // fused into node 'dstId'. Returns true if the following conditions are met:
 // *) 'srcNode' only writes to live out 'memref'.
@@ -1250,25 +1090,16 @@
 
   // Walk src loop nest and collect stats.
   LoopNestStats srcLoopNestStats;
-  LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
-  srcStatsCollector.collect(srcLoopIVs[0].getOperation());
-  // Currently only constant trip count loop nests are supported.
-  if (srcStatsCollector.hasLoopWithNonConstTripCount) {
-    LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
+  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
     return false;
-  }
+
   // Compute cost of dst loop nest.
   SmallVector<AffineForOp, 4> dstLoopIVs;
   getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
 
   LoopNestStats dstLoopNestStats;
-  LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
-  dstStatsCollector.collect(dstLoopIVs[0].getOperation());
-  // Currently only constant trip count loop nests are supported.
-  if (dstStatsCollector.hasLoopWithNonConstTripCount) {
-    LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n");
+  if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
     return false;
-  }
 
   // Compute the maximum loop depth at which we can can insert the src slice
   // and still satisfy dest loop nest dependences, for producer-consumer fusion.
@@ -1297,10 +1128,7 @@
   Optional<unsigned> bestDstLoopDepth = None;
 
   // Compute op instance count for the src loop nest without iteration slicing.
-  uint64_t srcLoopNestCost =
-      getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
-                     /*tripCountOverrideMap=*/nullptr,
-                     /*computeCostMap=*/nullptr);
+  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
 
   // Compute src loop nest write region size.
   MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
@@ -1317,15 +1145,10 @@
   int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
 
   // Compute op instance count for the src loop nest.
-  uint64_t dstLoopNestCost =
-      getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
-                     /*tripCountOverrideMap=*/nullptr,
-                     /*computeCostMap=*/nullptr);
+  uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
 
   // Evaluate all depth choices for materializing the slice in the destination
   // loop nest.
-  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
-  DenseMap<Operation *, int64_t> computeCostMap;
   for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
     // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
     if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
@@ -1338,47 +1161,14 @@
       continue;
     }
 
-    // Build trip count map for computation slice. We'll skip cases where the
-    // trip count was non-constant.
-    sliceTripCountMap.clear();
-    if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
-                                &sliceTripCountMap)) {
-      LLVM_DEBUG(llvm::dbgs() << "Unable to build slice trip count map.\n.");
+    int64_t fusedLoopNestComputeCost;
+    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
+                              dstLoopNestStats, &sliceStates[i - 1],
+                              &fusedLoopNestComputeCost)) {
+      LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
       continue;
     }
 
-    // Checks whether a store to load forwarding will happen.
-    int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
-    assert(sliceIterationCount > 0);
-    bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
-
-    // Compute cost of fusion for this dest loop depth.
-
-    computeCostMap.clear();
-
-    // The store and loads to this memref will disappear.
-    // TODO(andydavis) Add load coalescing to memref data flow opt pass.
-    if (storeLoadFwdGuaranteed) {
-      // A single store disappears: -1 for that.
-      computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getOperation()] = -1;
-      for (auto *loadOp : dstLoadOpInsts)
-        if (auto forOp = dyn_cast_or_null<AffineForOp>(loadOp->getParentOp()))
-          computeCostMap[forOp] = -1;
-    }
-
-    // Compute op instance count for the src loop nest with iteration slicing.
-    int64_t sliceComputeCost =
-        getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats,
-                       /*tripCountOverrideMap=*/&sliceTripCountMap,
-                       /*computeCostMap=*/&computeCostMap);
-
-    // Compute cost of fusion for this depth.
-    computeCostMap[dstLoopIVs[i - 1].getOperation()] = sliceComputeCost;
-
-    int64_t fusedLoopNestComputeCost =
-        getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats,
-                       /*tripCountOverrideMap=*/nullptr, &computeCostMap);
-
     double additionalComputeFraction =
         fusedLoopNestComputeCost /
             (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
@@ -1427,7 +1217,6 @@
           << 100.0 * additionalComputeFraction << "%\n"
           << "   storage reduction factor: " << storageReduction << "x\n"
           << "   fused nest cost: " << fusedLoopNestComputeCost << "\n"
-          << "   slice iteration count: " << sliceIterationCount << "\n"
           << "   src write region size: " << srcWriteRegionSizeBytes << "\n"
           << "   slice write region size: " << sliceWriteRegionSizeBytes
           << "\n";
diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 1fb41a2..93503d1 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -24,6 +24,7 @@
 #include "mlir/AffineOps/AffineOps.h"
 #include "mlir/Analysis/AffineAnalysis.h"
 #include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Analysis/Utils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
@@ -250,3 +251,236 @@
 
   return FusionResult::Success;
 }
+
+/// Collect loop nest statistics (eg. loop trip count and operation count)
+/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
+/// returns false otherwise.
+bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
+  bool ret = true;
+  forOpRoot.getOperation()->walk<AffineForOp>([&](AffineForOp forOp) {
+    auto *childForOp = forOp.getOperation();
+    auto *parentForOp = forOp.getOperation()->getParentOp();
+    if (parentForOp != nullptr) {
+      if (!isa<AffineForOp>(parentForOp)) {
+        LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
+        ret = false;
+        return;
+      }
+      // Add mapping to 'forOp' from its parent AffineForOp.
+      stats->loopMap[parentForOp].push_back(forOp);
+    }
+
+    // Record the number of op operations in the body of 'forOp'.
+    unsigned count = 0;
+    stats->opCountMap[childForOp] = 0;
+    for (auto &op : *forOp.getBody()) {
+      if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
+        ++count;
+    }
+    stats->opCountMap[childForOp] = count;
+    // Record trip count for 'forOp'. Set flag if trip count is not
+    // constant.
+    Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+    if (!maybeConstTripCount.hasValue()) {
+      // Currently only constant trip count loop nests are supported.
+      LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
+      ret = false;
+      return;
+    }
+    stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
+  });
+  return ret;
+}
+
+// Computes the total cost of the loop nest rooted at 'forOp'.
+// Currently, the total cost is computed by counting the total operation
+// instance count (i.e. total number of operations in the loop bodyloop
+// operation count * loop trip count) for the entire loop nest.
+// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
+// specified in the map when computing the total op instance count.
+// NOTEs: 1) This is used to compute the cost of computation slices, which are
+// sliced along the iteration dimension, and thus reduce the trip count.
+// If 'computeCostMap' is non-null, the total op count for forOps specified
+// in the map is increased (not overridden) by adding the op count from the
+// map to the existing op count for the for loop. This is done before
+// multiplying by the loop's trip count, and is used to model the cost of
+// inserting a sliced loop nest of known cost into the loop's body.
+// 2) This is also used to compute the cost of fusing a slice of some loop nest
+// within another loop.
+static int64_t getComputeCostHelper(
+    Operation *forOp, LoopNestStats &stats,
+    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
+    DenseMap<Operation *, int64_t> *computeCostMap) {
+  // 'opCount' is the total number operations in one iteration of 'forOp' body,
+  // minus terminator op which is a no-op.
+  int64_t opCount = stats.opCountMap[forOp] - 1;
+  if (stats.loopMap.count(forOp) > 0) {
+    for (auto childForOp : stats.loopMap[forOp]) {
+      opCount += getComputeCostHelper(childForOp.getOperation(), stats,
+                                      tripCountOverrideMap, computeCostMap);
+    }
+  }
+  // Add in additional op instances from slice (if specified in map).
+  if (computeCostMap != nullptr) {
+    auto it = computeCostMap->find(forOp);
+    if (it != computeCostMap->end()) {
+      opCount += it->second;
+    }
+  }
+  // Override trip count (if specified in map).
+  int64_t tripCount = stats.tripCountMap[forOp];
+  if (tripCountOverrideMap != nullptr) {
+    auto it = tripCountOverrideMap->find(forOp);
+    if (it != tripCountOverrideMap->end()) {
+      tripCount = it->second;
+    }
+  }
+  // Returns the total number of dynamic instances of operations in loop body.
+  return tripCount * opCount;
+}
+
+// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
+static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
+  assert(lbMap.getNumResults() == 1 && "expected single result bound map");
+  assert(ubMap.getNumResults() == 1 && "expected single result bound map");
+  assert(lbMap.getNumDims() == ubMap.getNumDims());
+  assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
+  AffineExpr lbExpr(lbMap.getResult(0));
+  AffineExpr ubExpr(ubMap.getResult(0));
+  auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
+                                         lbMap.getNumSymbols());
+  auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
+  if (!cExpr)
+    return None;
+  return cExpr.getValue();
+}
+
+// Return the number of iterations in the given slice.
+static uint64_t getSliceIterationCount(
+    const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
+  uint64_t iterCount = 1;
+  for (const auto &count : sliceTripCountMap) {
+    iterCount *= count.second;
+  }
+  return iterCount;
+}
+
+// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
+// nest surrounding represented by slice loop bounds in 'slice'.
+// Returns true on success, false otherwise (if a non-constant trip count
+// was encountered).
+// TODO(andydavis) Make this work with non-unit step loops.
+static bool buildSliceTripCountMap(
+    ComputationSliceState *slice,
+    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
+  unsigned numSrcLoopIVs = slice->ivs.size();
+  // Populate map from AffineForOp -> trip count
+  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
+    AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]);
+    auto *op = forOp.getOperation();
+    AffineMap lbMap = slice->lbs[i];
+    AffineMap ubMap = slice->ubs[i];
+    if (lbMap == AffineMap() || ubMap == AffineMap()) {
+      // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
+      if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
+        (*tripCountMap)[op] =
+            forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
+        continue;
+      }
+      Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+      if (maybeConstTripCount.hasValue()) {
+        (*tripCountMap)[op] = maybeConstTripCount.getValue();
+        continue;
+      }
+      return false;
+    }
+    Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
+    // Slice bounds are created with a constant ub - lb difference.
+    if (!tripCount.hasValue())
+      return false;
+    (*tripCountMap)[op] = tripCount.getValue();
+  }
+  return true;
+}
+
+/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
+/// Currently, the total cost is computed by counting the total operation
+/// instance count (i.e. total number of operations in the loop body * loop
+/// trip count) for the entire loop nest.
+int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
+  return getComputeCostHelper(forOp.getOperation(), stats,
+                              /*tripCountOverrideMap=*/nullptr,
+                              /*computeCostMap=*/nullptr);
+}
+
+/// Computes and returns in 'computeCost', the total compute cost of fusing the
+/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
+/// the total cost is computed by counting the total operation instance count
+/// (i.e. total number of operations in the loop body * loop trip count) for
+/// the entire loop nest.
+bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
+                                AffineForOp dstForOp, LoopNestStats &dstStats,
+                                ComputationSliceState *slice,
+                                int64_t *computeCost) {
+  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
+  DenseMap<Operation *, int64_t> computeCostMap;
+
+  // Build trip count map for computation slice.
+  if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
+    return false;
+  // Checks whether a store to load forwarding will happen.
+  int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
+  assert(sliceIterationCount > 0);
+  bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
+  auto *insertPointParent = slice->insertPoint->getParentOp();
+
+  // The store and loads to this memref will disappear.
+  // TODO(andydavis) Add load coalescing to memref data flow opt pass.
+  if (storeLoadFwdGuaranteed) {
+    // Subtract from operation count the loads/store we expect load/store
+    // forwarding to remove.
+    unsigned storeCount = 0;
+    llvm::SmallDenseSet<Value *, 4> storeMemrefs;
+    srcForOp.getOperation()->walk([&](Operation *op) {
+      if (auto storeOp = dyn_cast<StoreOp>(op)) {
+        storeMemrefs.insert(storeOp.getMemRef());
+        ++storeCount;
+      }
+    });
+    // Subtract out any store ops in single-iteration src slice loop nest.
+    if (storeCount > 0)
+      computeCostMap[insertPointParent] = -storeCount;
+    // Subtract out any load users of 'storeMemrefs' nested below
+    // 'insertPointParent'.
+    for (auto *value : storeMemrefs) {
+      for (auto *user : value->getUsers()) {
+        if (auto loadOp = dyn_cast<LoadOp>(user)) {
+          SmallVector<AffineForOp, 4> loops;
+          // Check if any loop in loop nest surrounding 'user' is
+          // 'insertPointParent'.
+          getLoopIVs(*user, &loops);
+          if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
+            if (auto forOp =
+                    dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
+              if (computeCostMap.count(forOp) == 0)
+                computeCostMap[forOp] = 0;
+              computeCostMap[forOp] -= 1;
+            }
+          }
+        }
+      }
+    }
+  }
+
+  // Compute op instance count for the src loop nest with iteration slicing.
+  int64_t sliceComputeCost = getComputeCostHelper(
+      srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
+
+  // Compute cost of fusion for this depth.
+  computeCostMap[insertPointParent] = sliceComputeCost;
+
+  *computeCost =
+      getComputeCostHelper(dstForOp.getOperation(), dstStats,
+                           /*tripCountOverrideMap=*/nullptr, &computeCostMap);
+  return true;
+}