[IR] Remove the AtomicMem*Inst helper classes (#138710)

Migrate their usage to the `AnyMem*Inst` family, and add a isAtomic()
query on the base class for that hierarchy. This matches the idioms we
use for e.g. isAtomic on load, store, etc.. instructions, the existing
isVolatile idioms on mem* routines, and allows us to more easily share
code between atomic and non-atomic variants.

As with #138568, the goal here is to simplify the class hierarchy and
make it easier to reason about. I'm moving from easiest to hardest, and
will stop at some point when I hit "good enough". Longer term, I'd sorta
like to merge or reverse the naming on the plain Mem*Inst and the
AnyMem*Inst, but that's a much larger and more risky change. Not sure
I'm going to actually do that.
diff --git a/llvm/include/llvm/Analysis/MemoryLocation.h b/llvm/include/llvm/Analysis/MemoryLocation.h
index c046e0e..2de5601 100644
--- a/llvm/include/llvm/Analysis/MemoryLocation.h
+++ b/llvm/include/llvm/Analysis/MemoryLocation.h
@@ -30,8 +30,6 @@
 class MemTransferInst;
 class MemIntrinsic;
 class AtomicCmpXchgInst;
-class AtomicMemTransferInst;
-class AtomicMemIntrinsic;
 class AtomicRMWInst;
 class AnyMemTransferInst;
 class AnyMemIntrinsic;
@@ -253,13 +251,11 @@
 
   /// Return a location representing the source of a memory transfer.
   static MemoryLocation getForSource(const MemTransferInst *MTI);
-  static MemoryLocation getForSource(const AtomicMemTransferInst *MTI);
   static MemoryLocation getForSource(const AnyMemTransferInst *MTI);
 
   /// Return a location representing the destination of a memory set or
   /// transfer.
   static MemoryLocation getForDest(const MemIntrinsic *MI);
-  static MemoryLocation getForDest(const AtomicMemIntrinsic *MI);
   static MemoryLocation getForDest(const AnyMemIntrinsic *MI);
   static std::optional<MemoryLocation> getForDest(const CallBase *CI,
                                                   const TargetLibraryInfo &TLI);
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index 48b3067..ea9257b 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1107,100 +1107,6 @@
   }
 };
 
