[OpenMP][NFCI] Cleanup new device RT mapping interface

Minimize the `impl` interface and clean up some uses of mapping
functions.

Reviewed By: jhuber6

Differential Revision: https://reviews.llvm.org/D112154
diff --git a/openmp/libomptarget/DeviceRTL/src/Mapping.cpp b/openmp/libomptarget/DeviceRTL/src/Mapping.cpp
index be937d1..bece294 100644
--- a/openmp/libomptarget/DeviceRTL/src/Mapping.cpp
+++ b/openmp/libomptarget/DeviceRTL/src/Mapping.cpp
@@ -10,6 +10,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "Mapping.h"
+#include "Interface.h"
 #include "State.h"
 #include "Types.h"
 #include "Utils.h"
@@ -43,6 +44,12 @@
   return (r < group_size) ? r : group_size;
 }
 
+uint32_t getNumHardwareThreadsInBlock() {
+  return getWorkgroupDim(__builtin_amdgcn_workgroup_id_x(),
+                         __builtin_amdgcn_grid_size_x(),
+                         __builtin_amdgcn_workgroup_size_x());
+}
+
 LaneMaskTy activemask() { return __builtin_amdgcn_read_exec(); }
 
 LaneMaskTy lanemaskLT() {
@@ -67,13 +74,6 @@
 
 uint32_t getThreadIdInBlock() { return __builtin_amdgcn_workitem_id_x(); }
 
-uint32_t getBlockSize() {
-  // TODO: verify this logic for generic mode.
-  return getWorkgroupDim(__builtin_amdgcn_workgroup_id_x(),
-                         __builtin_amdgcn_grid_size_x(),
-                         __builtin_amdgcn_workgroup_size_x());
-}
-
 uint32_t getKernelSize() { return __builtin_amdgcn_grid_size_x(); }
 
 uint32_t getBlockId() { return __builtin_amdgcn_workgroup_id_x(); }
@@ -83,12 +83,8 @@
                     __builtin_amdgcn_workgroup_size_x());
 }
 
-uint32_t getNumberOfProcessorElements() {
-  return getBlockSize();
-}
-
 uint32_t getWarpId() {
-  return mapping::getThreadIdInBlock() / mapping::getWarpSize();
+  return impl::getThreadIdInBlock() / mapping::getWarpSize();
 }
 
 uint32_t getNumberOfWarpsInBlock() {
@@ -104,6 +100,10 @@
 #pragma omp begin declare variant match(                                       \
     device = {arch(nvptx, nvptx64)}, implementation = {extension(match_any)})
 
+uint32_t getNumHardwareThreadsInBlock() {
+  return __nvvm_read_ptx_sreg_ntid_x();
+}
+
 constexpr const llvm::omp::GV &getGridValue() {
   return llvm::omp::NVPTXGridValues;
 }
@@ -126,29 +126,23 @@
   return Res;
 }
 
-uint32_t getThreadIdInWarp() {
-  return mapping::getThreadIdInBlock() & (mapping::getWarpSize() - 1);
-}
-
 uint32_t getThreadIdInBlock() { return __nvvm_read_ptx_sreg_tid_x(); }
 
-uint32_t getBlockSize() {
-  return __nvvm_read_ptx_sreg_ntid_x() -
-         (!mapping::isSPMDMode() * mapping::getWarpSize());
+uint32_t getThreadIdInWarp() {
+  return impl::getThreadIdInBlock() & (mapping::getWarpSize() - 1);
 }
 
-uint32_t getKernelSize() { return __nvvm_read_ptx_sreg_nctaid_x(); }
+uint32_t getKernelSize() {
+  return __nvvm_read_ptx_sreg_nctaid_x() *
+         mapping::getNumberOfProcessorElements();
+}
 
 uint32_t getBlockId() { return __nvvm_read_ptx_sreg_ctaid_x(); }
 
 uint32_t getNumberOfBlocks() { return __nvvm_read_ptx_sreg_nctaid_x(); }
 
-uint32_t getNumberOfProcessorElements() {
-  return __nvvm_read_ptx_sreg_ntid_x();
-}
-
 uint32_t getWarpId() {
-  return mapping::getThreadIdInBlock() / mapping::getWarpSize();
+  return impl::getThreadIdInBlock() / mapping::getWarpSize();
 }
 
 uint32_t getNumberOfWarpsInBlock() {
@@ -164,6 +158,10 @@
 } // namespace impl
 } // namespace _OMP
 
