rocPRIM
warp_reduce_shared_mem.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHARED_MEM_HPP_
22 #define ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHARED_MEM_HPP_
23 
24 #include <type_traits>
25 
26 #include "../../config.hpp"
27 #include "../../intrinsics.hpp"
28 #include "../../types.hpp"
29 #include "../../detail/various.hpp"
30 
31 #include "warp_segment_bounds.hpp"
32 
33 BEGIN_ROCPRIM_NAMESPACE
34 
35 namespace detail
36 {
37 
38 template<
39  class T,
40  unsigned int WarpSize,
41  bool UseAllReduce
42 >
44 {
45  struct storage_type_
46  {
47  T values[WarpSize];
48  };
49 
50 public:
52 
53  template<class BinaryFunction>
54  ROCPRIM_DEVICE ROCPRIM_INLINE
55  void reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op)
56  {
57  constexpr unsigned int ceiling = next_power_of_two(WarpSize);
58  const unsigned int lid = detail::logical_lane_id<WarpSize>();
59  storage_type_& storage_ = storage.get();
60 
61  output = input;
62  storage_.values[lid] = output;
64  ROCPRIM_UNROLL
65  for(unsigned int i = ceiling >> 1; i > 0; i >>= 1)
66  {
67  const bool do_op = lid + i < WarpSize && lid < i;
68  if(do_op)
69  {
70  output = storage_.values[lid];
71  T other = storage_.values[lid + i];
72  output = reduce_op(output, other);
73  }
75  if(do_op)
76  {
77  storage_.values[lid] = output;
78  }
80  }
81  set_output<UseAllReduce>(output, storage);
82  }
83 
84  template<class BinaryFunction>
85  ROCPRIM_DEVICE ROCPRIM_INLINE
86  void reduce(T input, T& output, unsigned int valid_items,
87  storage_type& storage, BinaryFunction reduce_op)
88  {
89  constexpr unsigned int ceiling = next_power_of_two(WarpSize);
90  const unsigned int lid = detail::logical_lane_id<WarpSize>();
91  storage_type_& storage_ = storage.get();
92 
93  output = input;
94  storage_.values[lid] = output;
96  ROCPRIM_UNROLL
97  for(unsigned int i = ceiling >> 1; i > 0; i >>= 1)
98  {
99  const bool do_op = (lid + i) < WarpSize && lid < i && (lid + i) < valid_items;
100  if(do_op)
101  {
102  output = storage_.values[lid];
103  T other = storage_.values[lid + i];
104  output = reduce_op(output, other);
105  }
107  if(do_op)
108  {
109  storage_.values[lid] = output;
110  }
112  }
113  set_output<UseAllReduce>(output, storage);
114  }
115 
116  template<class Flag, class BinaryFunction>
117  ROCPRIM_DEVICE ROCPRIM_INLINE
118  void head_segmented_reduce(T input, T& output, Flag flag,
119  storage_type& storage, BinaryFunction reduce_op)
120  {
121  this->segmented_reduce<true>(input, output, flag, storage, reduce_op);
122  }
123 
124  template<class Flag, class BinaryFunction>
125  ROCPRIM_DEVICE ROCPRIM_INLINE
126  void tail_segmented_reduce(T input, T& output, Flag flag,
127  storage_type& storage, BinaryFunction reduce_op)
128  {
129  this->segmented_reduce<false>(input, output, flag, storage, reduce_op);
130  }
131 
132 private:
133  template<bool HeadSegmented, class Flag, class BinaryFunction>
134  ROCPRIM_DEVICE ROCPRIM_INLINE
135  void segmented_reduce(T input, T& output, Flag flag,
136  storage_type& storage, BinaryFunction reduce_op)
137  {
138  const unsigned int lid = detail::logical_lane_id<WarpSize>();
139  constexpr unsigned int ceiling = next_power_of_two(WarpSize);
140  storage_type_& storage_ = storage.get();
141  // Get logical lane id of the last valid value in the segment
142  auto last = last_in_warp_segment<HeadSegmented, WarpSize>(flag);
143 
144  output = input;
145  ROCPRIM_UNROLL
146  for(unsigned int i = 1; i < ceiling; i *= 2)
147  {
148  storage_.values[lid] = output;
150  if((lid + i) <= last)
151  {
152  T other = storage_.values[lid + i];
153  output = reduce_op(output, other);
154  }
156  }
157  }
158 
159  template<bool Switch>
160  ROCPRIM_DEVICE ROCPRIM_INLINE
161  typename std::enable_if<(Switch == false)>::type
162  set_output(T& output, storage_type& storage)
163  {
164  (void) output;
165  (void) storage;
166  // output already set correctly
167  }
168 
169  template<bool Switch>
170  ROCPRIM_DEVICE ROCPRIM_INLINE
171  typename std::enable_if<(Switch == true)>::type
172  set_output(T& output, storage_type& storage)
173  {
174  storage_type_& storage_ = storage.get();
175  output = storage_.values[0];
176  }
177 };
178 
179 } // end namespace detail
180 
181 END_ROCPRIM_NAMESPACE
182 
183 #endif // ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHARED_MEM_HPP_
Definition: benchmark_block_reduce.cpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE void wave_barrier()
Synchronize all threads in the wavefront.
Definition: thread.hpp:235
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: warp_reduce_shared_mem.hpp:43