LoopFusion: adds support for computing forward computation slices, which will enable fusion of consumer loop nests into their producers in subsequent CLs.
PiperOrigin-RevId: 253601994
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 829b1b22..95890a6 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -1329,7 +1329,9 @@
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
// Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
- /*dstLoopDepth=*/i,
+ /*loopDepth=*/i,
+ /*numCommonLoops=*/0,
+ /*isBackwardSlice=*/true,
&sliceStates[i - 1]))) {
LLVM_DEBUG(llvm::dbgs()
<< "computeSliceUnion failed for loopDepth: " << i << "\n");
@@ -1736,15 +1738,16 @@
dstLoadOpInsts, dstStoreOpInsts, &sliceState,
&bestDstLoopDepth, maximalFusion))
continue;
- // TODO(andydavis) Remove assert and surrounding code when
- // canFuseLoops is fully functional.
+ // TODO(andydavis) Remove the following test code when canFuseLoops
+ // is fully functional.
mlir::ComputationSliceState sliceUnion;
- FusionResult result = mlir::canFuseLoops(
- cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
- bestDstLoopDepth, &sliceUnion);
- assert(result.value == FusionResult::Success);
- (void)result;
-
+ if (!maximalFusion) {
+ FusionResult result = mlir::canFuseLoops(
+ cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
+ bestDstLoopDepth, &sliceUnion);
+ assert(result.value == FusionResult::Success);
+ (void)result;
+ }
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);