-// The common base class for the atomic memset/memmove/memcpy intrinsics
-// i.e. llvm.element.unordered.atomic.memset/memcpy/memmove
-class AtomicMemIntrinsic : public MemIntrinsicBase<AtomicMemIntrinsic> {
-private:
-  enum { ARG_ELEMENTSIZE = 3 };
-
-public:
-  Value *getRawElementSizeInBytes() const {
-    return const_cast<Value *>(getArgOperand(ARG_ELEMENTSIZE));
-  }
-
-  ConstantInt *getElementSizeInBytesCst() const {
-    return cast<ConstantInt>(getRawElementSizeInBytes());
-  }
-
-  uint32_t getElementSizeInBytes() const {
-    return getElementSizeInBytesCst()->getZExtValue();
-  }
-
-  void setElementSizeInBytes(Constant *V) {
-    assert(V->getType() == Type::getInt8Ty(getContext()) &&
-           "setElementSizeInBytes called with value of wrong type!");
-    setArgOperand(ARG_ELEMENTSIZE, V);
-  }
-
-  static bool classof(const IntrinsicInst *I) {
-    switch (I->getIntrinsicID()) {
-    case Intrinsic::memcpy_element_unordered_atomic:
-    case Intrinsic::memmove_element_unordered_atomic:
-    case Intrinsic::memset_element_unordered_atomic:
-      return true;
-    default:
-      return false;
-    }
-  }
-  static bool classof(const Value *V) {
-    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
-  }
-};
-
-/// This class represents atomic memset intrinsic
-// i.e. llvm.element.unordered.atomic.memset
-class AtomicMemSetInst : public MemSetBase<AtomicMemIntrinsic> {
-public:
-  static bool classof(const IntrinsicInst *I) {
-    return I->getIntrinsicID() == Intrinsic::memset_element_unordered_atomic;
-  }
-  static bool classof(const Value *V) {
-    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
-  }
-};
-
-// This class wraps the atomic memcpy/memmove intrinsics
-// i.e. llvm.element.unordered.atomic.memcpy/memmove
-class AtomicMemTransferInst : public MemTransferBase<AtomicMemIntrinsic> {
-public:
-  static bool classof(const IntrinsicInst *I) {
-    switch (I->getIntrinsicID()) {
-    case Intrinsic::memcpy_element_unordered_atomic:
-    case Intrinsic::memmove_element_unordered_atomic:
-      return true;
-    default:
-      return false;
-    }
-  }
-  static bool classof(const Value *V) {
-    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
-  }
-};
-
-/// This class represents the atomic memcpy intrinsic
-/// i.e. llvm.element.unordered.atomic.memcpy
-class AtomicMemCpyInst : public AtomicMemTransferInst {
-public:
-  static bool classof(const IntrinsicInst *I) {
-    return I->getIntrinsicID() == Intrinsic::memcpy_element_unordered_atomic;
-  }
-  static bool classof(const Value *V) {
-    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
-  }
-};
-
-/// This class represents the atomic memmove intrinsic
-/// i.e. llvm.element.unordered.atomic.memmove
-class AtomicMemMoveInst : public AtomicMemTransferInst {
-public:
-  static bool classof(const IntrinsicInst *I) {
-    return I->getIntrinsicID() == Intrinsic::memmove_element_unordered_atomic;
-  }
-  static bool classof(const Value *V) {
-    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
-  }
-};
-
 /// This is the common base class for memset/memcpy/memmove.
 class MemIntrinsic : public MemIntrinsicBase<MemIntrinsic> {
 private:
@@ -1345,6 +1251,9 @@
 // i.e. llvm.element.unordered.atomic.memset/memcpy/memmove
 //  and llvm.memset/memcpy/memmove
 class AnyMemIntrinsic : public MemIntrinsicBase<AnyMemIntrinsic> {
+private:
+  enum { ARG_ELEMENTSIZE = 3 };
+
 public:
   bool isVolatile() const {
     // Only the non-atomic intrinsics can be volatile
@@ -1353,6 +1262,17 @@
     return false;
   }
 
+  bool isAtomic() const {
+    switch (getIntrinsicID()) {
+    case Intrinsic::memcpy_element_unordered_atomic:
+    case Intrinsic::memmove_element_unordered_atomic:
+    case Intrinsic::memset_element_unordered_atomic:
+      return true;
+    default:
+      return false;
+    }
+  }
+
   static bool classof(const IntrinsicInst *I) {
     switch (I->getIntrinsicID()) {
     case Intrinsic::memcpy:
@@ -1371,6 +1291,16 @@
   static bool classof(const Value *V) {
     return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
   }
+
+  Value *getRawElementSizeInBytes() const {
+    assert(isAtomic());
+    return const_cast<Value *>(getArgOperand(ARG_ELEMENTSIZE));
+  }
+
+  uint32_t getElementSizeInBytes() const {
+    assert(isAtomic());
+    return cast<ConstantInt>(getRawElementSizeInBytes())->getZExtValue();
+  }
 };
 
 /// This class represents any memset intrinsic
diff --git a/llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h b/llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h
index 1007d28..6a2c44a 100644
--- a/llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h
+++ b/llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h
@@ -19,7 +19,7 @@
 
 namespace llvm {
 
-class AtomicMemCpyInst;
+class AnyMemCpyInst;
 class ConstantInt;
 class Instruction;
 class MemCpyInst;
@@ -62,10 +62,10 @@
 void expandMemSetPatternAsLoop(MemSetPatternInst *MemSet);
 
 /// Expand \p AtomicMemCpy as a loop. \p AtomicMemCpy is not deleted.
-void expandAtomicMemCpyAsLoop(AtomicMemCpyInst *AtomicMemCpy,
+void expandAtomicMemCpyAsLoop(AnyMemCpyInst *AtomicMemCpy,
                               const TargetTransformInfo &TTI,
                               ScalarEvolution *SE);
 
-} // End llvm namespace
+} // namespace llvm
 
 #endif
diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
index 6e32327..3b42bb4 100644
--- a/llvm/lib/Analysis/MemoryLocation.cpp
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -95,10 +95,6 @@
   return getForSource(cast<AnyMemTransferInst>(MTI));
 }
 
-MemoryLocation MemoryLocation::getForSource(const AtomicMemTransferInst *MTI) {
-  return getForSource(cast<AnyMemTransferInst>(MTI));
-}
-
 MemoryLocation MemoryLocation::getForSource(const AnyMemTransferInst *MTI) {
   assert(MTI->getRawSource() == MTI->getArgOperand(1));
   return getForArgument(MTI, 1, nullptr);
@@ -108,10 +104,6 @@
   return getForDest(cast<AnyMemIntrinsic>(MI));
 }
 
-MemoryLocation MemoryLocation::getForDest(const AtomicMemIntrinsic *MI) {
-  return getForDest(cast<AnyMemIntrinsic>(MI));
-}
-
 MemoryLocation MemoryLocation::getForDest(const AnyMemIntrinsic *MI) {
   assert(MI->getRawDest() == MI->getArgOperand(0));
   return getForArgument(MI, 0, nullptr);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 97ce20b..9d138d3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6525,7 +6525,7 @@
     return;
   }
   case Intrinsic::memcpy_element_unordered_atomic: {
-    const AtomicMemCpyInst &MI = cast<AtomicMemCpyInst>(I);
+    auto &MI = cast<AnyMemCpyInst>(I);
     SDValue Dst = getValue(MI.getRawDest());
     SDValue Src = getValue(MI.getRawSource());
     SDValue Length = getValue(MI.getLength());
@@ -6541,7 +6541,7 @@
     return;
   }
   case Intrinsic::memmove_element_unordered_atomic: {
-    auto &MI = cast<AtomicMemMoveInst>(I);
+    auto &MI = cast<AnyMemMoveInst>(I);
     SDValue Dst = getValue(MI.getRawDest());
     SDValue Src = getValue(MI.getRawSource());
     SDValue Length = getValue(MI.getLength());
@@ -6557,7 +6557,7 @@
     return;
   }
   case Intrinsic::memset_element_unordered_atomic: {
-    auto &MI = cast<AtomicMemSetInst>(I);
+    auto &MI = cast<AnyMemSetInst>(I);
     SDValue Dst = getValue(MI.getRawDest());
     SDValue Val = getValue(MI.getValue());
     SDValue Length = getValue(MI.getLength());
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 089bd99..8adb85e 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -228,7 +228,7 @@
   CallInst *CI =
       CreateIntrinsic(Intrinsic::memset_element_unordered_atomic, Tys, Ops);
 
-  cast<AtomicMemSetInst>(CI)->setDestAlignment(Alignment);
+  cast<AnyMemSetInst>(CI)->setDestAlignment(Alignment);
 
   // Set the TBAA info if present.
   if (TBAATag)
@@ -293,7 +293,7 @@
       CreateIntrinsic(Intrinsic::memcpy_element_unordered_atomic, Tys, Ops);
 
   // Set the alignment of the pointer args.
-  auto *AMCI = cast<AtomicMemCpyInst>(CI);
+  auto *AMCI = cast<AnyMemCpyInst>(CI);
   AMCI->setDestAlignment(DstAlign);
   AMCI->setSourceAlignment(SrcAlign);
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index a798808..83c1264 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -5617,7 +5617,7 @@
   case Intrinsic::memcpy_element_unordered_atomic:
   case Intrinsic::memmove_element_unordered_atomic:
   case Intrinsic::memset_element_unordered_atomic: {
-    const auto *AMI = cast<AtomicMemIntrinsic>(&Call);
+    const auto *AMI = cast<AnyMemIntrinsic>(&Call);
 
     ConstantInt *ElementSizeCI =
         cast<ConstantInt>(AMI->getRawElementSizeInBytes());
@@ -5632,7 +5632,7 @@
     };
     Check(IsValidAlignment(AMI->getDestAlign()),
           "incorrect alignment of the destination argument", Call);
-    if (const auto *AMT = dyn_cast<AtomicMemTransferInst>(AMI)) {
+    if (const auto *AMT = dyn_cast<AnyMemTransferInst>(AMI)) {
       Check(IsValidAlignment(AMT->getSourceAlign()),
             "incorrect alignment of the source argument", Call);
     }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 6ea09ed..3e78b20 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -164,7 +164,7 @@
   // introduce the unaligned memory access which will be later transformed
   // into libcall in CodeGen. This is not evident performance gain so disable
   // it now.
-  if (isa<AtomicMemTransferInst>(MI))
+  if (MI->isAtomic())
     if (*CopyDstAlign < Size || *CopySrcAlign < Size)
       return nullptr;
 
@@ -204,7 +204,7 @@
     L->setVolatile(MT->isVolatile());
     S->setVolatile(MT->isVolatile());
   }
-  if (isa<AtomicMemTransferInst>(MI)) {
+  if (MI->isAtomic()) {
     // atomics have to be unordered
     L->setOrdering(AtomicOrdering::Unordered);
     S->setOrdering(AtomicOrdering::Unordered);
@@ -255,9 +255,8 @@
   // introduce the unaligned memory access which will be later transformed
   // into libcall in CodeGen. This is not evident performance gain so disable
   // it now.
-  if (isa<AtomicMemSetInst>(MI))
-    if (Alignment < Len)
-      return nullptr;
+  if (MI->isAtomic() && Alignment < Len)
+    return nullptr;
 
   // memset(s,c,n) -> store s, c (for n=1,2,4,8)
   if (Len <= 8 && isPowerOf2_32((uint32_t)Len)) {
@@ -276,7 +275,7 @@
     for_each(at::getDVRAssignmentMarkers(S), replaceOpForAssignmentMarkers);
 
     S->setAlignment(Alignment);
-    if (isa<AtomicMemSetInst>(MI))
+    if (MI->isAtomic())
       S->setOrdering(AtomicOrdering::Unordered);
 
     // Set the size of the copy to 0, it will be deleted on the next iteration.
@@ -1654,27 +1653,27 @@
   }
 
   IntrinsicInst *II = dyn_cast<IntrinsicInst>(&CI);
-  if (!II) return visitCallBase(CI);
-
-  // For atomic unordered mem intrinsics if len is not a positive or
-  // not a multiple of element size then behavior is undefined.
-  if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(II))
-    if (ConstantInt *NumBytes = dyn_cast<ConstantInt>(AMI->getLength()))
-      if (NumBytes->isNegative() ||
-          (NumBytes->getZExtValue() % AMI->getElementSizeInBytes() != 0)) {
-        CreateNonTerminatorUnreachable(AMI);
-        assert(AMI->getType()->isVoidTy() &&
-               "non void atomic unordered mem intrinsic");
-        return eraseInstFromFunction(*AMI);
-      }
+  if (!II)
+    return visitCallBase(CI);
 
   // Intrinsics cannot occur in an invoke or a callbr, so handle them here
   // instead of in visitCallBase.
   if (auto *MI = dyn_cast<AnyMemIntrinsic>(II)) {
-    // memmove/cpy/set of zero bytes is a noop.
-    if (Constant *NumBytes = dyn_cast<Constant>(MI->getLength())) {
+    if (ConstantInt *NumBytes = dyn_cast<ConstantInt>(MI->getLength())) {
+      // memmove/cpy/set of zero bytes is a noop.
       if (NumBytes->isNullValue())
         return eraseInstFromFunction(CI);
+
+      // For atomic unordered mem intrinsics if len is not a positive or
+      // not a multiple of element size then behavior is undefined.
+      if (MI->isAtomic() &&
+          (NumBytes->isNegative() ||
+           (NumBytes->getZExtValue() % MI->getElementSizeInBytes() != 0))) {
+        CreateNonTerminatorUnreachable(MI);
+        assert(MI->getType()->isVoidTy() &&
+               "non void atomic unordered mem intrinsic");
+        return eraseInstFromFunction(*MI);
+      }
     }
 
     // No other transformations apply to volatile transfers.
@@ -1719,7 +1718,7 @@
         if (GVSrc->isConstant()) {
           Module *M = CI.getModule();
           Intrinsic::ID MemCpyID =
-              isa<AtomicMemMoveInst>(MMI)
+              MMI->isAtomic()
                   ? Intrinsic::memcpy_element_unordered_atomic
                   : Intrinsic::memcpy;
           Type *Tys[3] = { CI.getArgOperand(0)->getType(),
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 66168a9..15192ed 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -670,10 +670,10 @@
   assert(DeadSize > ToRemoveSize && "Can't remove more than original size");
 
   uint64_t NewSize = DeadSize - ToRemoveSize;
-  if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(DeadI)) {
+  if (DeadIntrinsic->isAtomic()) {
     // When shortening an atomic memory intrinsic, the newly shortened
     // length must remain an integer multiple of the element size.
-    const uint32_t ElementSize = AMI->getElementSizeInBytes();
+    const uint32_t ElementSize = DeadIntrinsic->getElementSizeInBytes();
     if (0 != NewSize % ElementSize)
       return false;
   }
diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
index 8005be8..d9805d8 100644
--- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
+++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
@@ -3054,7 +3054,8 @@
       // non-leaf memcpy/memmove without deopt state just treat it as a leaf
       // copy and don't produce a statepoint.
       if (!AllowStatepointWithNoDeoptInfo && !Call->hasDeoptState()) {
-        assert((isa<AtomicMemCpyInst>(Call) || isa<AtomicMemMoveInst>(Call)) &&
+        assert(isa<AnyMemTransferInst>(Call) &&
+               cast<AnyMemTransferInst>(Call)->isAtomic() &&
                "Don't expect any other calls here!");
         return false;
       }
diff --git a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
index dbab56a..18b0f61 100644
--- a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
+++ b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
@@ -982,9 +982,10 @@
                    Memset->isVolatile());
 }
 
-void llvm::expandAtomicMemCpyAsLoop(AtomicMemCpyInst *AtomicMemcpy,
+void llvm::expandAtomicMemCpyAsLoop(AnyMemCpyInst *AtomicMemcpy,
                                     const TargetTransformInfo &TTI,
                                     ScalarEvolution *SE) {
+  assert(AtomicMemcpy->isAtomic());
   if (ConstantInt *CI = dyn_cast<ConstantInt>(AtomicMemcpy->getLength())) {
     createMemCpyLoopKnownSize(
         /* InsertBefore */ AtomicMemcpy,
diff --git a/llvm/unittests/Transforms/Utils/MemTransferLowering.cpp b/llvm/unittests/Transforms/Utils/MemTransferLowering.cpp
index 68c6364..b97bc31 100644
--- a/llvm/unittests/Transforms/Utils/MemTransferLowering.cpp
+++ b/llvm/unittests/Transforms/Utils/MemTransferLowering.cpp
@@ -198,9 +198,9 @@
         TargetTransformInfo TTI(M->getDataLayout());
         auto *MemCpyBB = getBasicBlockByName(F, "memcpy");
         Instruction *Inst = &MemCpyBB->front();
-        assert(isa<AtomicMemCpyInst>(Inst) &&
+        assert(isa<AnyMemCpyInst>(Inst) &&
                "Expecting llvm.memcpy.p0i8.i64 instructon");
-        AtomicMemCpyInst *MemCpyI = cast<AtomicMemCpyInst>(Inst);
+        AnyMemCpyInst *MemCpyI = cast<AnyMemCpyInst>(Inst);
         auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
         expandAtomicMemCpyAsLoop(MemCpyI, TTI, &SE);
         auto *CopyLoopBB = getBasicBlockByName(F, "load-store-loop");
@@ -243,9 +243,9 @@
         TargetTransformInfo TTI(M->getDataLayout());
         auto *MemCpyBB = getBasicBlockByName(F, "memcpy");
         Instruction *Inst = &MemCpyBB->front();
-        assert(isa<AtomicMemCpyInst>(Inst) &&
+        assert(isa<AnyMemCpyInst>(Inst) &&
                "Expecting llvm.memcpy.p0i8.i64 instructon");
-        AtomicMemCpyInst *MemCpyI = cast<AtomicMemCpyInst>(Inst);
+        auto *MemCpyI = cast<AnyMemCpyInst>(Inst);
         auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
         expandAtomicMemCpyAsLoop(MemCpyI, TTI, &SE);
         auto *CopyLoopBB = getBasicBlockByName(F, "loop-memcpy-expansion");