rocPRIM
block_sort_merge.hpp
1 // Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_MERGE_HPP_
22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_MERGE_HPP_
23 
24 #include "../../config.hpp"
25 #include "../../detail/merge_path.hpp"
26 #include "../../detail/various.hpp"
27 #include "../../warp/detail/warp_sort_stable.hpp"
28 #include "../../warp/warp_sort.hpp"
29 
30 BEGIN_ROCPRIM_NAMESPACE
31 
32 namespace detail
33 {
34 
35 template<class Key,
36  unsigned int BlockSizeX,
37  unsigned int BlockSizeY,
38  unsigned int BlockSizeZ,
39  unsigned int ItemsPerThread,
40  class Value>
42 {
43  static constexpr const unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
44  static constexpr const unsigned int ItemsPerBlock = BlockSize * ItemsPerThread;
45  static constexpr const unsigned int WarpSortSize = std::min(BlockSize, 16u);
46  static constexpr const bool with_values = !std::is_same<Value, rocprim::empty_type>::value;
47  // stable sort gives superior performance over the non-stable variant
48  using warp_sort_type
49  = rocprim::detail::warp_sort_stable<Key, BlockSize, WarpSortSize, ItemsPerThread, Value>;
50 
51  static_assert(rocprim::detail::is_power_of_two(BlockSize),
52  "BlockSize must be a power of two for block_sort_merge!");
53 
54  static_assert(rocprim::detail::is_power_of_two(ItemsPerThread),
55  "ItemsPerThread must be a power of two for block_sort_merge!");
56 
57  template<bool with_values>
58  union storage_type_
59  {
60  typename warp_sort_type::storage_type warp_sort;
62  };
63 
64  template<>
65  union storage_type_<true>
66  {
67  typename warp_sort_type::storage_type warp_sort;
68  struct
69  {
72  };
73  };
74 
75 public:
76  using storage_type = storage_type_<with_values>;
77 
78  template<class BinaryFunction>
79  ROCPRIM_DEVICE ROCPRIM_INLINE void
80  sort(Key& thread_key, storage_type& storage, BinaryFunction compare_function)
81  {
82  Key thread_keys[] = {thread_key};
83  this->sort_impl<ItemsPerBlock>(
84  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
85  storage,
86  compare_function,
87  thread_keys);
88  thread_key = thread_keys[0];
89  }
90 
91  template<class BinaryFunction>
92  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&thread_keys)[ItemsPerThread],
93  storage_type& storage,
94  BinaryFunction compare_function)
95  {
96  this->sort_impl<ItemsPerBlock>(
97  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
98  storage,
99  compare_function,
100  thread_keys);
101  }
102 
103  template<class BinaryFunction>
104  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key& thread_key, BinaryFunction compare_function)
105  {
106  ROCPRIM_SHARED_MEMORY storage_type storage;
107  this->sort(thread_key, storage, compare_function);
108  }
109 
110  template<class BinaryFunction>
111  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key (&thread_keys)[ItemsPerThread],
112  BinaryFunction compare_function)
113  {
114  ROCPRIM_SHARED_MEMORY storage_type storage;
115  this->sort(thread_keys, storage, compare_function);
116  }
117 
118  template<class BinaryFunction>
119  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key& thread_key,
120  Value& thread_value,
121  storage_type& storage,
122  BinaryFunction compare_function)
123  {
124  Key thread_keys[] = {thread_key};
125  Value thread_values[] = {thread_value};
126  this->sort_impl<ItemsPerBlock>(
127  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
128  storage,
129  compare_function,
130  thread_keys,
131  thread_values);
132  thread_key = thread_keys[0];
133  thread_value = thread_values[0];
134  }
135 
136  template<class BinaryFunction>
137  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&thread_keys)[ItemsPerThread],
138  Value (&thread_values)[ItemsPerThread],
139  storage_type& storage,
140  BinaryFunction compare_function)
141  {
142  this->sort_impl<ItemsPerBlock>(
143  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
144  storage,
145  compare_function,
146  thread_keys,
147  thread_values);
148  }
149 
150  template<class BinaryFunction>
151  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void
152  sort(Key& thread_key, Value& thread_value, BinaryFunction compare_function)
153  {
154  ROCPRIM_SHARED_MEMORY storage_type storage;
155  this->sort(thread_key, thread_value, storage, compare_function);
156  }
157 
158  template<class BinaryFunction>
159  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key (&thread_keys)[ItemsPerThread],
160  Value (&thread_values)[ItemsPerThread],
161  BinaryFunction compare_function)
162  {
163  ROCPRIM_SHARED_MEMORY storage_type storage;
164  this->sort(thread_keys, thread_values, storage, compare_function);
165  }
166 
167  template<class BinaryFunction>
168  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key& thread_key,
169  storage_type& storage,
170  unsigned int size,
171  BinaryFunction compare_function)
172  {
173  Key thread_keys[] = {thread_key};
174  this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
175  size,
176  storage,
177  compare_function,
178  thread_keys);
179  thread_key = thread_keys[0];
180  }
181 
182  template<class BinaryFunction>
183  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key (&thread_keys)[ItemsPerThread],
184  storage_type& storage,
185  unsigned int size,
186  BinaryFunction compare_function)
187  {
188  this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
189  size,
190  storage,
191  compare_function,
192  thread_keys);
193  }
194 
195  template<class BinaryFunction>
196  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key& thread_key,
197  Value& thread_value,
198  storage_type& storage,
199  unsigned int size,
200  BinaryFunction compare_function)
201  {
202  Key thread_keys[] = {thread_key};
203  Value thread_values[] = {thread_value};
204  this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
205  size,
206  storage,
207  compare_function,
208  thread_keys,
209  thread_values);
210  thread_key = thread_keys[0];
211  thread_value = thread_values[0];
212  }
213 
214  template<class BinaryFunction>
215  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key (&thread_keys)[ItemsPerThread],
216  Value (&thread_values)[ItemsPerThread],
217  storage_type& storage,
218  unsigned int size,
219  BinaryFunction compare_function)
220  {
221  this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
222  size,
223  storage,
224  compare_function,
225  thread_keys,
226  thread_values);
227  }
228 
229 private:
230  ROCPRIM_DEVICE ROCPRIM_INLINE void
231  copy_to_shared(Key& k, const unsigned int flat_tid, Key* keys_shared)
232  {
233  keys_shared[flat_tid] = k;
235  }
236 
237  ROCPRIM_DEVICE ROCPRIM_INLINE void
238  copy_to_shared(Key (&k)[ItemsPerThread], const unsigned int flat_tid, Key* keys_shared)
239  {
240  ROCPRIM_UNROLL
241  for(unsigned int item = 0; item < ItemsPerThread; ++item)
242  {
243  keys_shared[ItemsPerThread * flat_tid + item] = k[item];
244  }
246  }
247 
248  ROCPRIM_DEVICE ROCPRIM_INLINE void copy_to_shared(
249  Key& k, Value& v, const unsigned int flat_tid, Key* keys_shared, Value* values_shared)
250  {
251  keys_shared[flat_tid] = k;
252  values_shared[flat_tid] = v;
254  }
255 
256  ROCPRIM_DEVICE ROCPRIM_INLINE void copy_to_shared(Key (&k)[ItemsPerThread],
257  Value (&v)[ItemsPerThread],
258  const unsigned int flat_tid,
259  Key* keys_shared,
260  Value* values_shared)
261  {
262  ROCPRIM_UNROLL
263  for(unsigned int item = 0; item < ItemsPerThread; ++item)
264  {
265  keys_shared[ItemsPerThread * flat_tid + item] = k[item];
266  values_shared[ItemsPerThread * flat_tid + item] = v[item];
267  }
269  }
270 
271  template<unsigned int Size, class BinaryFunction>
272  ROCPRIM_DEVICE ROCPRIM_INLINE void sort_impl(const unsigned int flat_tid,
273  storage_type& storage,
274  BinaryFunction compare_function,
275  Key (&keys)[ItemsPerThread])
276  {
277  if(Size > ItemsPerBlock)
278  {
279  return;
280  }
281  warp_sort_type ws;
282  ws.sort(keys, storage.warp_sort, compare_function);
283  sort_merge_impl(flat_tid,
284  Size,
285  ItemsPerThread * WarpSortSize,
286  storage,
287  compare_function,
288  keys);
289  }
290 
291  template<unsigned int Size, class BinaryFunction>
292  ROCPRIM_DEVICE ROCPRIM_INLINE void sort_impl(const unsigned int flat_tid,
293  storage_type& storage,
294  BinaryFunction compare_function,
295  Key (&keys)[ItemsPerThread],
296  Value (&values)[ItemsPerThread])
297  {
298  if(Size > ItemsPerBlock)
299  {
300  return;
301  }
302  warp_sort_type ws;
303  ws.sort(keys, values, storage.warp_sort, compare_function);
304  sort_merge_impl(flat_tid,
305  Size,
306  ItemsPerThread * WarpSortSize,
307  storage,
308  compare_function,
309  keys,
310  values);
311  }
312 
313  template<class BinaryFunction>
314  ROCPRIM_DEVICE ROCPRIM_INLINE void sort_impl(const unsigned int flat_tid,
315  const unsigned int input_size,
316  storage_type& storage,
317  BinaryFunction compare_function,
318  Key (&keys)[ItemsPerThread])
319  {
320  warp_sort_type ws;
321  ws.sort(keys, storage.warp_sort, input_size, compare_function);
322  sort_merge_impl(flat_tid,
323  input_size,
324  ItemsPerThread * WarpSortSize,
325  storage,
326  compare_function,
327  keys);
328  }
329 
330  template<class BinaryFunction>
331  ROCPRIM_DEVICE ROCPRIM_INLINE void sort_impl(const unsigned int flat_tid,
332  const unsigned int input_size,
333  storage_type& storage,
334  BinaryFunction compare_function,
335  Key (&keys)[ItemsPerThread],
336  Value (&values)[ItemsPerThread])
337  {
338  warp_sort_type ws;
339  ws.sort(keys, values, storage.warp_sort, input_size, compare_function);
340  sort_merge_impl(flat_tid,
341  input_size,
342  ItemsPerThread * WarpSortSize,
343  storage,
344  compare_function,
345  keys,
346  values);
347  }
348 
349  template<class BinaryFunction>
350  ROCPRIM_DEVICE ROCPRIM_INLINE void sort_merge_impl(const unsigned int flat_tid,
351  const unsigned int input_size,
352  unsigned int sorted_block_size,
353  storage_type& storage,
354  BinaryFunction compare_function,
355  Key (&thread_keys)[ItemsPerThread])
356  {
357  const unsigned int thread_offset = flat_tid * ItemsPerThread;
358  auto& keys_shared = storage.keys.get();
359 
360  if(ItemsPerThread == 1 && thread_offset > input_size)
361  return;
362  // loop as long as sorted_block_size < input_size
363  while(sorted_block_size < input_size)
364  {
365  copy_to_shared(thread_keys, flat_tid, keys_shared);
366  const unsigned int target_sorted_block_size = sorted_block_size * 2;
367  const unsigned int mask = target_sorted_block_size - 1;
368  const unsigned int keys1_beg = ~mask & thread_offset;
369  const unsigned int keys1_end = std::min(input_size, keys1_beg + sorted_block_size);
370  const unsigned int keys2_end = std::min(input_size, keys1_end + sorted_block_size);
371  sorted_block_size = target_sorted_block_size;
372  const unsigned int diag0_local = std::min(input_size, mask & thread_offset);
373 
374  const unsigned int num_keys1 = keys1_end - keys1_beg;
375  const unsigned int num_keys2 = keys2_end - keys1_end;
376 
377  const unsigned int keys1_beg_local = merge_path(&keys_shared[keys1_beg],
378  &keys_shared[keys1_end],
379  num_keys1,
380  num_keys2,
381  diag0_local,
382  compare_function);
383  const unsigned int keys2_beg_local = diag0_local - keys1_beg_local;
384  range_t range_local
385  = {keys1_beg_local + keys1_beg, keys1_end, keys2_beg_local + keys1_end, keys2_end};
386 
387  serial_merge(keys_shared, thread_keys, range_local, compare_function);
388  }
389  }
390 
391  template<class BinaryFunction>
392  ROCPRIM_DEVICE ROCPRIM_INLINE void sort_merge_impl(const unsigned int flat_tid,
393  const unsigned int input_size,
394  unsigned int sorted_block_size,
395  storage_type& storage,
396  BinaryFunction compare_function,
397  Key (&thread_keys)[ItemsPerThread],
398  Value (&thread_values)[ItemsPerThread])
399  {
400  const unsigned int thread_offset = flat_tid * ItemsPerThread;
401  auto& keys_shared = storage.keys.get();
402  auto& values_shared = storage.values.get();
403  // loop as long as sorted_block_size < input_size
404  while(sorted_block_size < input_size)
405  {
406  copy_to_shared(thread_keys, thread_values, flat_tid, keys_shared, values_shared);
407  const unsigned int target_sorted_block_size = sorted_block_size * 2;
408  const unsigned int mask = target_sorted_block_size - 1;
409  const unsigned int keys1_beg = ~mask & thread_offset;
410  const unsigned int keys1_end = std::min(input_size, keys1_beg + sorted_block_size);
411  const unsigned int keys2_end = std::min(input_size, keys1_end + sorted_block_size);
412  sorted_block_size = target_sorted_block_size;
413  const unsigned int diag0_local = std::min(input_size, mask & thread_offset);
414 
415  const unsigned int num_keys1 = keys1_end - keys1_beg;
416  const unsigned int num_keys2 = keys2_end - keys1_end;
417 
418  const unsigned int keys1_beg_local = merge_path(&keys_shared[keys1_beg],
419  &keys_shared[keys1_end],
420  num_keys1,
421  num_keys2,
422  diag0_local,
423  compare_function);
424  const unsigned int keys2_beg_local = diag0_local - keys1_beg_local;
425  range_t range_local
426  = {keys1_beg_local + keys1_beg, keys1_end, keys2_beg_local + keys1_end, keys2_end};
427 
428  serial_merge(keys_shared,
429  thread_keys,
430  values_shared,
431  thread_values,
432  range_local,
433  compare_function);
434  }
435  }
436 };
437 
438 } // end namespace detail
439 
440 END_ROCPRIM_NAMESPACE
441 
442 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_MERGE_HPP_
Definition: block_sort_merge.hpp:41
The warp_sort class provides warp-wide methods for computing a parallel sort of items across thread w...
Definition: warp_sort.hpp:99
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
Definition: merge_path.hpp:33