rocPRIM
test_device_segmented_radix_sort.hpp
1 // MIT License
2 //
3 // Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included in all
13 // copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 // SOFTWARE.
22 
23 #ifndef TEST_DEVICE_SEGMENTED_RADIX_SORT_HPP_
24 #define TEST_DEVICE_SEGMENTED_RADIX_SORT_HPP_
25 
26 #include "../common_test_header.hpp"
27 
28 // required rocprim headers
29 #include <rocprim/device/device_segmented_radix_sort.hpp>
30 
31 // required test headers
32 #include "test_utils_custom_float_type.hpp"
33 #include "test_utils_sort_comparator.hpp"
34 #include "test_utils_types.hpp"
35 
36 template<class Key,
37  class Value,
38  bool Descending,
39  unsigned int StartBit,
40  unsigned int EndBit,
41  unsigned int MinSegmentLength,
42  unsigned int MaxSegmentLength,
43  class Config = rocprim::default_config>
44 struct params
45 {
46  using key_type = Key;
47  using value_type = Value;
48  static constexpr bool descending = Descending;
49  static constexpr unsigned int start_bit = StartBit;
50  static constexpr unsigned int end_bit = EndBit;
51  static constexpr unsigned int min_segment_length = MinSegmentLength;
52  static constexpr unsigned int max_segment_length = MaxSegmentLength;
53  using config = Config;
54 };
55 
56 using config_default = rocprim::segmented_radix_sort_config<
57  4, //< long radix bits
58  3, //< short radix bits
59  rocprim::kernel_config<256, 4> //< sort block size, items per thread
60  >;
61 
62 using config_semi_custom = rocprim::segmented_radix_sort_config<
63  3, //< long radix bits
64  2, //< short radix bits
65  rocprim::kernel_config<128, 4>, //< sort block size, items per thread
66  rocprim::WarpSortConfig<16, //< logical warp size small
67  8 //< items per thread small
68  >>;
69 
70 using config_semi_custom_warp_config = rocprim::segmented_radix_sort_config<
71  3, //< long radix bits
72  2, //< short radix bits
73  rocprim::kernel_config<128, 4>, //< sort block size, items per thread
74  rocprim::WarpSortConfig<16, //< logical warp size small
75  2, //< items per thread small
76  512, //< block size small
77  0, //< partitioning threshold
78  true //< enable unpartitioned sort
79  >>;
80 
81 using config_custom = rocprim::segmented_radix_sort_config<
82  3, //< long radix bits
83  2, //< short radix bits
84  rocprim::kernel_config<128, 4>, //< sort block size, items per thread
85  rocprim::WarpSortConfig<16, //< logical warp size small
86  2, //< items per thread small
87  512, //< block size small
88  0, //< partitioning threshold
89  true, //< enable unpartitioned sort
90  32, //< logical warp size medium
91  4, //< items per thread medium
92  256 //< block size medium
93  >>;
94 
95 template<class Params>
96 class RocprimDeviceSegmentedRadixSort : public ::testing::Test
97 {
98 public:
99  using params = Params;
100 };
101 
102 TYPED_TEST_SUITE_P(RocprimDeviceSegmentedRadixSort);
103 
104 template<typename TestFixture>
105 inline void sort_keys()
106 {
107  int device_id = test_common_utils::obtain_device_from_ctest();
108  SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
109  HIP_CHECK(hipSetDevice(device_id));
110 
111  using key_type = typename TestFixture::params::key_type;
112  using config = typename TestFixture::params::config;
113  static constexpr bool descending = TestFixture::params::descending;
114  static constexpr unsigned int start_bit = TestFixture::params::start_bit;
115  static constexpr unsigned int end_bit = TestFixture::params::end_bit;
116 
117  using offset_type = unsigned int;
118 
119  hipStream_t stream = 0;
120 
121  const bool debug_synchronous = false;
122 
123  std::random_device rd;
124  std::default_random_engine gen(rd());
125 
126  std::uniform_int_distribution<size_t> segment_length_dis(
127  TestFixture::params::min_segment_length,
128  TestFixture::params::max_segment_length);
129 
130  for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
131  {
132  unsigned int seed_value
133  = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
134  SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);
135 
136  for(size_t size : test_utils::get_sizes(seed_value))
137  {
138  SCOPED_TRACE(testing::Message() << "with size = " << size);
139 
140  // Generate data
141  std::vector<key_type> keys_input;
142  if(rocprim::is_floating_point<key_type>::value)
143  {
144  keys_input = test_utils::get_random_data<key_type>(size,
145  static_cast<key_type>(-1000),
146  static_cast<key_type>(+1000),
147  seed_value);
148  }
149  else
150  {
151  keys_input
152  = test_utils::get_random_data<key_type>(size,
155  seed_value);
156  }
157 
158  std::vector<offset_type> offsets;
159  unsigned int segments_count = 0;
160  size_t offset = 0;
161  while(offset < size)
162  {
163  const size_t segment_length = segment_length_dis(gen);
164  offsets.push_back(offset);
165  segments_count++;
166  offset += segment_length;
167  }
168  offsets.push_back(size);
169 
170  key_type* d_keys_input;
171  key_type* d_keys_output;
172  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type)));
173  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type)));
174  HIP_CHECK(hipMemcpy(d_keys_input,
175  keys_input.data(),
176  size * sizeof(key_type),
177  hipMemcpyHostToDevice));
178 
179  offset_type* d_offsets;
180  HIP_CHECK(
181  test_common_utils::hipMallocHelper(&d_offsets,
182  (segments_count + 1) * sizeof(offset_type)));
183  HIP_CHECK(hipMemcpy(d_offsets,
184  offsets.data(),
185  (segments_count + 1) * sizeof(offset_type),
186  hipMemcpyHostToDevice));
187 
188  // Calculate expected results on host
189  std::vector<key_type> expected(keys_input);
190  for(size_t i = 0; i < segments_count; i++)
191  {
192  std::stable_sort(
193  expected.begin() + offsets[i],
194  expected.begin() + offsets[i + 1],
196  }
197 
198  size_t temporary_storage_bytes = 0;
199  HIP_CHECK(rocprim::segmented_radix_sort_keys<config>(nullptr,
200  temporary_storage_bytes,
201  d_keys_input,
202  d_keys_output,
203  size,
204  segments_count,
205  d_offsets,
206  d_offsets + 1,
207  start_bit,
208  end_bit));
209 
210  ASSERT_GT(temporary_storage_bytes, 0U);
211 
212  void* d_temporary_storage;
213  HIP_CHECK(
214  test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
215 
216  if(descending)
217  {
218  HIP_CHECK(rocprim::segmented_radix_sort_keys_desc<config>(d_temporary_storage,
219  temporary_storage_bytes,
220  d_keys_input,
221  d_keys_output,
222  size,
223  segments_count,
224  d_offsets,
225  d_offsets + 1,
226  start_bit,
227  end_bit,
228  stream,
229  debug_synchronous));
230  }
231  else
232  {
233  HIP_CHECK(rocprim::segmented_radix_sort_keys<config>(d_temporary_storage,
234  temporary_storage_bytes,
235  d_keys_input,
236  d_keys_output,
237  size,
238  segments_count,
239  d_offsets,
240  d_offsets + 1,
241  start_bit,
242  end_bit,
243  stream,
244  debug_synchronous));
245  }
246 
247  std::vector<key_type> keys_output(size);
248  HIP_CHECK(hipMemcpy(keys_output.data(),
249  d_keys_output,
250  size * sizeof(key_type),
251  hipMemcpyDeviceToHost));
252 
253  HIP_CHECK(hipFree(d_temporary_storage));
254  HIP_CHECK(hipFree(d_keys_input));
255  HIP_CHECK(hipFree(d_keys_output));
256  HIP_CHECK(hipFree(d_offsets));
257 
258  ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected));
259  }
260  }
261 }
262 
263 template<typename TestFixture>
264 inline void sort_pairs()
265 {
266  int device_id = test_common_utils::obtain_device_from_ctest();
267  SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
268  HIP_CHECK(hipSetDevice(device_id));
269 
270  using key_type = typename TestFixture::params::key_type;
271  using value_type = typename TestFixture::params::value_type;
272  using config = typename TestFixture::params::config;
273  constexpr bool descending = TestFixture::params::descending;
274  constexpr unsigned int start_bit = TestFixture::params::start_bit;
275  constexpr unsigned int end_bit = TestFixture::params::end_bit;
276 
277  using offset_type = unsigned int;
278 
279  hipStream_t stream = 0;
280 
281  const bool debug_synchronous = false;
282 
283  std::random_device rd;
284  std::default_random_engine gen(rd());
285 
286  std::uniform_int_distribution<size_t> segment_length_dis(
287  TestFixture::params::min_segment_length,
288  TestFixture::params::max_segment_length);
289 
290  for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
291  {
292  unsigned int seed_value
293  = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
294  SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);
295 
296  for(size_t size : test_utils::get_sizes(seed_value))
297  {
298  SCOPED_TRACE(testing::Message() << "with size = " << size);
299 
300  // Generate data
301  std::vector<key_type> keys_input;
302  if(rocprim::is_floating_point<key_type>::value)
303  {
304  keys_input = test_utils::get_random_data<key_type>(size,
305  static_cast<key_type>(-1000),
306  static_cast<key_type>(+1000),
307  seed_value);
308  }
309  else
310  {
311  keys_input
312  = test_utils::get_random_data<key_type>(size,
315  seed_value);
316  }
317 
318  std::vector<offset_type> offsets;
319  unsigned int segments_count = 0;
320  size_t offset = 0;
321  while(offset < size)
322  {
323  const size_t segment_length = segment_length_dis(gen);
324  offsets.push_back(offset);
325  segments_count++;
326  offset += segment_length;
327  }
328  offsets.push_back(size);
329 
330  std::vector<value_type> values_input(size);
331  test_utils::iota(values_input.begin(), values_input.end(), 0);
332 
333  key_type* d_keys_input;
334  key_type* d_keys_output;
335  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type)));
336  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type)));
337  HIP_CHECK(hipMemcpy(d_keys_input,
338  keys_input.data(),
339  size * sizeof(key_type),
340  hipMemcpyHostToDevice));
341 
342  value_type* d_values_input;
343  value_type* d_values_output;
344  HIP_CHECK(
345  test_common_utils::hipMallocHelper(&d_values_input, size * sizeof(value_type)));
346  HIP_CHECK(
347  test_common_utils::hipMallocHelper(&d_values_output, size * sizeof(value_type)));
348  HIP_CHECK(hipMemcpy(d_values_input,
349  values_input.data(),
350  size * sizeof(value_type),
351  hipMemcpyHostToDevice));
352 
353  offset_type* d_offsets;
354  HIP_CHECK(
355  test_common_utils::hipMallocHelper(&d_offsets,
356  (segments_count + 1) * sizeof(offset_type)));
357  HIP_CHECK(hipMemcpy(d_offsets,
358  offsets.data(),
359  (segments_count + 1) * sizeof(offset_type),
360  hipMemcpyHostToDevice));
361 
362  using key_value = std::pair<key_type, value_type>;
363 
364  // Calculate expected results on host
365  std::vector<key_value> expected(size);
366  for(size_t i = 0; i < size; i++)
367  {
368  expected[i] = key_value(keys_input[i], values_input[i]);
369  }
370  for(size_t i = 0; i < segments_count; i++)
371  {
372  std::stable_sort(expected.begin() + offsets[i],
373  expected.begin() + offsets[i + 1],
375  value_type,
376  descending,
377  start_bit,
378  end_bit>());
379  }
380  std::vector<key_type> keys_expected(size);
381  std::vector<value_type> values_expected(size);
382  for(size_t i = 0; i < size; i++)
383  {
384  keys_expected[i] = expected[i].first;
385  values_expected[i] = expected[i].second;
386  }
387 
388  void* d_temporary_storage = nullptr;
389  size_t temporary_storage_bytes = 0;
390  HIP_CHECK(rocprim::segmented_radix_sort_pairs<config>(d_temporary_storage,
391  temporary_storage_bytes,
392  d_keys_input,
393  d_keys_output,
394  d_values_input,
395  d_values_output,
396  size,
397  segments_count,
398  d_offsets,
399  d_offsets + 1,
400  start_bit,
401  end_bit));
402 
403  ASSERT_GT(temporary_storage_bytes, 0U);
404 
405  HIP_CHECK(
406  test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
407 
408  if(descending)
409  {
410  HIP_CHECK(rocprim::segmented_radix_sort_pairs_desc<config>(d_temporary_storage,
411  temporary_storage_bytes,
412  d_keys_input,
413  d_keys_output,
414  d_values_input,
415  d_values_output,
416  size,
417  segments_count,
418  d_offsets,
419  d_offsets + 1,
420  start_bit,
421  end_bit,
422  stream,
423  debug_synchronous));
424  }
425  else
426  {
427  HIP_CHECK(rocprim::segmented_radix_sort_pairs<config>(d_temporary_storage,
428  temporary_storage_bytes,
429  d_keys_input,
430  d_keys_output,
431  d_values_input,
432  d_values_output,
433  size,
434  segments_count,
435  d_offsets,
436  d_offsets + 1,
437  start_bit,
438  end_bit,
439  stream,
440  debug_synchronous));
441  }
442 
443  std::vector<key_type> keys_output(size);
444  HIP_CHECK(hipMemcpy(keys_output.data(),
445  d_keys_output,
446  size * sizeof(key_type),
447  hipMemcpyDeviceToHost));
448 
449  std::vector<value_type> values_output(size);
450  HIP_CHECK(hipMemcpy(values_output.data(),
451  d_values_output,
452  size * sizeof(value_type),
453  hipMemcpyDeviceToHost));
454 
455  HIP_CHECK(hipFree(d_temporary_storage));
456  HIP_CHECK(hipFree(d_keys_input));
457  HIP_CHECK(hipFree(d_values_input));
458  HIP_CHECK(hipFree(d_keys_output));
459  HIP_CHECK(hipFree(d_values_output));
460  HIP_CHECK(hipFree(d_offsets));
461 
462  ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, keys_expected));
463  ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, values_expected));
464  }
465  }
466 }
467 
468 template<typename TestFixture>
469 inline void sort_keys_double_buffer()
470 {
471  int device_id = test_common_utils::obtain_device_from_ctest();
472  SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
473  HIP_CHECK(hipSetDevice(device_id));
474 
475  using key_type = typename TestFixture::params::key_type;
476  using config = typename TestFixture::params::config;
477  constexpr bool descending = TestFixture::params::descending;
478  constexpr unsigned int start_bit = TestFixture::params::start_bit;
479  constexpr unsigned int end_bit = TestFixture::params::end_bit;
480 
481  using offset_type = unsigned int;
482 
483  hipStream_t stream = 0;
484 
485  const bool debug_synchronous = false;
486 
487  std::random_device rd;
488  std::default_random_engine gen(rd());
489 
490  std::uniform_int_distribution<size_t> segment_length_dis(
491  TestFixture::params::min_segment_length,
492  TestFixture::params::max_segment_length);
493 
494  for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
495  {
496  unsigned int seed_value
497  = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
498  SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);
499 
500  for(size_t size : test_utils::get_sizes(seed_value))
501  {
502  SCOPED_TRACE(testing::Message() << "with size = " << size);
503 
504  // Generate data
505  std::vector<key_type> keys_input;
506  if(rocprim::is_floating_point<key_type>::value)
507  {
508  keys_input = test_utils::get_random_data<key_type>(size,
509  static_cast<key_type>(-1000),
510  static_cast<key_type>(+1000),
511  seed_value);
512  }
513  else
514  {
515  keys_input
516  = test_utils::get_random_data<key_type>(size,
519  seed_value);
520  }
521 
522  std::vector<offset_type> offsets;
523  unsigned int segments_count = 0;
524  size_t offset = 0;
525  while(offset < size)
526  {
527  const size_t segment_length = segment_length_dis(gen);
528  offsets.push_back(offset);
529  segments_count++;
530  offset += segment_length;
531  }
532  offsets.push_back(size);
533 
534  key_type* d_keys_input;
535  key_type* d_keys_output;
536  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type)));
537  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type)));
538  HIP_CHECK(hipMemcpy(d_keys_input,
539  keys_input.data(),
540  size * sizeof(key_type),
541  hipMemcpyHostToDevice));
542 
543  offset_type* d_offsets;
544  HIP_CHECK(
545  test_common_utils::hipMallocHelper(&d_offsets,
546  (segments_count + 1) * sizeof(offset_type)));
547  HIP_CHECK(hipMemcpy(d_offsets,
548  offsets.data(),
549  (segments_count + 1) * sizeof(offset_type),
550  hipMemcpyHostToDevice));
551 
552  // Calculate expected results on host
553  std::vector<key_type> expected(keys_input);
554  for(size_t i = 0; i < segments_count; i++)
555  {
556  std::stable_sort(
557  expected.begin() + offsets[i],
558  expected.begin() + offsets[i + 1],
560  }
561 
562  rocprim::double_buffer<key_type> d_keys(d_keys_input, d_keys_output);
563 
564  size_t temporary_storage_bytes = 0;
565  HIP_CHECK(rocprim::segmented_radix_sort_keys<config>(nullptr,
566  temporary_storage_bytes,
567  d_keys,
568  size,
569  segments_count,
570  d_offsets,
571  d_offsets + 1,
572  start_bit,
573  end_bit));
574 
575  ASSERT_GT(temporary_storage_bytes, 0U);
576 
577  void* d_temporary_storage;
578  HIP_CHECK(
579  test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
580 
581  if(descending)
582  {
583  HIP_CHECK(rocprim::segmented_radix_sort_keys_desc<config>(d_temporary_storage,
584  temporary_storage_bytes,
585  d_keys,
586  size,
587  segments_count,
588  d_offsets,
589  d_offsets + 1,
590  start_bit,
591  end_bit,
592  stream,
593  debug_synchronous));
594  }
595  else
596  {
597  HIP_CHECK(rocprim::segmented_radix_sort_keys<config>(d_temporary_storage,
598  temporary_storage_bytes,
599  d_keys,
600  size,
601  segments_count,
602  d_offsets,
603  d_offsets + 1,
604  start_bit,
605  end_bit,
606  stream,
607  debug_synchronous));
608  }
609 
610  std::vector<key_type> keys_output(size);
611  HIP_CHECK(hipMemcpy(keys_output.data(),
612  d_keys.current(),
613  size * sizeof(key_type),
614  hipMemcpyDeviceToHost));
615 
616  HIP_CHECK(hipFree(d_temporary_storage));
617  HIP_CHECK(hipFree(d_keys_input));
618  HIP_CHECK(hipFree(d_keys_output));
619  HIP_CHECK(hipFree(d_offsets));
620 
621  ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected));
622  }
623  }
624 }
625 
626 template<typename TestFixture>
627 inline void sort_pairs_double_buffer()
628 {
629  int device_id = test_common_utils::obtain_device_from_ctest();
630  SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
631  HIP_CHECK(hipSetDevice(device_id));
632 
633  using key_type = typename TestFixture::params::key_type;
634  using value_type = typename TestFixture::params::value_type;
635  using config = typename TestFixture::params::config;
636  constexpr bool descending = TestFixture::params::descending;
637  constexpr unsigned int start_bit = TestFixture::params::start_bit;
638  constexpr unsigned int end_bit = TestFixture::params::end_bit;
639 
640  using offset_type = unsigned int;
641 
642  hipStream_t stream = 0;
643 
644  const bool debug_synchronous = false;
645 
646  std::random_device rd;
647  std::default_random_engine gen(rd());
648 
649  std::uniform_int_distribution<size_t> segment_length_dis(
650  TestFixture::params::min_segment_length,
651  TestFixture::params::max_segment_length);
652 
653  for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
654  {
655  unsigned int seed_value
656  = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
657  SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);
658 
659  for(size_t size : test_utils::get_sizes(seed_value))
660  {
661  SCOPED_TRACE(testing::Message() << "with size = " << size);
662 
663  // Generate data
664  std::vector<key_type> keys_input;
665  if(rocprim::is_floating_point<key_type>::value)
666  {
667  keys_input = test_utils::get_random_data<key_type>(size,
668  static_cast<key_type>(-1000),
669  static_cast<key_type>(+1000),
670  seed_value);
671  }
672  else
673  {
674  keys_input
675  = test_utils::get_random_data<key_type>(size,
678  seed_value);
679  }
680 
681  std::vector<offset_type> offsets;
682  unsigned int segments_count = 0;
683  size_t offset = 0;
684  while(offset < size)
685  {
686  const size_t segment_length = segment_length_dis(gen);
687  offsets.push_back(offset);
688  segments_count++;
689  offset += segment_length;
690  }
691  offsets.push_back(size);
692 
693  std::vector<value_type> values_input(size);
694  test_utils::iota(values_input.begin(), values_input.end(), 0);
695 
696  key_type* d_keys_input;
697  key_type* d_keys_output;
698  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type)));
699  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type)));
700  HIP_CHECK(hipMemcpy(d_keys_input,
701  keys_input.data(),
702  size * sizeof(key_type),
703  hipMemcpyHostToDevice));
704 
705  value_type* d_values_input;
706  value_type* d_values_output;
707  HIP_CHECK(
708  test_common_utils::hipMallocHelper(&d_values_input, size * sizeof(value_type)));
709  HIP_CHECK(
710  test_common_utils::hipMallocHelper(&d_values_output, size * sizeof(value_type)));
711  HIP_CHECK(hipMemcpy(d_values_input,
712  values_input.data(),
713  size * sizeof(value_type),
714  hipMemcpyHostToDevice));
715 
716  offset_type* d_offsets;
717  HIP_CHECK(
718  test_common_utils::hipMallocHelper(&d_offsets,
719  (segments_count + 1) * sizeof(offset_type)));
720  HIP_CHECK(hipMemcpy(d_offsets,
721  offsets.data(),
722  (segments_count + 1) * sizeof(offset_type),
723  hipMemcpyHostToDevice));
724 
725  using key_value = std::pair<key_type, value_type>;
726 
727  // Calculate expected results on host
728  std::vector<key_value> expected(size);
729  for(size_t i = 0; i < size; i++)
730  {
731  expected[i] = key_value(keys_input[i], values_input[i]);
732  }
733  for(size_t i = 0; i < segments_count; i++)
734  {
735  std::stable_sort(expected.begin() + offsets[i],
736  expected.begin() + offsets[i + 1],
738  value_type,
739  descending,
740  start_bit,
741  end_bit>());
742  }
743  std::vector<key_type> keys_expected(size);
744  std::vector<value_type> values_expected(size);
745  for(size_t i = 0; i < size; i++)
746  {
747  keys_expected[i] = expected[i].first;
748  values_expected[i] = expected[i].second;
749  }
750 
751  rocprim::double_buffer<key_type> d_keys(d_keys_input, d_keys_output);
752  rocprim::double_buffer<value_type> d_values(d_values_input, d_values_output);
753 
754  void* d_temporary_storage = nullptr;
755  size_t temporary_storage_bytes = 0;
756  HIP_CHECK(rocprim::segmented_radix_sort_pairs<config>(d_temporary_storage,
757  temporary_storage_bytes,
758  d_keys,
759  d_values,
760  size,
761  segments_count,
762  d_offsets,
763  d_offsets + 1,
764  start_bit,
765  end_bit));
766 
767  ASSERT_GT(temporary_storage_bytes, 0U);
768 
769  HIP_CHECK(
770  test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
771 
772  if(descending)
773  {
774  HIP_CHECK(rocprim::segmented_radix_sort_pairs_desc<config>(d_temporary_storage,
775  temporary_storage_bytes,
776  d_keys,
777  d_values,
778  size,
779  segments_count,
780  d_offsets,
781  d_offsets + 1,
782  start_bit,
783  end_bit,
784  stream,
785  debug_synchronous));
786  }
787  else
788  {
789  HIP_CHECK(rocprim::segmented_radix_sort_pairs<config>(d_temporary_storage,
790  temporary_storage_bytes,
791  d_keys,
792  d_values,
793  size,
794  segments_count,
795  d_offsets,
796  d_offsets + 1,
797  start_bit,
798  end_bit,
799  stream,
800  debug_synchronous));
801  }
802 
803  std::vector<key_type> keys_output(size);
804  HIP_CHECK(hipMemcpy(keys_output.data(),
805  d_keys.current(),
806  size * sizeof(key_type),
807  hipMemcpyDeviceToHost));
808 
809  std::vector<value_type> values_output(size);
810  HIP_CHECK(hipMemcpy(values_output.data(),
811  d_values.current(),
812  size * sizeof(value_type),
813  hipMemcpyDeviceToHost));
814 
815  HIP_CHECK(hipFree(d_temporary_storage));
816  HIP_CHECK(hipFree(d_keys_input));
817  HIP_CHECK(hipFree(d_keys_output));
818  HIP_CHECK(hipFree(d_values_input));
819  HIP_CHECK(hipFree(d_values_output));
820  HIP_CHECK(hipFree(d_offsets));
821 
822  ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, keys_expected));
823  ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, values_expected));
824  }
825  }
826 }
827 
828 #endif // TEST_DEVICE_SEGMENTED_RADIX_SORT_HPP_
ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments.
Definition: functional.hpp:55
Definition: test_utils_sort_comparator.hpp:110
Definition: test_warp_exchange.cpp:34
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Definition: test_device_segmented_radix_sort.hpp:96
Definition: test_utils_sort_comparator.hpp:45
Definition: test_device_binary_search.cpp:37