21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_MERGE_HPP_ 22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_MERGE_HPP_ 24 #include "../../config.hpp" 25 #include "../../detail/merge_path.hpp" 26 #include "../../detail/various.hpp" 27 #include "../../warp/detail/warp_sort_stable.hpp" 28 #include "../../warp/warp_sort.hpp" 30 BEGIN_ROCPRIM_NAMESPACE
36 unsigned int BlockSizeX,
37 unsigned int BlockSizeY,
38 unsigned int BlockSizeZ,
39 unsigned int ItemsPerThread,
43 static constexpr
const unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
44 static constexpr
const unsigned int ItemsPerBlock = BlockSize * ItemsPerThread;
45 static constexpr
const unsigned int WarpSortSize =
std::min(BlockSize, 16u);
46 static constexpr
const bool with_values = !std::is_same<Value, rocprim::empty_type>::value;
49 = rocprim::detail::warp_sort_stable<Key, BlockSize, WarpSortSize, ItemsPerThread, Value>;
51 static_assert(rocprim::detail::is_power_of_two(BlockSize),
52 "BlockSize must be a power of two for block_sort_merge!");
54 static_assert(rocprim::detail::is_power_of_two(ItemsPerThread),
55 "ItemsPerThread must be a power of two for block_sort_merge!");
57 template<
bool with_values>
60 typename warp_sort_type::storage_type
warp_sort;
65 union storage_type_<true>
67 typename warp_sort_type::storage_type
warp_sort;
76 using storage_type = storage_type_<with_values>;
78 template<
class BinaryFunction>
79 ROCPRIM_DEVICE ROCPRIM_INLINE
void 80 sort(Key& thread_key, storage_type& storage, BinaryFunction compare_function)
82 Key thread_keys[] = {thread_key};
83 this->sort_impl<ItemsPerBlock>(
84 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
88 thread_key = thread_keys[0];
91 template<
class BinaryFunction>
92 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
93 storage_type& storage,
94 BinaryFunction compare_function)
96 this->sort_impl<ItemsPerBlock>(
97 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
103 template<
class BinaryFunction>
104 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void sort(Key& thread_key, BinaryFunction compare_function)
106 ROCPRIM_SHARED_MEMORY storage_type storage;
107 this->sort(thread_key, storage, compare_function);
110 template<
class BinaryFunction>
111 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
112 BinaryFunction compare_function)
114 ROCPRIM_SHARED_MEMORY storage_type storage;
115 this->sort(thread_keys, storage, compare_function);
118 template<
class BinaryFunction>
119 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
121 storage_type& storage,
122 BinaryFunction compare_function)
124 Key thread_keys[] = {thread_key};
125 Value thread_values[] = {thread_value};
126 this->sort_impl<ItemsPerBlock>(
127 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
132 thread_key = thread_keys[0];
133 thread_value = thread_values[0];
136 template<
class BinaryFunction>
137 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
138 Value (&thread_values)[ItemsPerThread],
139 storage_type& storage,
140 BinaryFunction compare_function)
142 this->sort_impl<ItemsPerBlock>(
143 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
150 template<
class BinaryFunction>
151 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void 152 sort(Key& thread_key, Value& thread_value, BinaryFunction compare_function)
154 ROCPRIM_SHARED_MEMORY storage_type storage;
155 this->sort(thread_key, thread_value, storage, compare_function);
158 template<
class BinaryFunction>
159 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
160 Value (&thread_values)[ItemsPerThread],
161 BinaryFunction compare_function)
163 ROCPRIM_SHARED_MEMORY storage_type storage;
164 this->sort(thread_keys, thread_values, storage, compare_function);
167 template<
class BinaryFunction>
168 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
169 storage_type& storage,
171 BinaryFunction compare_function)
173 Key thread_keys[] = {thread_key};
174 this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
179 thread_key = thread_keys[0];
182 template<
class BinaryFunction>
183 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
184 storage_type& storage,
186 BinaryFunction compare_function)
188 this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
195 template<
class BinaryFunction>
196 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
198 storage_type& storage,
200 BinaryFunction compare_function)
202 Key thread_keys[] = {thread_key};
203 Value thread_values[] = {thread_value};
204 this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
210 thread_key = thread_keys[0];
211 thread_value = thread_values[0];
214 template<
class BinaryFunction>
215 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
216 Value (&thread_values)[ItemsPerThread],
217 storage_type& storage,
219 BinaryFunction compare_function)
221 this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
230 ROCPRIM_DEVICE ROCPRIM_INLINE
void 231 copy_to_shared(Key& k,
const unsigned int flat_tid, Key* keys_shared)
233 keys_shared[flat_tid] = k;
237 ROCPRIM_DEVICE ROCPRIM_INLINE
void 238 copy_to_shared(Key (&k)[ItemsPerThread],
const unsigned int flat_tid, Key* keys_shared)
241 for(
unsigned int item = 0; item < ItemsPerThread; ++item)
243 keys_shared[ItemsPerThread * flat_tid + item] = k[item];
248 ROCPRIM_DEVICE ROCPRIM_INLINE
void copy_to_shared(
249 Key& k, Value& v,
const unsigned int flat_tid, Key* keys_shared, Value* values_shared)
251 keys_shared[flat_tid] = k;
252 values_shared[flat_tid] = v;
256 ROCPRIM_DEVICE ROCPRIM_INLINE
void copy_to_shared(Key (&k)[ItemsPerThread],
257 Value (&v)[ItemsPerThread],
258 const unsigned int flat_tid,
260 Value* values_shared)
263 for(
unsigned int item = 0; item < ItemsPerThread; ++item)
265 keys_shared[ItemsPerThread * flat_tid + item] = k[item];
266 values_shared[ItemsPerThread * flat_tid + item] = v[item];
271 template<
unsigned int Size,
class BinaryFunction>
272 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort_impl(
const unsigned int flat_tid,
273 storage_type& storage,
274 BinaryFunction compare_function,
275 Key (&keys)[ItemsPerThread])
277 if(Size > ItemsPerBlock)
282 ws.sort(keys, storage.warp_sort, compare_function);
283 sort_merge_impl(flat_tid,
285 ItemsPerThread * WarpSortSize,
291 template<
unsigned int Size,
class BinaryFunction>
292 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort_impl(
const unsigned int flat_tid,
293 storage_type& storage,
294 BinaryFunction compare_function,
295 Key (&keys)[ItemsPerThread],
296 Value (&values)[ItemsPerThread])
298 if(Size > ItemsPerBlock)
303 ws.sort(keys, values, storage.warp_sort, compare_function);
304 sort_merge_impl(flat_tid,
306 ItemsPerThread * WarpSortSize,
313 template<
class BinaryFunction>
314 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort_impl(
const unsigned int flat_tid,
315 const unsigned int input_size,
316 storage_type& storage,
317 BinaryFunction compare_function,
318 Key (&keys)[ItemsPerThread])
321 ws.sort(keys, storage.warp_sort, input_size, compare_function);
322 sort_merge_impl(flat_tid,
324 ItemsPerThread * WarpSortSize,
330 template<
class BinaryFunction>
331 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort_impl(
const unsigned int flat_tid,
332 const unsigned int input_size,
333 storage_type& storage,
334 BinaryFunction compare_function,
335 Key (&keys)[ItemsPerThread],
336 Value (&values)[ItemsPerThread])
339 ws.sort(keys, values, storage.warp_sort, input_size, compare_function);
340 sort_merge_impl(flat_tid,
342 ItemsPerThread * WarpSortSize,
349 template<
class BinaryFunction>
350 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort_merge_impl(
const unsigned int flat_tid,
351 const unsigned int input_size,
352 unsigned int sorted_block_size,
353 storage_type& storage,
354 BinaryFunction compare_function,
355 Key (&thread_keys)[ItemsPerThread])
357 const unsigned int thread_offset = flat_tid * ItemsPerThread;
358 auto& keys_shared = storage.keys.get();
360 if(ItemsPerThread == 1 && thread_offset > input_size)
363 while(sorted_block_size < input_size)
365 copy_to_shared(thread_keys, flat_tid, keys_shared);
366 const unsigned int target_sorted_block_size = sorted_block_size * 2;
367 const unsigned int mask = target_sorted_block_size - 1;
368 const unsigned int keys1_beg = ~mask & thread_offset;
369 const unsigned int keys1_end =
std::min(input_size, keys1_beg + sorted_block_size);
370 const unsigned int keys2_end =
std::min(input_size, keys1_end + sorted_block_size);
371 sorted_block_size = target_sorted_block_size;
372 const unsigned int diag0_local =
std::min(input_size, mask & thread_offset);
374 const unsigned int num_keys1 = keys1_end - keys1_beg;
375 const unsigned int num_keys2 = keys2_end - keys1_end;
377 const unsigned int keys1_beg_local = merge_path(&keys_shared[keys1_beg],
378 &keys_shared[keys1_end],
383 const unsigned int keys2_beg_local = diag0_local - keys1_beg_local;
385 = {keys1_beg_local + keys1_beg, keys1_end, keys2_beg_local + keys1_end, keys2_end};
387 serial_merge(keys_shared, thread_keys, range_local, compare_function);
391 template<
class BinaryFunction>
392 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort_merge_impl(
const unsigned int flat_tid,
393 const unsigned int input_size,
394 unsigned int sorted_block_size,
395 storage_type& storage,
396 BinaryFunction compare_function,
397 Key (&thread_keys)[ItemsPerThread],
398 Value (&thread_values)[ItemsPerThread])
400 const unsigned int thread_offset = flat_tid * ItemsPerThread;
401 auto& keys_shared = storage.keys.get();
402 auto& values_shared = storage.values.get();
404 while(sorted_block_size < input_size)
406 copy_to_shared(thread_keys, thread_values, flat_tid, keys_shared, values_shared);
407 const unsigned int target_sorted_block_size = sorted_block_size * 2;
408 const unsigned int mask = target_sorted_block_size - 1;
409 const unsigned int keys1_beg = ~mask & thread_offset;
410 const unsigned int keys1_end =
std::min(input_size, keys1_beg + sorted_block_size);
411 const unsigned int keys2_end =
std::min(input_size, keys1_end + sorted_block_size);
412 sorted_block_size = target_sorted_block_size;
413 const unsigned int diag0_local =
std::min(input_size, mask & thread_offset);
415 const unsigned int num_keys1 = keys1_end - keys1_beg;
416 const unsigned int num_keys2 = keys2_end - keys1_end;
418 const unsigned int keys1_beg_local = merge_path(&keys_shared[keys1_beg],
419 &keys_shared[keys1_end],
424 const unsigned int keys2_beg_local = diag0_local - keys1_beg_local;
426 = {keys1_beg_local + keys1_beg, keys1_end, keys2_beg_local + keys1_end, keys2_end};
428 serial_merge(keys_shared,
440 END_ROCPRIM_NAMESPACE
442 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_MERGE_HPP_ Definition: block_sort_merge.hpp:41
The warp_sort class provides warp-wide methods for computing a parallel sort of items across thread w...
Definition: warp_sort.hpp:99
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
Definition: merge_path.hpp:33