rocPRIM
device_scan_by_key.hpp
1 // Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_
22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_
23 
24 #include "device_scan_common.hpp"
25 #include "lookback_scan_state.hpp"
26 
27 #include "../../block/block_discontinuity.hpp"
28 #include "../../block/block_load.hpp"
29 #include "../../block/block_scan.hpp"
30 #include "../../block/block_store.hpp"
31 #include "../../config.hpp"
32 #include "../../detail/binary_op_wrappers.hpp"
33 #include "../../intrinsics/thread.hpp"
34 #include "../../types/tuple.hpp"
35 
36 #include <type_traits>
37 
38 BEGIN_ROCPRIM_NAMESPACE
39 
40 namespace detail
41 {
42  template <bool Exclusive,
43  unsigned int block_size,
44  unsigned int items_per_thread,
45  typename key_type,
46  typename result_type,
47  ::rocprim::block_load_method load_keys_method,
48  ::rocprim::block_load_method load_values_method>
50  {
51  using block_load_keys
52  = ::rocprim::block_load<key_type, block_size, items_per_thread, load_keys_method>;
53 
54  using block_discontinuity = ::rocprim::block_discontinuity<key_type, block_size>;
55 
56  using block_load_values
57  = ::rocprim::block_load<result_type, block_size, items_per_thread, load_keys_method>;
58 
59  union storage_type {
60  struct {
61  typename block_load_keys::storage_type load;
63  } keys;
64  typename block_load_values::storage_type load_values;
65  };
66 
67  // Load flagged values
68  // - if the scan is exclusive, the last item of each segment (range where the keys compare equal)
69  // is flagged and reset to the initial value. Adding the last item of the range to the
70  // second to last using `headflag_scan_op_wrapper` will return the initial_value,
71  // which is exactly what should be saved at the start of the next range.
72  // - if the scan is inclusive, then the first item of each segment is marked, and it will
73  // restart the scan from that value
74  template <typename KeyIterator, typename ValueIterator, typename CompareFunction>
75  ROCPRIM_DEVICE void
76  load(KeyIterator keys_input,
77  ValueIterator values_input,
78  CompareFunction compare,
79  const result_type initial_value,
80  const unsigned int flat_block_id,
81  const size_t starting_block,
82  const size_t number_of_blocks,
83  const unsigned int flat_thread_id,
84  const size_t size,
85  rocprim::tuple<result_type, bool> (&wrapped_values)[items_per_thread],
86  storage_type& storage)
87  {
88  constexpr static unsigned int items_per_block = items_per_thread * block_size;
89  const unsigned int block_offset = flat_block_id * items_per_block;
90  KeyIterator block_keys = keys_input + block_offset;
91  ValueIterator block_values = values_input + block_offset;
92 
93  key_type keys[items_per_thread];
94  result_type values[items_per_thread];
95  bool flags[items_per_thread];
96 
97  auto not_equal
98  = [compare](const auto& a, const auto& b) mutable { return !compare(a, b); };
99 
100  const auto flag_segment_boundaries = [&]() {
101  if(Exclusive)
102  {
103  const key_type tile_successor
104  = starting_block + flat_block_id < number_of_blocks - 1
105  ? block_keys[items_per_block]
106  : *block_keys;
108  flags, tile_successor, keys, not_equal, storage.keys.flag);
109  }
110  else
111  {
112  const key_type tile_predecessor = starting_block + flat_block_id > 0
113  ? block_keys[-1]
114  : *block_keys;
116  flags, tile_predecessor, keys, not_equal, storage.keys.flag);
117  }
118  };
119 
120  if(starting_block + flat_block_id < number_of_blocks - 1)
121  {
122  block_load_keys{}.load(
123  block_keys,
124  keys,
125  storage.keys.load
126  );
127 
128  flag_segment_boundaries();
129  // Reusing shared memory for loading values
131 
132  block_load_values{}.load(
133  block_values,
134  values,
135  storage.load_values
136  );
137 
138  ROCPRIM_UNROLL
139  for(unsigned int i = 0; i < items_per_thread; ++i) {
140  rocprim::get<0>(wrapped_values[i])
141  = (Exclusive && flags[i]) ? initial_value : values[i];
142  rocprim::get<1>(wrapped_values[i]) = flags[i];
143  }
144  }
145  else
146  {
147  const unsigned int valid_in_last_block
148  = static_cast<unsigned int>(size - items_per_block * (number_of_blocks - 1));
149 
150  block_load_keys {}.load(
151  block_keys,
152  keys,
153  valid_in_last_block,
154  *block_keys, // Any value is okay, so discontinuity doesn't access undefined items
155  storage.keys.load);
156 
157  flag_segment_boundaries();
158  // Reusing shared memory for loading values
160 
161  block_load_values{}.load(
162  block_values,
163  values,
164  valid_in_last_block,
165  storage.load_values
166  );
167 
168  ROCPRIM_UNROLL
169  for(unsigned int i = 0; i < items_per_thread; ++i) {
170  if(flat_thread_id * items_per_thread + i >= valid_in_last_block) {
171  break;
172  }
173 
174  rocprim::get<0>(wrapped_values[i])
175  = (Exclusive && flags[i]) ? initial_value : values[i];
176  rocprim::get<1>(wrapped_values[i]) = flags[i];
177  }
178  }
179  }
180  };
181 
182  template <unsigned int block_size,
183  unsigned int items_per_thread,
184  typename result_type,
185  ::rocprim::block_store_method store_method>
187  {
188  using block_store_values
189  = ::rocprim::block_store<result_type, block_size, items_per_thread, store_method>;
190 
191  using storage_type = typename block_store_values::storage_type;
192 
193  template <typename OutputIterator>
194  ROCPRIM_DEVICE void
195  store(OutputIterator output,
196  const unsigned int flat_block_id,
197  const size_t starting_block,
198  const size_t number_of_blocks,
199  const unsigned int flat_thread_id,
200  const size_t size,
201  const rocprim::tuple<result_type, bool> (&wrapped_values)[items_per_thread],
202  storage_type& storage)
203  {
204  constexpr static unsigned int items_per_block = items_per_thread * block_size;
205  const unsigned int block_offset = flat_block_id * items_per_block;
206  OutputIterator block_output = output + block_offset;
207 
208  result_type thread_values[items_per_thread];
209 
210  if(starting_block + flat_block_id < number_of_blocks - 1)
211  {
212  ROCPRIM_UNROLL
213  for(unsigned int i = 0; i < items_per_thread; ++i) {
214  thread_values[i] = rocprim::get<0>(wrapped_values[i]);
215  }
216 
217  // Reusing shared memory from scan to perform store
219 
220  block_store_values {}.store(block_output, thread_values, storage);
221  }
222  else
223  {
224  const unsigned int valid_in_last_block
225  = static_cast<unsigned int>(size - items_per_block * (number_of_blocks - 1));
226 
227  ROCPRIM_UNROLL
228  for(unsigned int i = 0; i < items_per_thread; ++i) {
229  if(flat_thread_id * items_per_thread + i >= valid_in_last_block) {
230  break;
231  }
232 
233  thread_values[i] = rocprim::get<0>(wrapped_values[i]);
234  }
235 
236  // Reusing shared memory from scan to perform store
238 
239  block_store_values {}.store(
240  block_output, thread_values, valid_in_last_block, storage);
241  }
242  }
243  };
244 
245  template<bool Exclusive,
246  typename Config,
247  typename KeyInputIterator,
248  typename InputIterator,
249  typename OutputIterator,
250  typename ResultType,
251  typename CompareFunction,
252  typename BinaryFunction,
253  typename LookbackScanState>
254  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void device_scan_by_key_kernel_impl(
255  KeyInputIterator keys,
256  InputIterator values,
257  OutputIterator output,
258  ResultType initial_value,
259  const CompareFunction compare,
260  const BinaryFunction scan_op,
261  LookbackScanState scan_state,
262  const size_t size,
263  const size_t starting_block,
264  const size_t number_of_blocks,
265  const rocprim::tuple<ResultType, bool>* const previous_last_value)
266  {
267  using result_type = ResultType;
268  static_assert(std::is_same<rocprim::tuple<ResultType, bool>,
269  typename LookbackScanState::value_type>::value,
270  "value_type of LookbackScanState must be tuple of result type and flag");
271  static constexpr scan_by_key_config_params params = device_params<Config>();
272 
273  constexpr auto block_size = params.kernel_config.block_size;
274  constexpr auto items_per_thread = params.kernel_config.items_per_thread;
275  constexpr auto load_keys_method = params.block_load_method;
276  constexpr auto load_values_method = load_keys_method;
277 
278  using key_type = typename std::iterator_traits<KeyInputIterator>::value_type;
279  using load_flagged = load_values_flagged<Exclusive,
280  block_size,
281  items_per_thread,
282  key_type,
283  result_type,
284  load_keys_method,
285  load_values_method>;
286 
288  using wrapped_type = rocprim::tuple<result_type, bool>;
289 
290  using block_scan_type
291  = ::rocprim::block_scan<wrapped_type, block_size, params.block_scan_method>;
292 
293  constexpr auto store_method = params.block_store_method;
295 
296  ROCPRIM_SHARED_MEMORY union
297  {
298  typename load_flagged::storage_type load;
299  typename block_scan_type::storage_type scan;
300  typename store_unwrap::storage_type store;
301  } storage;
302 
303  const auto flat_thread_id = ::rocprim::detail::block_thread_id<0>();
304  const auto flat_block_id = ::rocprim::detail::block_id<0>();
305 
306  // Load input
307  wrapped_type wrapped_values[items_per_thread];
308  load_flagged {}.load(keys,
309  values,
310  compare,
311  initial_value,
312  flat_block_id,
313  starting_block,
314  number_of_blocks,
315  flat_thread_id,
316  size,
317  wrapped_values,
318  storage.load);
319 
320  // Reusing the storage from load to perform the scan
322 
323  // Perform look back scan scan
324  if(flat_block_id == 0)
325  {
326  auto wrapped_initial_value = rocprim::make_tuple(initial_value, false);
327 
328  // previous_last_value is used to pass the value from the previous grid, if this is a
329  // multi grid launch
330  if(previous_last_value != nullptr)
331  {
332  if(Exclusive) {
333  rocprim::get<0>(wrapped_initial_value) = rocprim::get<0>(*previous_last_value);
334  } else if (flat_thread_id == 0) {
335  wrapped_values[0] = wrapped_op(*previous_last_value, wrapped_values[0]);
336  }
337  }
338 
339  wrapped_type reduction;
340  lookback_block_scan<Exclusive, block_scan_type>(wrapped_values,
341  wrapped_initial_value,
342  reduction,
343  storage.scan,
344  wrapped_op);
345 
346  if(flat_thread_id == 0)
347  {
348  scan_state.set_complete(flat_block_id, reduction);
349  }
350  }
351  else
352  {
353  auto prefix_op = lookback_scan_prefix_op<wrapped_type,
354  decltype(wrapped_op),
355  decltype(scan_state)> {
356  flat_block_id, wrapped_op, scan_state};
357 
358  // Scan of block values
359  lookback_block_scan<Exclusive, block_scan_type>(
360  wrapped_values,
361  storage.scan,
362  prefix_op,
363  wrapped_op);
364  }
365 
366  // Store output
367  // synchronization is inside the function after unwrapping
368  store_unwrap {}.store(output,
369  flat_block_id,
370  starting_block,
371  number_of_blocks,
372  flat_thread_id,
373  size,
374  wrapped_values,
375  storage.store);
376  }
377 } // namespace detail
378 
379 END_ROCPRIM_NAMESPACE
380 
381 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_
Definition: device_scan_by_key.hpp:59
Definition: device_scan_by_key.hpp:186
The block_discontinuity class is a block level parallel primitive which provides methods for flagging...
Definition: block_discontinuity.hpp:82
block_store_method
block_store_method enumerates the methods available to store a striped arrangement of items into a bl...
Definition: block_store.hpp:41
Definition: lookback_scan_state.hpp:356
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void flag_heads(Flag(&head_flags)[ItemsPerThread], const T(&input)[ItemsPerThread], FlagOp flag_op, storage_type &storage)
Tags head_flags that indicate discontinuities between items partitioned across the thread block...
Definition: block_discontinuity.hpp:156
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
block_load_method
block_load_method enumerates the methods available to load data from continuous memory into a blocked...
Definition: block_load.hpp:41
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_id()
Returns flat (linear, 1D) block identifier in a multidimensional grid.
Definition: thread.hpp:178
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_size()
Returns block size in a multidimensional grid by dimension.
Definition: thread.hpp:268
Definition: binary_op_wrappers.hpp:72
Definition: test_device_binary_search.cpp:37
Provides the kernel parameters for exclusive_scan_by_key and inclusive_scan_by_key based on autotuned...
Definition: device_config_helper.hpp:393
Definition: device_scan_by_key.hpp:49
ROCPRIM_DEVICE ROCPRIM_INLINE void flag_tails(Flag(&tail_flags)[ItemsPerThread], const T(&input)[ItemsPerThread], FlagOp flag_op, storage_type &storage)
Tags tail_flags that indicate discontinuities between items partitioned across the thread block...
Definition: block_discontinuity.hpp:304