rocPRIM
block_adjacent_difference_impl.hpp
1 // Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_ADJACENT_DIFFERENCE_IMPL_HPP_
22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_ADJACENT_DIFFERENCE_IMPL_HPP_
23 
24 #include "../../config.hpp"
25 #include "../../detail/various.hpp"
26 #include "../../intrinsics/thread.hpp"
27 
28 #include <type_traits>
29 
30 #include <cassert>
31 
32 BEGIN_ROCPRIM_NAMESPACE
33 
34 namespace detail
35 {
36 
37 // Wrapping function that allows to call BinaryFunction of any of these signatures:
38 // with b_index (a, b, b_index) or without it (a, b).
39 // Only in the case of discontinuity (when flags_style is true) is the operator allowed to take an
40 // index
41 // block_discontinuity and block_adjacent difference only differ in their implementations by the
42 // order the operators parameters are passed, so this method deals with this as well
43 template <class T, class BinaryFunction>
44 ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op,
45  const T& a,
46  const T& b,
47  unsigned int index,
48  bool_constant<true> /*as_flags*/,
49  bool_constant<false> /*reversed*/) -> decltype(op(b, a, index))
50 {
51  return op(a, b, index);
52 }
53 
54 template <class T, class BinaryFunction>
55 ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op,
56  const T& a,
57  const T& b,
58  unsigned int index,
59  bool_constant<true> /*as_flags*/,
60  bool_constant<true> /*reversed*/)
61  -> decltype(op(b, a, index))
62 {
63  return op(b, a, index);
64 }
65 
66 template <typename T, typename BinaryFunction, bool AsFlags>
67 ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op,
68  const T& a,
69  const T& b,
70  unsigned int,
71  bool_constant<AsFlags> /*as_flags*/,
72  bool_constant<false> /*reversed*/) -> decltype(op(b, a))
73 {
74  return op(a, b);
75 }
76 
77 template <typename T, typename BinaryFunction, bool AsFlags>
78 ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op,
79  const T& a,
80  const T& b,
81  unsigned int,
82  bool_constant<AsFlags> /*as_flags*/,
83  bool_constant<true> /*reversed*/) -> decltype(op(b, a))
84 {
85  return op(b, a);
86 }
87 
88 template <typename T,
89  unsigned int BlockSizeX,
90  unsigned int BlockSizeY = 1,
91  unsigned int BlockSizeZ = 1>
93 {
94 public:
95  static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
96  struct storage_type
97  {
98  T items[BlockSize];
99  };
100 
101  template <bool AsFlags,
102  bool Reversed,
103  bool WithTilePredecessor,
104  unsigned int ItemsPerThread,
105  typename Output,
106  typename BinaryFunction>
107  ROCPRIM_DEVICE void apply_left(const T (&input)[ItemsPerThread],
108  Output (&output)[ItemsPerThread],
109  BinaryFunction op,
110  const T tile_predecessor_item,
111  storage_type& storage)
112  {
113  static constexpr auto as_flags = bool_constant<AsFlags> {};
114  static constexpr auto reversed = bool_constant<Reversed> {};
115 
116  const unsigned int flat_id
117  = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
118 
119  // Save the last item of each thread
120  storage.items[flat_id] = input[ItemsPerThread - 1];
121 
122  ROCPRIM_UNROLL
123  for(unsigned int i = ItemsPerThread - 1; i > 0; --i)
124  {
125  output[i] = detail::apply(
126  op, input[i - 1], input[i], flat_id * ItemsPerThread + i, as_flags, reversed);
127  }
129 
130  if ROCPRIM_IF_CONSTEXPR (WithTilePredecessor)
131  {
132  T predecessor_item = tile_predecessor_item;
133  if(flat_id != 0) {
134  predecessor_item = storage.items[flat_id - 1];
135  }
136 
137  output[0] = detail::apply(
138  op, predecessor_item, input[0], flat_id * ItemsPerThread, as_flags, reversed);
139  }
140  else
141  {
142  output[0] = get_default_item(input, 0, as_flags);
143  if(flat_id != 0) {
144  output[0] = detail::apply(op,
145  storage.items[flat_id - 1],
146  input[0],
147  flat_id * ItemsPerThread,
148  as_flags,
149  reversed);
150  }
151  }
152  }
153 
154  template <bool AsFlags,
155  bool Reversed,
156  bool WithTilePredecessor,
157  unsigned int ItemsPerThread,
158  typename Output,
159  typename BinaryFunction>
160  ROCPRIM_DEVICE void apply_left_partial(const T (&input)[ItemsPerThread],
161  Output (&output)[ItemsPerThread],
162  BinaryFunction op,
163  const T tile_predecessor_item,
164  const unsigned int valid_items,
165  storage_type& storage)
166  {
167  static constexpr auto as_flags = bool_constant<AsFlags> {};
168  static constexpr auto reversed = bool_constant<Reversed> {};
169 
170  const unsigned int flat_id
171  = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
172 
173  // Save the last item of each thread
174  storage.items[flat_id] = input[ItemsPerThread - 1];
175 
176  ROCPRIM_UNROLL
177  for(unsigned int i = ItemsPerThread - 1; i > 0; --i)
178  {
179  const unsigned int index = flat_id * ItemsPerThread + i;
180  output[i] = get_default_item(input, i, as_flags);
181  if(index < valid_items) {
182  output[i] = detail::apply(op, input[i - 1], input[i], index, as_flags, reversed);
183  }
184  }
186 
187  const unsigned int index = flat_id * ItemsPerThread;
188 
189  if ROCPRIM_IF_CONSTEXPR (WithTilePredecessor)
190  {
191  T predecessor_item = tile_predecessor_item;
192  if(flat_id != 0) {
193  predecessor_item = storage.items[flat_id - 1];
194  }
195 
196  output[0] = get_default_item(input, 0, as_flags);
197  if(index < valid_items)
198  {
199  output[0]
200  = detail::apply(op, predecessor_item, input[0], index, as_flags, reversed);
201  }
202  }
203  else
204  {
205  output[0] = get_default_item(input, 0, as_flags);
206  if(flat_id != 0 && index < valid_items)
207  {
208  output[0] = detail::apply(op,
209  storage.items[flat_id - 1],
210  input[0],
211  flat_id * ItemsPerThread,
212  as_flags,
213  reversed);
214  }
215  }
216  }
217 
218  template <bool AsFlags,
219  bool Reversed,
220  bool WithTileSuccessor,
221  unsigned int ItemsPerThread,
222  typename Output,
223  typename BinaryFunction>
224  ROCPRIM_DEVICE void apply_right(const T (&input)[ItemsPerThread],
225  Output (&output)[ItemsPerThread],
226  BinaryFunction op,
227  const T tile_successor_item,
228  storage_type& storage)
229  {
230  static constexpr auto as_flags = bool_constant<AsFlags> {};
231  static constexpr auto reversed = bool_constant<Reversed> {};
232 
233  const unsigned int flat_id
234  = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
235 
236  // Save the first item of each thread
237  storage.items[flat_id] = input[0];
238 
239  ROCPRIM_UNROLL
240  for(unsigned int i = 0; i < ItemsPerThread - 1; ++i)
241  {
242  output[i] = detail::apply(
243  op, input[i], input[i + 1], flat_id * ItemsPerThread + i + 1, as_flags, reversed);
244  }
246 
247  if ROCPRIM_IF_CONSTEXPR (WithTileSuccessor)
248  {
249  T successor_item = tile_successor_item;
250  if(flat_id != BlockSize - 1) {
251  successor_item = storage.items[flat_id + 1];
252  }
253 
254  output[ItemsPerThread - 1] = detail::apply(op,
255  input[ItemsPerThread - 1],
256  successor_item,
257  flat_id * ItemsPerThread + ItemsPerThread,
258  as_flags,
259  reversed);
260  }
261  else
262  {
263  output[ItemsPerThread - 1] = get_default_item(input, ItemsPerThread - 1, as_flags);
264  if(flat_id != BlockSize - 1) {
265  output[ItemsPerThread - 1]
266  = detail::apply(op,
267  input[ItemsPerThread - 1],
268  storage.items[flat_id + 1],
269  flat_id * ItemsPerThread + ItemsPerThread,
270  as_flags,
271  reversed);
272  }
273  }
274  }
275  template <bool AsFlags,
276  bool Reversed,
277  unsigned int ItemsPerThread,
278  typename Output,
279  typename BinaryFunction>
280  ROCPRIM_DEVICE void apply_right_partial(const T (&input)[ItemsPerThread],
281  Output (&output)[ItemsPerThread],
282  BinaryFunction op,
283  const unsigned int valid_items,
284  storage_type& storage)
285  {
286  static constexpr auto as_flags = bool_constant<AsFlags> {};
287  static constexpr auto reversed = bool_constant<Reversed> {};
288 
289  const unsigned int flat_id
290  = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
291 
292  // Save the first item of each thread
293  storage.items[flat_id] = input[0];
294 
295  ROCPRIM_UNROLL
296  for(unsigned int i = 0; i < ItemsPerThread - 1; ++i)
297  {
298  const unsigned int index = flat_id * ItemsPerThread + i + 1;
299  output[i] = get_default_item(input, i, as_flags);
300  if(index < valid_items)
301  {
302  output[i] = detail::apply(op, input[i], input[i + 1], index, as_flags, reversed);
303  }
304  }
306 
307  output[ItemsPerThread - 1] = get_default_item(input, ItemsPerThread - 1, as_flags);
308 
309  const unsigned int next_thread_index = flat_id * ItemsPerThread + ItemsPerThread;
310  if(next_thread_index < valid_items)
311  {
312  output[ItemsPerThread - 1] = detail::apply(op,
313  input[ItemsPerThread - 1],
314  storage.items[flat_id + 1],
315  next_thread_index,
316  as_flags,
317  reversed);
318  }
319  }
320 
321 private:
322  template <unsigned int ItemsPerThread>
323  ROCPRIM_DEVICE int get_default_item(const T (&)[ItemsPerThread],
324  unsigned int /*index*/,
325  bool_constant<true> /*as_flags*/)
326  {
327  return 1;
328  }
329 
330  template <unsigned int ItemsPerThread>
331  ROCPRIM_DEVICE T get_default_item(const T (&input)[ItemsPerThread],
332  const unsigned int index,
333  bool_constant<false> /*as_flags*/)
334  {
335  return input[index];
336  }
337 };
338 
339 } // namespace detail
340 
341 END_ROCPRIM_NAMESPACE
342 
343 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_ADJACENT_DIFFERENCE_IMPL_HPP_
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
Definition: block_adjacent_difference_impl.hpp:96
Definition: block_adjacent_difference_impl.hpp:92