Promote local buffers created post fusion to higher memory space

- fusion already includes the necessary analysis to create small/local buffers
  post fusion; allocate these buffers in a higher memory space if the necessary
  pass parameters are provided (threshold size, memory space id)

- although there will be a separate utility at some point to directly detect
  and promote small local buffers to higher memory spaces, doing it while fusion
  when possible is much less expensive, comes free with fusion analysis, and covers
  a key common case.

PiperOrigin-RevId: 232063894
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 7d4ff03..5091e3c 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -63,6 +63,17 @@
                    " computation tolerated while fusing"),
     llvm::cl::cat(clOptionsCategory));
 
+static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
+    "fusion-fast-mem-space", llvm::cl::Hidden,
+    llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
+    llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<unsigned> clFusionLocalBufThreshold(
+    "fusion-local-buf-threshold", llvm::cl::Hidden,
+    llvm::cl::desc("Threshold size (bytes) for promoting local buffers to fast "
+                   "memory space"),
+    llvm::cl::cat(clOptionsCategory));
+
 namespace {
 
 /// Loop fusion pass. This pass currently supports a greedy fusion policy,
@@ -80,6 +91,11 @@
   PassResult runOnFunction(Function *f) override;
   static char passID;
 
+  // Any local buffers smaller than this size will be created in
+  // `fastMemorySpace` if provided.
+  unsigned localBufSizeThreshold = 1024;
+  Optional<unsigned> fastMemorySpace = None;
+
   // The amount of additional computation that is tolerated while fusing
   // pair-wise as a fraction of the total computation.
   constexpr static double kComputeToleranceThreshold = 0.30f;
@@ -876,6 +892,21 @@
   return true;
 }
 
+//  TODO(mlir-team): improve/complete this when we have target data.
+unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
+  auto elementType = memRefType.getElementType();
+
+  unsigned sizeInBits;
+  if (elementType.isIntOrFloat()) {
+    sizeInBits = elementType.getIntOrFloatBitWidth();
+  } else {
+    auto vectorType = elementType.cast<VectorType>();
+    sizeInBits =
+        vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+  }
+  return llvm::divideCeil(sizeInBits, 8);
+}
+
 // Creates and returns a private (single-user) memref for fused loop rooted
 // at 'forOp', with (potentially reduced) memref size based on the
 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
@@ -883,7 +914,9 @@
 // this one.
 static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
                                   OperationInst *srcStoreOpInst,
-                                  unsigned dstLoopDepth) {
+                                  unsigned dstLoopDepth,
+                                  Optional<unsigned> fastMemorySpace,
+                                  unsigned localBufSizeThreshold) {
   auto *forInst = forOp->getInstruction();
 
   // Create builder to insert alloc op just before 'forOp'.
@@ -906,7 +939,8 @@
   // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
   Optional<int64_t> numElements =
       region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
-  assert(numElements.hasValue());
+  assert(numElements.hasValue() &&
+         "non-constant number of elts in local buffer");
 
   const FlatAffineConstraints *cst = region.getConstraints();
   // 'outerIVs' holds the values that this memory region is symbolic/paramteric
@@ -933,9 +967,16 @@
 
   // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
   // by 'srcStoreOpInst'.
-  auto newMemRefType =
-      top.getMemRefType(newShape, oldMemRefType.getElementType(), {},
-                        oldMemRefType.getMemorySpace());
+  uint64_t bufSize =
+      getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
+  unsigned newMemSpace;
+  if (bufSize < localBufSizeThreshold && fastMemorySpace.hasValue()) {
+    newMemSpace = fastMemorySpace.getValue();
+  } else {
+    newMemSpace = oldMemRefType.getMemorySpace();
+  }
+  auto newMemRefType = top.getMemRefType(
+      newShape, oldMemRefType.getElementType(), {}, newMemSpace);
   // Gather alloc operands for the dynamic dimensions of the memref.
   SmallVector<Value *, 4> allocOperands;
   unsigned dynamicDimCount = 0;
@@ -1343,7 +1384,7 @@
     std::iota(worklist.begin(), worklist.end(), 0);
   }
 
-  void run() {
+  void run(unsigned localBufSizeThreshold, Optional<unsigned> fastMemorySpace) {
     while (!worklist.empty()) {
       unsigned dstId = worklist.back();
       worklist.pop_back();
@@ -1455,7 +1496,8 @@
             }
             assert(storesForMemref.size() == 1);
             auto *newMemRef = createPrivateMemRef(
-                dstAffineForOp, storesForMemref[0], bestDstLoopDepth);
+                dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
+                fastMemorySpace, localBufSizeThreshold);
             visitedMemrefs.insert(newMemRef);
             // Create new node in dependence graph for 'newMemRef' alloc op.
             unsigned newMemRefNodeId =
@@ -1510,9 +1552,13 @@
 } // end anonymous namespace
 
 PassResult LoopFusion::runOnFunction(Function *f) {
+  if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
+    fastMemorySpace = clFusionFastMemorySpace.getValue();
+  }
+
   MemRefDependenceGraph g;
   if (g.init(f))
-    GreedyFusion(&g).run();
+    GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace);
   return success();
 }