31#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_HELPER_H
32#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_HELPER_H
35#if !defined(__HIPCC_RTC__)
36#include <hip/amd_detail/amd_hip_runtime.h>
37#include <hip/amd_detail/amd_device_functions.h>
39#if !defined(__align__)
40#define __align__(x) __attribute__((aligned(x)))
43#if !defined(__CG_QUALIFIER__)
44#define __CG_QUALIFIER__ __device__ __forceinline__
47#if !defined(__CG_STATIC_QUALIFIER__)
48#define __CG_STATIC_QUALIFIER__ __device__ static __forceinline__
51#if !defined(_CG_STATIC_CONST_DECL_)
52#define _CG_STATIC_CONST_DECL_ static constexpr
55#if __AMDGCN_WAVEFRONT_SIZE == 32
56using lane_mask =
unsigned int;
58using lane_mask =
unsigned long long int;
61namespace cooperative_groups {
64template <
unsigned int size>
65using is_power_of_2 = std::integral_constant<bool, (size & (size - 1)) == 0>;
67template <
unsigned int size>
68using is_valid_wavefront = std::integral_constant<bool, (size <= __AMDGCN_WAVEFRONT_SIZE)>;
70template <
unsigned int size>
71using is_valid_tile_size =
72 std::integral_constant<bool, is_power_of_2<size>::value && is_valid_wavefront<size>::value>;
76 std::integral_constant<bool, std::is_integral<T>::value || std::is_floating_point<T>::value>;
118__CG_STATIC_QUALIFIER__
unsigned long long adjust_mask(
119 unsigned long long base_mask,
unsigned long long input_mask) {
120 unsigned long long out = 0;
121 for (
unsigned int i = 0, index = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) {
122 auto lane_active = base_mask & (1ull << i);
124 auto result = input_mask & (1ull << i);
125 out |= ((result ? 1ull : 0ull) << index);
138namespace multi_grid {
140__CG_STATIC_QUALIFIER__ uint32_t num_grids() {
141 return static_cast<uint32_t
>(__ockl_multi_grid_num_grids()); }
143__CG_STATIC_QUALIFIER__ uint32_t grid_rank() {
144 return static_cast<uint32_t
>(__ockl_multi_grid_grid_rank()); }
146__CG_STATIC_QUALIFIER__ uint32_t size() {
return static_cast<uint32_t
>(__ockl_multi_grid_size()); }
148__CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
149 return static_cast<uint32_t
>(__ockl_multi_grid_thread_rank()); }
151__CG_STATIC_QUALIFIER__
bool is_valid() {
return static_cast<bool>(__ockl_multi_grid_is_valid()); }
153__CG_STATIC_QUALIFIER__
void sync() { __ockl_multi_grid_sync(); }
163__CG_STATIC_QUALIFIER__ uint32_t size() {
164 return static_cast<uint32_t
>((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y) *
165 (blockDim.x * gridDim.x));
168__CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
170 uint32_t blkIdx =
static_cast<uint32_t
>((blockIdx.z * gridDim.y * gridDim.x) +
171 (blockIdx.y * gridDim.x) + (blockIdx.x));
175 uint32_t num_threads_till_current_workgroup =
176 static_cast<uint32_t
>(blkIdx * (blockDim.x * blockDim.y * blockDim.z));
179 uint32_t local_thread_rank =
static_cast<uint32_t
>((threadIdx.z * blockDim.y * blockDim.x) +
180 (threadIdx.y * blockDim.x) + (threadIdx.x));
182 return (num_threads_till_current_workgroup + local_thread_rank);
185__CG_STATIC_QUALIFIER__
bool is_valid() {
return static_cast<bool>(__ockl_grid_is_valid()); }
187__CG_STATIC_QUALIFIER__
void sync() { __ockl_grid_sync(); }
198__CG_STATIC_QUALIFIER__ dim3 group_index() {
199 return (dim3(
static_cast<uint32_t
>(blockIdx.x),
static_cast<uint32_t
>(blockIdx.y),
200 static_cast<uint32_t
>(blockIdx.z)));
203__CG_STATIC_QUALIFIER__ dim3 thread_index() {
204 return (dim3(
static_cast<uint32_t
>(threadIdx.x),
static_cast<uint32_t
>(threadIdx.y),
205 static_cast<uint32_t
>(threadIdx.z)));
208__CG_STATIC_QUALIFIER__ uint32_t size() {
209 return (
static_cast<uint32_t
>(blockDim.x * blockDim.y * blockDim.z));
212__CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
213 return (
static_cast<uint32_t
>((threadIdx.z * blockDim.y * blockDim.x) +
214 (threadIdx.y * blockDim.x) + (threadIdx.x)));
217__CG_STATIC_QUALIFIER__
bool is_valid() {
221__CG_STATIC_QUALIFIER__
void sync() { __syncthreads(); }
223__CG_STATIC_QUALIFIER__ dim3 block_dim() {
224 return (dim3(
static_cast<uint32_t
>(blockDim.x),
static_cast<uint32_t
>(blockDim.y),
225 static_cast<uint32_t
>(blockDim.z)));
230namespace tiled_group {
233__CG_STATIC_QUALIFIER__
void sync() { __builtin_amdgcn_fence(__ATOMIC_ACQ_REL,
"agent"); }
237namespace coalesced_group {
240__CG_STATIC_QUALIFIER__
void sync() { __builtin_amdgcn_fence(__ATOMIC_ACQ_REL,
"agent"); }
246__CG_STATIC_QUALIFIER__
unsigned int masked_bit_count(lane_mask x,
unsigned int add = 0) {
247 unsigned int counter=0;
248 #if __AMDGCN_WAVEFRONT_SIZE == 32
249 counter = __builtin_amdgcn_mbcnt_lo(x, add);
251 counter = __builtin_amdgcn_mbcnt_lo(
static_cast<lane_mask
>(x), add);
252 counter = __builtin_amdgcn_mbcnt_hi(
static_cast<lane_mask
>(x >> 32), counter);