[OpenMP][NFCI] Cleanup new device RT mapping interface
Minimize the `impl` interface and clean up some uses of mapping
functions.
Reviewed By: jhuber6
Differential Revision: https://reviews.llvm.org/D112154
diff --git a/openmp/libomptarget/DeviceRTL/src/Mapping.cpp b/openmp/libomptarget/DeviceRTL/src/Mapping.cpp
index be937d1..bece294 100644
--- a/openmp/libomptarget/DeviceRTL/src/Mapping.cpp
+++ b/openmp/libomptarget/DeviceRTL/src/Mapping.cpp
@@ -10,6 +10,7 @@
//===----------------------------------------------------------------------===//
#include "Mapping.h"
+#include "Interface.h"
#include "State.h"
#include "Types.h"
#include "Utils.h"
@@ -43,6 +44,12 @@
return (r < group_size) ? r : group_size;
}
+uint32_t getNumHardwareThreadsInBlock() {
+ return getWorkgroupDim(__builtin_amdgcn_workgroup_id_x(),
+ __builtin_amdgcn_grid_size_x(),
+ __builtin_amdgcn_workgroup_size_x());
+}
+
LaneMaskTy activemask() { return __builtin_amdgcn_read_exec(); }
LaneMaskTy lanemaskLT() {
@@ -67,13 +74,6 @@
uint32_t getThreadIdInBlock() { return __builtin_amdgcn_workitem_id_x(); }
-uint32_t getBlockSize() {
- // TODO: verify this logic for generic mode.
- return getWorkgroupDim(__builtin_amdgcn_workgroup_id_x(),
- __builtin_amdgcn_grid_size_x(),
- __builtin_amdgcn_workgroup_size_x());
-}
-
uint32_t getKernelSize() { return __builtin_amdgcn_grid_size_x(); }
uint32_t getBlockId() { return __builtin_amdgcn_workgroup_id_x(); }
@@ -83,12 +83,8 @@
__builtin_amdgcn_workgroup_size_x());
}
-uint32_t getNumberOfProcessorElements() {
- return getBlockSize();
-}
-
uint32_t getWarpId() {
- return mapping::getThreadIdInBlock() / mapping::getWarpSize();
+ return impl::getThreadIdInBlock() / mapping::getWarpSize();
}
uint32_t getNumberOfWarpsInBlock() {
@@ -104,6 +100,10 @@
#pragma omp begin declare variant match( \
device = {arch(nvptx, nvptx64)}, implementation = {extension(match_any)})
+uint32_t getNumHardwareThreadsInBlock() {
+ return __nvvm_read_ptx_sreg_ntid_x();
+}
+
constexpr const llvm::omp::GV &getGridValue() {
return llvm::omp::NVPTXGridValues;
}
@@ -126,29 +126,23 @@
return Res;
}
-uint32_t getThreadIdInWarp() {
- return mapping::getThreadIdInBlock() & (mapping::getWarpSize() - 1);
-}
-
uint32_t getThreadIdInBlock() { return __nvvm_read_ptx_sreg_tid_x(); }
-uint32_t getBlockSize() {
- return __nvvm_read_ptx_sreg_ntid_x() -
- (!mapping::isSPMDMode() * mapping::getWarpSize());
+uint32_t getThreadIdInWarp() {
+ return impl::getThreadIdInBlock() & (mapping::getWarpSize() - 1);
}
-uint32_t getKernelSize() { return __nvvm_read_ptx_sreg_nctaid_x(); }
+uint32_t getKernelSize() {
+ return __nvvm_read_ptx_sreg_nctaid_x() *
+ mapping::getNumberOfProcessorElements();
+}
uint32_t getBlockId() { return __nvvm_read_ptx_sreg_ctaid_x(); }
uint32_t getNumberOfBlocks() { return __nvvm_read_ptx_sreg_nctaid_x(); }
-uint32_t getNumberOfProcessorElements() {
- return __nvvm_read_ptx_sreg_ntid_x();
-}
-
uint32_t getWarpId() {
- return mapping::getThreadIdInBlock() / mapping::getWarpSize();
+ return impl::getThreadIdInBlock() / mapping::getWarpSize();
}
uint32_t getNumberOfWarpsInBlock() {
@@ -164,6 +158,10 @@
} // namespace impl
} // namespace _OMP
+/// We have to be deliberate about the distinction of `mapping::` and `impl::`
+/// below to avoid repeating assumptions or including irrelevant ones.
+///{
+
static bool isInLastWarp() {
uint32_t MainTId = (mapping::getNumberOfProcessorElements() - 1) &
~(mapping::getWarpSize() - 1);
@@ -200,30 +198,60 @@
LaneMaskTy mapping::lanemaskGT() { return impl::lanemaskGT(); }
-uint32_t mapping::getThreadIdInWarp() { return impl::getThreadIdInWarp(); }
-
-uint32_t mapping::getThreadIdInBlock() { return impl::getThreadIdInBlock(); }
-
-uint32_t mapping::getBlockSize() { return impl::getBlockSize(); }
-
-uint32_t mapping::getKernelSize() { return impl::getKernelSize(); }
-
-uint32_t mapping::getBlockId() { return impl::getBlockId(); }
-
-uint32_t mapping::getNumberOfBlocks() { return impl::getNumberOfBlocks(); }
-
-uint32_t mapping::getNumberOfProcessorElements() {
- return impl::getNumberOfProcessorElements();
+uint32_t mapping::getThreadIdInWarp() {
+ uint32_t ThreadIdInWarp = impl::getThreadIdInWarp();
+ ASSERT(ThreadIdInWarp < impl::getWarpSize());
+ return ThreadIdInWarp;
}
-uint32_t mapping::getWarpId() { return impl::getWarpId(); }
+uint32_t mapping::getThreadIdInBlock() {
+ uint32_t ThreadIdInBlock = impl::getThreadIdInBlock();
+ ASSERT(ThreadIdInBlock < impl::getNumHardwareThreadsInBlock());
+ return ThreadIdInBlock;
+}
uint32_t mapping::getWarpSize() { return impl::getWarpSize(); }
-uint32_t mapping::getNumberOfWarpsInBlock() {
- return impl::getNumberOfWarpsInBlock();
+uint32_t mapping::getBlockSize() {
+ uint32_t BlockSize = mapping::getNumberOfProcessorElements() -
+ (!mapping::isSPMDMode() * impl::getWarpSize());
+ return BlockSize;
}
+uint32_t mapping::getKernelSize() { return impl::getKernelSize(); }
+
+uint32_t mapping::getWarpId() {
+ uint32_t WarpID = impl::getWarpId();
+ ASSERT(WarpID < impl::getNumberOfWarpsInBlock());
+ return WarpID;
+}
+
+uint32_t mapping::getBlockId() {
+ uint32_t BlockId = impl::getBlockId();
+ ASSERT(BlockId < impl::getNumberOfBlocks());
+ return BlockId;
+}
+
+uint32_t mapping::getNumberOfWarpsInBlock() {
+ uint32_t NumberOfWarpsInBlocks = impl::getNumberOfWarpsInBlock();
+ ASSERT(impl::getWarpId() < NumberOfWarpsInBlocks);
+ return NumberOfWarpsInBlocks;
+}
+
+uint32_t mapping::getNumberOfBlocks() {
+ uint32_t NumberOfBlocks = impl::getNumberOfBlocks();
+ ASSERT(impl::getBlockId() < NumberOfBlocks);
+ return NumberOfBlocks;
+}
+
+uint32_t mapping::getNumberOfProcessorElements() {
+ uint32_t NumberOfProcessorElements = impl::getNumHardwareThreadsInBlock();
+ ASSERT(impl::getThreadIdInBlock() < NumberOfProcessorElements);
+ return NumberOfProcessorElements;
+}
+
+///}
+
/// Execution mode
///
///{
@@ -247,7 +275,7 @@
__attribute__((noinline)) uint32_t __kmpc_get_hardware_num_threads_in_block() {
FunctionTracingRAII();
- return mapping::getNumberOfProcessorElements();
+ return impl::getNumHardwareThreadsInBlock();
}
}
#pragma omp end declare target