rocPRIM
warp_reduce_shuffle.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHUFFLE_HPP_
22 #define ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHUFFLE_HPP_
23 
24 #include <type_traits>
25 
26 #include "../../config.hpp"
27 #include "../../intrinsics.hpp"
28 #include "../../types.hpp"
29 #include "../../detail/various.hpp"
30 
31 #include "warp_segment_bounds.hpp"
32 
33 BEGIN_ROCPRIM_NAMESPACE
34 
35 namespace detail
36 {
37 
38 template<
39  class T,
40  unsigned int WarpSize,
41  bool UseAllReduce
42 >
44 {
45 public:
46  static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2");
47 
49 
50  template<class BinaryFunction>
51  ROCPRIM_DEVICE ROCPRIM_INLINE
52  void reduce(T input, T& output, BinaryFunction reduce_op)
53  {
54  output = input;
55 
56  T value;
57  ROCPRIM_UNROLL
58  for(unsigned int offset = 1; offset < WarpSize; offset *= 2)
59  {
60  value = warp_shuffle_down(output, offset, WarpSize);
61  output = reduce_op(output, value);
62  }
63  set_output<UseAllReduce>(output);
64  }
65 
66  template<class BinaryFunction>
67  ROCPRIM_DEVICE ROCPRIM_INLINE
68  void reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op)
69  {
70  (void) storage; // disables unused parameter warning
71  this->reduce(input, output, reduce_op);
72  }
73 
74  template<bool UseAllReduceDummy = UseAllReduce, class BinaryFunction>
75  ROCPRIM_DEVICE ROCPRIM_INLINE
76  void reduce(T input, T& output, unsigned int valid_items, BinaryFunction reduce_op)
77  {
78  output = input;
79 
80  T value;
81  ROCPRIM_UNROLL
82  for(unsigned int offset = 1; offset < WarpSize; offset *= 2)
83  {
84  value = warp_shuffle_down(output, offset, WarpSize);
85  unsigned int id = detail::logical_lane_id<WarpSize>();
86  if (id + offset < valid_items) output = reduce_op(output, value);
87  }
88  set_output<UseAllReduceDummy>(output);
89  }
90 
91  template<class BinaryFunction>
92  ROCPRIM_DEVICE ROCPRIM_INLINE
93  void reduce(T input, T& output, unsigned int valid_items,
94  storage_type& storage, BinaryFunction reduce_op)
95  {
96  (void) storage; // disables unused parameter warning
97  this->reduce(input, output, valid_items, reduce_op);
98  }
99 
100  template<class Flag, class BinaryFunction>
101  ROCPRIM_DEVICE ROCPRIM_INLINE
102  void head_segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op)
103  {
104  this->segmented_reduce<true>(input, output, flag, reduce_op);
105  }
106 
107  template<class Flag, class BinaryFunction>
108  ROCPRIM_DEVICE ROCPRIM_INLINE
109  void tail_segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op)
110  {
111  this->segmented_reduce<false>(input, output, flag, reduce_op);
112  }
113 
114  template<class Flag, class BinaryFunction>
115  ROCPRIM_DEVICE ROCPRIM_INLINE
116  void head_segmented_reduce(T input, T& output, Flag flag,
117  storage_type& storage, BinaryFunction reduce_op)
118  {
119  (void) storage;
120  this->segmented_reduce<true>(input, output, flag, reduce_op);
121  }
122 
123  template<class Flag, class BinaryFunction>
124  ROCPRIM_DEVICE ROCPRIM_INLINE
125  void tail_segmented_reduce(T input, T& output, Flag flag,
126  storage_type& storage, BinaryFunction reduce_op)
127  {
128  (void) storage;
129  this->segmented_reduce<false>(input, output, flag, reduce_op);
130  }
131 
132 private:
133  template<bool HeadSegmented, class Flag, class BinaryFunction>
134  ROCPRIM_DEVICE ROCPRIM_INLINE
135  void segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op)
136  {
137  // Get logical lane id of the last valid value in the segment,
138  // and convert it to number of valid values in segment.
139  auto valid_items_in_segment = last_in_warp_segment<HeadSegmented, WarpSize>(flag) + 1U;
140  this->reduce<false>(input, output, valid_items_in_segment, reduce_op);
141  }
142 
143  template<bool Switch>
144  ROCPRIM_DEVICE ROCPRIM_INLINE
145  typename std::enable_if<(Switch == false)>::type
146  set_output(T& output)
147  {
148  (void) output;
149  // output already set correctly
150  }
151 
152  template<bool Switch>
153  ROCPRIM_DEVICE ROCPRIM_INLINE
154  typename std::enable_if<(Switch == true)>::type
155  set_output(T& output)
156  {
157  output = warp_shuffle(output, 0, WarpSize);
158  }
159 };
160 
161 } // end namespace detail
162 
163 END_ROCPRIM_NAMESPACE
164 
165 #endif // ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHUFFLE_HPP_
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle(const T &input, const int src_lane, const int width=device_warp_size())
Shuffle for any data type.
Definition: warp_shuffle.hpp:172
Definition: benchmark_block_reduce.cpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_down(const T &input, const unsigned int delta, const int width=device_warp_size())
Shuffle down for any data type.
Definition: warp_shuffle.hpp:222
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: warp_reduce_shuffle.hpp:43
Definition: various.hpp:52