rocPRIM
warp_reduce.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_WARP_WARP_REDUCE_HPP_
22 #define ROCPRIM_WARP_WARP_REDUCE_HPP_
23 
24 #include <type_traits>
25 
26 #include "../config.hpp"
27 #include "../detail/various.hpp"
28 
29 #include "../intrinsics.hpp"
30 #include "../functional.hpp"
31 #include "../types.hpp"
32 
33 #include "detail/warp_reduce_crosslane.hpp"
34 #include "detail/warp_reduce_shared_mem.hpp"
35 
38 
39 BEGIN_ROCPRIM_NAMESPACE
40 
41 namespace detail
42 {
43 
44 // Select warp_reduce implementation based WarpSize
45 template<class T, unsigned int WarpSize, bool UseAllReduce>
47 {
48  typedef typename std::conditional<
49  // can we use crosslane (DPP or shuffle-based) implementation?
51  detail::warp_reduce_crosslane<T, WarpSize, UseAllReduce>, // yes
53  >::type type;
54 };
55 
56 } // end namespace detail
57 
109 template<
110  class T,
111  unsigned int WarpSize = device_warp_size(),
112  bool UseAllReduce = false
113 >
115 #ifndef DOXYGEN_SHOULD_SKIP_THIS
116  : private detail::select_warp_reduce_impl<T, WarpSize, UseAllReduce>::type
117 #endif
118 {
119  using base_type = typename detail::select_warp_reduce_impl<T, WarpSize, UseAllReduce>::type;
120 
121  // Check if WarpSize is valid for the targets
122  static_assert(WarpSize <= ROCPRIM_MAX_WARP_SIZE, "WarpSize can't be greater than hardware warp size.");
123 
124 public:
133  using storage_type = typename base_type::storage_type;
134 
179  template<class BinaryFunction = ::rocprim::plus<T>, unsigned int FunctionWarpSize = WarpSize>
180  ROCPRIM_DEVICE ROCPRIM_INLINE
181  auto reduce(T input,
182  T& output,
183  storage_type& storage,
184  BinaryFunction reduce_op = BinaryFunction())
185  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
186  {
187  base_type::reduce(input, output, storage, reduce_op);
188  }
189 
192  template<class BinaryFunction = ::rocprim::plus<T>, unsigned int FunctionWarpSize = WarpSize>
193  ROCPRIM_DEVICE ROCPRIM_INLINE
194  auto reduce(T ,
195  T& ,
196  storage_type& ,
197  BinaryFunction reduce_op = BinaryFunction())
198  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
199  {
200  (void) reduce_op;
201  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
202  return;
203  }
204 
251  template<class BinaryFunction = ::rocprim::plus<T>, unsigned int FunctionWarpSize = WarpSize>
252  ROCPRIM_DEVICE ROCPRIM_INLINE
253  auto reduce(T input,
254  T& output,
255  int valid_items,
256  storage_type& storage,
257  BinaryFunction reduce_op = BinaryFunction())
258  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
259  {
260  base_type::reduce(input, output, valid_items, storage, reduce_op);
261  }
262 
265  template<class BinaryFunction = ::rocprim::plus<T>, unsigned int FunctionWarpSize = WarpSize>
266  ROCPRIM_DEVICE ROCPRIM_INLINE
267  auto reduce(T ,
268  T& ,
269  int ,
270  storage_type& ,
271  BinaryFunction reduce_op = BinaryFunction())
272  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
273  {
274  (void) reduce_op;
275  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
276  return;
277  }
278 
297  template<class Flag, class BinaryFunction = ::rocprim::plus<T>, unsigned int FunctionWarpSize = WarpSize>
298  ROCPRIM_DEVICE ROCPRIM_INLINE
299  auto head_segmented_reduce(T input,
300  T& output,
301  Flag flag,
302  storage_type& storage,
303  BinaryFunction reduce_op = BinaryFunction())
304  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
305  {
306  base_type::head_segmented_reduce(input, output, flag, storage, reduce_op);
307  }
308 
311  template<class Flag, class BinaryFunction = ::rocprim::plus<T>, unsigned int FunctionWarpSize = WarpSize>
312  ROCPRIM_DEVICE ROCPRIM_INLINE
314  T& ,
315  Flag ,
316  storage_type& ,
317  BinaryFunction reduce_op = BinaryFunction())
318  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
319  {
320  (void) reduce_op;
321  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
322  return;
323  }
324 
343  template<class Flag, class BinaryFunction = ::rocprim::plus<T>, unsigned int FunctionWarpSize = WarpSize>
344  ROCPRIM_DEVICE ROCPRIM_INLINE
345  auto tail_segmented_reduce(T input,
346  T& output,
347  Flag flag,
348  storage_type& storage,
349  BinaryFunction reduce_op = BinaryFunction())
350  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
351  {
352  base_type::tail_segmented_reduce(input, output, flag, storage, reduce_op);
353  }
354 
357  template<class Flag, class BinaryFunction = ::rocprim::plus<T>, unsigned int FunctionWarpSize = WarpSize>
358  ROCPRIM_DEVICE ROCPRIM_INLINE
360  T& ,
361  Flag ,
362  storage_type& ,
363  BinaryFunction reduce_op = BinaryFunction())
364  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
365  {
366  (void) reduce_op;
367  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
368  return;
369  }
370 };
371 
372 END_ROCPRIM_NAMESPACE
373 
375 // end of group warpmodule
376 
377 #endif // ROCPRIM_WARP_WARP_REDUCE_HPP_
ROCPRIM_DEVICE ROCPRIM_INLINE auto head_segmented_reduce(T, T &, Flag, storage_type &, BinaryFunction reduce_op=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Performs head-segmented reduction across threads in a logical warp.
Definition: warp_reduce.hpp:313
ROCPRIM_DEVICE ROCPRIM_INLINE auto reduce(T, T &, storage_type &, BinaryFunction reduce_op=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Performs reduction across threads in a logical warp.
Definition: warp_reduce.hpp:194
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
typename base_type::storage_type storage_type
Struct used to allocate a temporary memory that is required for thread communication during operation...
Definition: warp_reduce.hpp:133
hipError_t reduce(void *temporary_storage, size_t &storage_size, InputIterator input, OutputIterator output, const InitValueType initial_value, const size_t size, BinaryFunction reduce_op=BinaryFunction(), const hipStream_t stream=0, bool debug_synchronous=false)
Parallel reduction primitive for device level.
Definition: device_reduce.hpp:374
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: warp_reduce.hpp:46
ROCPRIM_DEVICE ROCPRIM_INLINE auto reduce(T input, T &output, int valid_items, storage_type &storage, BinaryFunction reduce_op=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Performs reduction across threads in a logical warp.
Definition: warp_reduce.hpp:253
ROCPRIM_DEVICE ROCPRIM_INLINE auto tail_segmented_reduce(T, T &, Flag, storage_type &, BinaryFunction reduce_op=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Performs tail-segmented reduction across threads in a logical warp.
Definition: warp_reduce.hpp:359
#define ROCPRIM_PRINT_ERROR_ONCE(message)
Prints the supplied error message only once (using only one of the active threads).
Definition: functional.hpp:42
The warp_reduce class is a warp level parallel primitive which provides methods for performing reduct...
Definition: warp_reduce.hpp:114
ROCPRIM_DEVICE ROCPRIM_INLINE auto reduce(T, T &, int, storage_type &, BinaryFunction reduce_op=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Performs reduction across threads in a logical warp.
Definition: warp_reduce.hpp:267
Definition: warp_reduce_shared_mem.hpp:43
Definition: various.hpp:108
ROCPRIM_DEVICE ROCPRIM_INLINE auto tail_segmented_reduce(T input, T &output, Flag flag, storage_type &storage, BinaryFunction reduce_op=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Performs tail-segmented reduction across threads in a logical warp.
Definition: warp_reduce.hpp:345
ROCPRIM_DEVICE ROCPRIM_INLINE auto head_segmented_reduce(T input, T &output, Flag flag, storage_type &storage, BinaryFunction reduce_op=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Performs head-segmented reduction across threads in a logical warp.
Definition: warp_reduce.hpp:299
ROCPRIM_DEVICE ROCPRIM_INLINE auto reduce(T input, T &output, storage_type &storage, BinaryFunction reduce_op=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Performs reduction across threads in a logical warp.
Definition: warp_reduce.hpp:181