+/// We have to be deliberate about the distinction of `mapping::` and `impl::`
+/// below to avoid repeating assumptions or including irrelevant ones.
+///{
+
 static bool isInLastWarp() {
   uint32_t MainTId = (mapping::getNumberOfProcessorElements() - 1) &
                      ~(mapping::getWarpSize() - 1);
@@ -200,30 +198,60 @@
 
 LaneMaskTy mapping::lanemaskGT() { return impl::lanemaskGT(); }
 
-uint32_t mapping::getThreadIdInWarp() { return impl::getThreadIdInWarp(); }
-
-uint32_t mapping::getThreadIdInBlock() { return impl::getThreadIdInBlock(); }
-
-uint32_t mapping::getBlockSize() { return impl::getBlockSize(); }
-
-uint32_t mapping::getKernelSize() { return impl::getKernelSize(); }
-
-uint32_t mapping::getBlockId() { return impl::getBlockId(); }
-
-uint32_t mapping::getNumberOfBlocks() { return impl::getNumberOfBlocks(); }
-
-uint32_t mapping::getNumberOfProcessorElements() {
-  return impl::getNumberOfProcessorElements();
+uint32_t mapping::getThreadIdInWarp() {
+  uint32_t ThreadIdInWarp = impl::getThreadIdInWarp();
+  ASSERT(ThreadIdInWarp < impl::getWarpSize());
+  return ThreadIdInWarp;
 }
 
-uint32_t mapping::getWarpId() { return impl::getWarpId(); }
+uint32_t mapping::getThreadIdInBlock() {
+  uint32_t ThreadIdInBlock = impl::getThreadIdInBlock();
+  ASSERT(ThreadIdInBlock < impl::getNumHardwareThreadsInBlock());
+  return ThreadIdInBlock;
+}
 
 uint32_t mapping::getWarpSize() { return impl::getWarpSize(); }
 
-uint32_t mapping::getNumberOfWarpsInBlock() {
-  return impl::getNumberOfWarpsInBlock();
+uint32_t mapping::getBlockSize() {
+  uint32_t BlockSize = mapping::getNumberOfProcessorElements() -
+                       (!mapping::isSPMDMode() * impl::getWarpSize());
+  return BlockSize;
 }
 
+uint32_t mapping::getKernelSize() { return impl::getKernelSize(); }
+
+uint32_t mapping::getWarpId() {
+  uint32_t WarpID = impl::getWarpId();
+  ASSERT(WarpID < impl::getNumberOfWarpsInBlock());
+  return WarpID;
+}
+
+uint32_t mapping::getBlockId() {
+  uint32_t BlockId = impl::getBlockId();
+  ASSERT(BlockId < impl::getNumberOfBlocks());
+  return BlockId;
+}
+
+uint32_t mapping::getNumberOfWarpsInBlock() {
+  uint32_t NumberOfWarpsInBlocks = impl::getNumberOfWarpsInBlock();
+  ASSERT(impl::getWarpId() < NumberOfWarpsInBlocks);
+  return NumberOfWarpsInBlocks;
+}
+
+uint32_t mapping::getNumberOfBlocks() {
+  uint32_t NumberOfBlocks = impl::getNumberOfBlocks();
+  ASSERT(impl::getBlockId() < NumberOfBlocks);
+  return NumberOfBlocks;
+}
+
+uint32_t mapping::getNumberOfProcessorElements() {
+  uint32_t NumberOfProcessorElements = impl::getNumHardwareThreadsInBlock();
+  ASSERT(impl::getThreadIdInBlock() < NumberOfProcessorElements);
+  return NumberOfProcessorElements;
+}
+
+///}
+
 /// Execution mode
 ///
 ///{
@@ -247,7 +275,7 @@
 
 __attribute__((noinline)) uint32_t __kmpc_get_hardware_num_threads_in_block() {
   FunctionTracingRAII();
-  return mapping::getNumberOfProcessorElements();
+  return impl::getNumHardwareThreadsInBlock();
 }
 }
 #pragma omp end declare target