rocPRIM
device_segmented_radix_sort_config.hpp
1 // Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_CONFIG_HPP_
22 #define ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_CONFIG_HPP_
23 
24 #include <algorithm>
25 #include <type_traits>
26 
27 #include "../config.hpp"
28 #include "../detail/various.hpp"
29 #include "../functional.hpp"
30 
31 #include "config_types.hpp"
32 
35 
36 BEGIN_ROCPRIM_NAMESPACE
37 
56 template<unsigned int LogicalWarpSizeSmall,
57  unsigned int ItemsPerThreadSmall,
58  unsigned int BlockSizeSmall = 256,
59  unsigned int PartitioningThreshold = 3000,
60  bool EnableUnpartitionedWarpSort = true,
61  unsigned int LogicalWarpSizeMedium = std::max(32u, LogicalWarpSizeSmall),
62  unsigned int ItemsPerThreadMedium = std::max(4u, ItemsPerThreadSmall),
63  unsigned int BlockSizeMedium = 256>
65 {
66  static_assert(LogicalWarpSizeSmall * ItemsPerThreadSmall
67  <= LogicalWarpSizeMedium * ItemsPerThreadMedium,
68  "The number of items processed by a small warp cannot be larger than the number "
69  "of items processed by a medium warp");
71  static constexpr unsigned int logical_warp_size_small = LogicalWarpSizeSmall;
73  static constexpr unsigned int items_per_thread_small = ItemsPerThreadSmall;
75  static constexpr unsigned int block_size_small = BlockSizeSmall;
78  static constexpr unsigned int partitioning_threshold = PartitioningThreshold;
81  static constexpr bool enable_unpartitioned_warp_sort = EnableUnpartitionedWarpSort;
83  static constexpr unsigned int logical_warp_size_medium = LogicalWarpSizeMedium;
85  static constexpr unsigned int items_per_thread_medium = ItemsPerThreadMedium;
87  static constexpr unsigned int block_size_medium = BlockSizeMedium;
88 };
89 
93 {
95  static constexpr unsigned int logical_warp_size_small = 1;
97  static constexpr unsigned int items_per_thread_small = 1;
99  static constexpr unsigned int block_size_small = 1;
102  static constexpr unsigned int partitioning_threshold = 0;
105  static constexpr bool enable_unpartitioned_warp_sort = false;
107  static constexpr unsigned int logical_warp_size_medium = 1;
109  static constexpr unsigned int items_per_thread_medium = 1;
111  static constexpr unsigned int block_size_medium = 1;
112 };
113 
118 template<class Key, unsigned int MediumWarpSize = ROCPRIM_WARP_SIZE_32>
120  = std::conditional_t<sizeof(Key) < 2,
122  WarpSortConfig<32, //< logical warp size - small kernel
123  4, //< items per thread - small kernel
124  256, //< block size - small kernel
125  3000, //< partitioning threshold
126  (sizeof(Key) > 2), //< enable unpartitioned warp sort
127  MediumWarpSize, //< logical warp size - medium kernel
128  4, //< items per thread - medium kernel
129  256 //< block size - medium kernel
130  >>;
131 
148 template<
149  unsigned int LongRadixBits,
150  unsigned int ShortRadixBits,
151  class SortConfig,
153 >
155 {
157  static constexpr unsigned int long_radix_bits = LongRadixBits;
159  static constexpr unsigned int short_radix_bits = ShortRadixBits;
161  using sort = SortConfig;
163  using warp_sort_config = WarpSortConfig;
164 };
165 
166 namespace detail
167 {
168 
169 template<class Key, class Value>
171 {
172  static constexpr unsigned int item_scale =
173  ::rocprim::detail::ceiling_div<unsigned int>(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int));
174 
175  using type = select_type<
177  (sizeof(Key) == 1 && sizeof(Value) <= 8),
179  >,
181  (sizeof(Key) == 2 && sizeof(Value) <= 8),
183  >,
185  (sizeof(Key) == 4 && sizeof(Value) <= 8),
186  segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t<Key> >
187  >,
189  (sizeof(Key) == 8 && sizeof(Value) <= 8),
190  segmented_radix_sort_config<7, 6, kernel_config<256, 13>, select_warp_sort_config_t<Key> >
191  >,
192  segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t<Key> >
193  >;
194 };
195 
196 template<class Key>
198  : select_type<
199  select_type_case<sizeof(Key) == 1, segmented_radix_sort_config<8, 7, kernel_config<256, 10>, select_warp_sort_config_t<Key> > >,
200  select_type_case<sizeof(Key) == 2, segmented_radix_sort_config<8, 7, kernel_config<256, 10>, select_warp_sort_config_t<Key> > >,
201  select_type_case<sizeof(Key) == 4, segmented_radix_sort_config<7, 6, kernel_config<256, 9>, select_warp_sort_config_t<Key> > >,
202  select_type_case<sizeof(Key) == 8, segmented_radix_sort_config<7, 6, kernel_config<256, 7>, select_warp_sort_config_t<Key> > >
203  > { };
204 
205 template<class Key, class Value>
207 {
208  static constexpr unsigned int item_scale =
209  ::rocprim::detail::ceiling_div<unsigned int>(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int));
210 
211  using type = select_type<
213  (sizeof(Key) == 1 && sizeof(Value) <= 8),
215  >,
217  (sizeof(Key) == 2 && sizeof(Value) <= 8),
219  >,
221  (sizeof(Key) == 4 && sizeof(Value) <= 8),
222  segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t<Key> >
223  >,
225  (sizeof(Key) == 8 && sizeof(Value) <= 8),
226  segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t<Key> >
227  >,
228  segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t<Key> >
229  >;
230 };
231 
232 template<class Key>
233 struct segmented_radix_sort_config_900<Key, empty_type>
234  : select_type<
235  select_type_case<sizeof(Key) == 1, segmented_radix_sort_config<4, 3, kernel_config<256, 10>, select_warp_sort_config_t<Key> > >,
236  select_type_case<sizeof(Key) == 2, segmented_radix_sort_config<6, 5, kernel_config<256, 10>, select_warp_sort_config_t<Key> > >,
237  select_type_case<sizeof(Key) == 4, segmented_radix_sort_config<7, 6, kernel_config<256, 17>, select_warp_sort_config_t<Key> > >,
238  select_type_case<sizeof(Key) == 8, segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t<Key> > >
239  > { };
240 
241 template<class Key, class Value>
243 {
244  static constexpr unsigned int item_scale =
245  ::rocprim::detail::ceiling_div<unsigned int>(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int));
246 
247  using type = select_type<
249  (sizeof(Key) == 1 && sizeof(Value) <= 8),
251  4,
255  (sizeof(Key) == 2 && sizeof(Value) <= 8),
257  5,
261  (sizeof(Key) == 4 && sizeof(Value) <= 8),
262  segmented_radix_sort_config<7,
263  6,
265  select_warp_sort_config_t<Key, ROCPRIM_WARP_SIZE_64>>>,
267  (sizeof(Key) == 8 && sizeof(Value) <= 8),
268  segmented_radix_sort_config<7,
269  6,
271  select_warp_sort_config_t<Key, ROCPRIM_WARP_SIZE_64>>>,
272  segmented_radix_sort_config<7,
273  6,
274  kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>,
275  select_warp_sort_config_t<Key, ROCPRIM_WARP_SIZE_64>>>;
276 };
277 
278 template<class Key>
279 struct segmented_radix_sort_config_90a<Key, empty_type>
280  : select_type<
281  select_type_case<
282  sizeof(Key) == 1,
283  segmented_radix_sort_config<4,
284  3,
285  kernel_config<256, 10>,
286  select_warp_sort_config_t<Key, ROCPRIM_WARP_SIZE_64>>>,
287  select_type_case<
288  sizeof(Key) == 2,
289  segmented_radix_sort_config<6,
290  5,
291  kernel_config<256, 10>,
292  select_warp_sort_config_t<Key, ROCPRIM_WARP_SIZE_64>>>,
293  select_type_case<
294  sizeof(Key) == 4,
295  segmented_radix_sort_config<7,
296  6,
297  kernel_config<256, 17>,
298  select_warp_sort_config_t<Key, ROCPRIM_WARP_SIZE_64>>>,
299  select_type_case<
300  sizeof(Key) == 8,
301  segmented_radix_sort_config<7,
302  6,
303  kernel_config<256, 15>,
304  select_warp_sort_config_t<Key, ROCPRIM_WARP_SIZE_64>>>>
305 {};
306 
307 template<class Key, class Value>
309 {
310  static constexpr unsigned int item_scale =
311  ::rocprim::detail::ceiling_div<unsigned int>(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int));
312 
313  using type = select_type<
315  (sizeof(Key) == 1 && sizeof(Value) <= 8),
317  >,
319  (sizeof(Key) == 2 && sizeof(Value) <= 8),
321  >,
323  (sizeof(Key) == 4 && sizeof(Value) <= 8),
324  segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t<Key> >
325  >,
327  (sizeof(Key) == 8 && sizeof(Value) <= 8),
328  segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t<Key> >
329  >,
330  segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t<Key> >
331  >;
332 };
333 
334 template<class Key>
335 struct segmented_radix_sort_config_1030<Key, empty_type>
336  : select_type<
337  select_type_case<sizeof(Key) == 1, segmented_radix_sort_config<4, 3, kernel_config<256, 10>, select_warp_sort_config_t<Key> > >,
338  select_type_case<sizeof(Key) == 2, segmented_radix_sort_config<6, 5, kernel_config<256, 10>, select_warp_sort_config_t<Key> > >,
339  select_type_case<sizeof(Key) == 4, segmented_radix_sort_config<7, 6, kernel_config<256, 17>, select_warp_sort_config_t<Key> > >,
340  select_type_case<sizeof(Key) == 8, segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t<Key> > >
341  > { };
342 
343 template<unsigned int TargetArch, class Key, class Value>
345  : select_arch<
346  TargetArch,
347  select_arch_case<803, detail::segmented_radix_sort_config_803<Key, Value>>,
348  select_arch_case<900, detail::segmented_radix_sort_config_900<Key, Value>>,
349  select_arch_case<906, detail::segmented_radix_sort_config_90a<Key, Value>>,
350  select_arch_case<908, detail::segmented_radix_sort_config_90a<Key, Value>>,
351  select_arch_case<ROCPRIM_ARCH_90a, detail::segmented_radix_sort_config_90a<Key, Value>>,
352  select_arch_case<1030, detail::segmented_radix_sort_config_1030<Key, Value>>,
353  detail::segmented_radix_sort_config_900<Key, Value>>
354 {};
355 
356 } // end namespace detail
357 
358 END_ROCPRIM_NAMESPACE
359 
361 // end of group primitivesmodule_deviceconfigs
362 
363 #endif // ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_CONFIG_HPP_
Empty type used as a placeholder, usually used to flag that given template parameter should not be us...
Definition: types.hpp:135
static constexpr unsigned int items_per_thread_small
The number of items processed by a thread in the small segment processing kernel. ...
Definition: device_segmented_radix_sort_config.hpp:73
Definition: device_segmented_radix_sort_config.hpp:206
ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments.
Definition: functional.hpp:55
Definition: device_segmented_radix_sort_config.hpp:344
static constexpr bool enable_unpartitioned_warp_sort
If set to true, warp sort can be used to sort the small segments, even if the total number of segment...
Definition: device_segmented_radix_sort_config.hpp:81
Indicates if the warp level sorting is disabled in the device segmented radix sort configuration...
Definition: device_segmented_radix_sort_config.hpp:92
static constexpr unsigned int partitioning_threshold
If the number of segments is at least partitioning_threshold, then the segments are partitioned into ...
Definition: device_segmented_radix_sort_config.hpp:78
Configuration of device-level segmented radix sort operation.
Definition: device_segmented_radix_sort_config.hpp:154
SortConfig sort
Configuration of radix sort kernel.
Definition: device_segmented_radix_sort_config.hpp:161
Definition: various.hpp:236
static constexpr unsigned int block_size_small
The number of threads per block in the small segment processing kernel.
Definition: device_segmented_radix_sort_config.hpp:75
Definition: device_segmented_radix_sort_config.hpp:170
Definition: device_segmented_radix_sort_config.hpp:242
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: device_segmented_radix_sort_config.hpp:308
static constexpr unsigned int logical_warp_size_small
The number of threads in the logical warp in the small segment processing kernel. ...
Definition: device_segmented_radix_sort_config.hpp:71
static constexpr unsigned int items_per_thread_medium
The number of items processed by a thread in the medium segment processing kernel.
Definition: device_segmented_radix_sort_config.hpp:85
Definition: config_types.hpp:140
Configuration of the warp sort part of the device segmented radix sort operation. ...
Definition: device_segmented_radix_sort_config.hpp:64
static constexpr unsigned int logical_warp_size_medium
The number of threads in the logical warp in the medium segment processing kernel.
Definition: device_segmented_radix_sort_config.hpp:83
static constexpr unsigned int block_size_medium
The number of threads per block in the medium segment processing kernel.
Definition: device_segmented_radix_sort_config.hpp:87
std::conditional_t< sizeof(Key)< 2, DisabledWarpSortConfig, WarpSortConfig< 32, 4, 256, 3000,(sizeof(Key) > 2), MediumWarpSize, 4, 256 > > select_warp_sort_config_t
Selects the appropriate WarpSortConfig based on the size of the key type.
Definition: device_segmented_radix_sort_config.hpp:130