rocPRIM
test_device_radix_sort.hpp
1 // MIT License
2 //
3 // Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included in all
13 // copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 // SOFTWARE.
22 
23 #ifndef TEST_DEVICE_RADIX_SORT_HPP_
24 #define TEST_DEVICE_RADIX_SORT_HPP_
25 
26 #include "../common_test_header.hpp"
27 
28 // required rocprim headers
29 #include <rocprim/device/device_radix_sort.hpp>
30 
31 // required test headers
32 #include "test_utils_custom_float_type.hpp"
33 #include "test_utils_sort_comparator.hpp"
34 #include "test_utils_types.hpp"
35 
36 template<class Key,
37  class Value,
38  bool Descending = false,
39  unsigned int StartBit = 0,
40  unsigned int EndBit = sizeof(Key) * 8,
41  bool CheckLargeSizes = false>
42 struct params
43 {
44  using key_type = Key;
45  using value_type = Value;
46  static constexpr bool descending = Descending;
47  static constexpr unsigned int start_bit = StartBit;
48  static constexpr unsigned int end_bit = EndBit;
49  static constexpr bool check_large_sizes = CheckLargeSizes;
50 };
51 
52 template<class Params>
53 class RocprimDeviceRadixSort : public ::testing::Test
54 {
55 public:
56  using params = Params;
57 };
58 
59 TYPED_TEST_SUITE_P(RocprimDeviceRadixSort);
60 
61 template<typename TestFixture>
62 inline void sort_keys()
63 {
64  int device_id = test_common_utils::obtain_device_from_ctest();
65  SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
66  HIP_CHECK(hipSetDevice(device_id));
67 
68  using key_type = typename TestFixture::params::key_type;
69  constexpr bool descending = TestFixture::params::descending;
70  constexpr unsigned int start_bit = TestFixture::params::start_bit;
71  constexpr unsigned int end_bit = TestFixture::params::end_bit;
72  constexpr bool check_large_sizes = TestFixture::params::check_large_sizes;
73 
74  hipStream_t stream = 0;
75 
76  const bool debug_synchronous = false;
77 
78  bool in_place = false;
79 
80  for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
81  {
82  unsigned int seed_value
83  = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
84  SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);
85 
86  auto sizes = test_utils::get_sizes(seed_value);
87  sizes.push_back(1 << 23);
88 
89  for(size_t size : sizes)
90  {
91  if(size > (1 << 17) && !check_large_sizes)
92  break;
93 
94  SCOPED_TRACE(testing::Message() << "with size = " << size);
95 
96  in_place = !in_place;
97 
98  // Generate data
99  std::vector<key_type> keys_input;
100  if(rocprim::is_floating_point<key_type>::value)
101  {
102  keys_input = test_utils::get_random_data<key_type>(size,
103  static_cast<key_type>(-1000),
104  static_cast<key_type>(+1000),
105  seed_value);
106  test_utils::add_special_values(keys_input, seed_value);
107  }
108  else
109  {
110  keys_input
111  = test_utils::get_random_data<key_type>(size,
114  seed_value);
115  }
116 
117  key_type* d_keys_input;
118  key_type* d_keys_output;
119  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type)));
120  if(in_place)
121  {
122  d_keys_output = d_keys_input;
123  }
124  else
125  {
126  HIP_CHECK(
127  test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type)));
128  }
129  HIP_CHECK(hipMemcpy(d_keys_input,
130  keys_input.data(),
131  size * sizeof(key_type),
132  hipMemcpyHostToDevice));
133 
134  // Calculate expected results on host
135  std::vector<key_type> expected(keys_input);
136  std::stable_sort(
137  expected.begin(),
138  expected.end(),
140 
141  // Use arbitrary custom config to increase test coverage without making more test cases
142  using config = rocprim::radix_sort_config_v2<rocprim::default_config,
143  rocprim::default_config,
144  rocprim::default_config,
145  1024 * 512>;
146 
147  size_t temporary_storage_bytes;
148  HIP_CHECK(rocprim::radix_sort_keys<config>(nullptr,
149  temporary_storage_bytes,
150  d_keys_input,
151  d_keys_output,
152  size,
153  start_bit,
154  end_bit));
155 
156  ASSERT_GT(temporary_storage_bytes, 0);
157 
158  void* d_temporary_storage;
159  HIP_CHECK(
160  test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
161 
162  if(descending)
163  {
164  HIP_CHECK(rocprim::radix_sort_keys_desc<config>(d_temporary_storage,
165  temporary_storage_bytes,
166  d_keys_input,
167  d_keys_output,
168  size,
169  start_bit,
170  end_bit,
171  stream,
172  debug_synchronous));
173  }
174  else
175  {
176  HIP_CHECK(rocprim::radix_sort_keys<config>(d_temporary_storage,
177  temporary_storage_bytes,
178  d_keys_input,
179  d_keys_output,
180  size,
181  start_bit,
182  end_bit,
183  stream,
184  debug_synchronous));
185  }
186 
187  std::vector<key_type> keys_output(size);
188  HIP_CHECK(hipMemcpy(keys_output.data(),
189  d_keys_output,
190  size * sizeof(key_type),
191  hipMemcpyDeviceToHost));
192 
193  HIP_CHECK(hipFree(d_temporary_storage));
194  HIP_CHECK(hipFree(d_keys_input));
195  if(!in_place)
196  {
197  HIP_CHECK(hipFree(d_keys_output));
198  }
199 
200  ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected));
201  }
202  }
203 }
204 
205 template<typename TestFixture>
206 inline void sort_pairs()
207 {
208  int device_id = test_common_utils::obtain_device_from_ctest();
209  SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
210  HIP_CHECK(hipSetDevice(device_id));
211 
212  using key_type = typename TestFixture::params::key_type;
213  using value_type = typename TestFixture::params::value_type;
214  constexpr bool descending = TestFixture::params::descending;
215  constexpr unsigned int start_bit = TestFixture::params::start_bit;
216  constexpr unsigned int end_bit = TestFixture::params::end_bit;
217  constexpr bool check_large_sizes = TestFixture::params::check_large_sizes;
218 
219  hipStream_t stream = 0;
220 
221  const bool debug_synchronous = false;
222 
223  bool in_place = false;
224 
225  for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
226  {
227  unsigned int seed_value
228  = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
229  SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);
230 
231  auto sizes = test_utils::get_sizes(seed_value);
232  sizes.push_back(1 << 23);
233 
234  for(size_t size : sizes)
235  {
236  if(size > (1 << 17) && !check_large_sizes)
237  break;
238 
239  SCOPED_TRACE(testing::Message() << "with size = " << size);
240 
241  in_place = !in_place;
242 
243  // Generate data
244  std::vector<key_type> keys_input;
245  if(rocprim::is_floating_point<key_type>::value)
246  {
247  keys_input = test_utils::get_random_data<key_type>(size,
248  static_cast<key_type>(-1000),
249  static_cast<key_type>(+1000),
250  seed_value);
251  test_utils::add_special_values(keys_input, seed_value);
252  }
253  else
254  {
255  keys_input
256  = test_utils::get_random_data<key_type>(size,
259  seed_value);
260  }
261 
262  std::vector<value_type> values_input(size);
263  test_utils::iota(values_input.begin(), values_input.end(), 0);
264 
265  key_type* d_keys_input;
266  key_type* d_keys_output;
267  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type)));
268  if(in_place)
269  {
270  d_keys_output = d_keys_input;
271  }
272  else
273  {
274  HIP_CHECK(
275  test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type)));
276  }
277  HIP_CHECK(hipMemcpy(d_keys_input,
278  keys_input.data(),
279  size * sizeof(key_type),
280  hipMemcpyHostToDevice));
281 
282  value_type* d_values_input;
283  value_type* d_values_output;
284  HIP_CHECK(
285  test_common_utils::hipMallocHelper(&d_values_input, size * sizeof(value_type)));
286  if(in_place)
287  {
288  d_values_output = d_values_input;
289  }
290  else
291  {
292  HIP_CHECK(test_common_utils::hipMallocHelper(&d_values_output,
293  size * sizeof(value_type)));
294  }
295  HIP_CHECK(hipMemcpy(d_values_input,
296  values_input.data(),
297  size * sizeof(value_type),
298  hipMemcpyHostToDevice));
299 
300  using key_value = std::pair<key_type, value_type>;
301 
302  // Calculate expected results on host
303  std::vector<key_value> expected(size);
304  for(size_t i = 0; i < size; i++)
305  {
306  expected[i] = key_value(keys_input[i], values_input[i]);
307  }
308  std::stable_sort(
309  expected.begin(),
310  expected.end(),
311  test_utils::
312  key_value_comparator<key_type, value_type, descending, start_bit, end_bit>());
313  std::vector<key_type> keys_expected(size);
314  std::vector<value_type> values_expected(size);
315  for(size_t i = 0; i < size; i++)
316  {
317  keys_expected[i] = expected[i].first;
318  values_expected[i] = expected[i].second;
319  }
320 
321  // Use arbitrary custom config to increase test coverage without making more test cases
322  using config = rocprim::radix_sort_config_v2<
323  rocprim::kernel_config<256, 1>,
324  rocprim::merge_sort_config<128, 64, 2, 128, 64, 2>,
325  rocprim::radix_sort_onesweep_config<rocprim::kernel_config<128, 1>,
326  rocprim::kernel_config<128, 1>,
327  4>,
328  1024 * 512>;
329 
330  void* d_temporary_storage = nullptr;
331  size_t temporary_storage_bytes;
332  HIP_CHECK(rocprim::radix_sort_pairs<config>(d_temporary_storage,
333  temporary_storage_bytes,
334  d_keys_input,
335  d_keys_output,
336  d_values_input,
337  d_values_output,
338  size,
339  start_bit,
340  end_bit));
341 
342  ASSERT_GT(temporary_storage_bytes, 0);
343 
344  HIP_CHECK(
345  test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
346 
347  if(descending)
348  {
349  HIP_CHECK(rocprim::radix_sort_pairs_desc<config>(d_temporary_storage,
350  temporary_storage_bytes,
351  d_keys_input,
352  d_keys_output,
353  d_values_input,
354  d_values_output,
355  size,
356  start_bit,
357  end_bit,
358  stream,
359  debug_synchronous));
360  }
361  else
362  {
363  HIP_CHECK(rocprim::radix_sort_pairs<config>(d_temporary_storage,
364  temporary_storage_bytes,
365  d_keys_input,
366  d_keys_output,
367  d_values_input,
368  d_values_output,
369  size,
370  start_bit,
371  end_bit,
372  stream,
373  debug_synchronous));
374  }
375 
376  std::vector<key_type> keys_output(size);
377  HIP_CHECK(hipMemcpy(keys_output.data(),
378  d_keys_output,
379  size * sizeof(key_type),
380  hipMemcpyDeviceToHost));
381 
382  std::vector<value_type> values_output(size);
383  HIP_CHECK(hipMemcpy(values_output.data(),
384  d_values_output,
385  size * sizeof(value_type),
386  hipMemcpyDeviceToHost));
387 
388  HIP_CHECK(hipFree(d_temporary_storage));
389  HIP_CHECK(hipFree(d_keys_input));
390  HIP_CHECK(hipFree(d_values_input));
391  if(!in_place)
392  {
393  HIP_CHECK(hipFree(d_keys_output));
394  HIP_CHECK(hipFree(d_values_output));
395  }
396 
397  ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected));
398  ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected));
399  }
400  }
401 }
402 
403 template<typename TestFixture>
404 inline void sort_keys_double_buffer()
405 {
406  int device_id = test_common_utils::obtain_device_from_ctest();
407  SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
408  HIP_CHECK(hipSetDevice(device_id));
409 
410  using key_type = typename TestFixture::params::key_type;
411  constexpr bool descending = TestFixture::params::descending;
412  constexpr unsigned int start_bit = TestFixture::params::start_bit;
413  constexpr unsigned int end_bit = TestFixture::params::end_bit;
414  constexpr bool check_large_sizes = TestFixture::params::check_large_sizes;
415 
416  hipStream_t stream = 0;
417 
418  const bool debug_synchronous = false;
419 
420  for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
421  {
422  unsigned int seed_value
423  = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
424  SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);
425 
426  auto sizes = test_utils::get_sizes(seed_value);
427  sizes.push_back(1 << 23);
428 
429  for(size_t size : sizes)
430  {
431  if(size > (1 << 17) && !check_large_sizes)
432  break;
433 
434  SCOPED_TRACE(testing::Message() << "with size = " << size);
435 
436  // Generate data
437  std::vector<key_type> keys_input;
438  if(rocprim::is_floating_point<key_type>::value)
439  {
440  keys_input = test_utils::get_random_data<key_type>(size,
441  static_cast<key_type>(-1000),
442  static_cast<key_type>(+1000),
443  seed_value);
444  test_utils::add_special_values(keys_input, seed_value);
445  }
446  else
447  {
448  keys_input
449  = test_utils::get_random_data<key_type>(size,
452  seed_value);
453  }
454 
455  key_type* d_keys_input;
456  key_type* d_keys_output;
457  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type)));
458  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type)));
459  HIP_CHECK(hipMemcpy(d_keys_input,
460  keys_input.data(),
461  size * sizeof(key_type),
462  hipMemcpyHostToDevice));
463 
464  // Calculate expected results on host
465  std::vector<key_type> expected(keys_input);
466  std::stable_sort(
467  expected.begin(),
468  expected.end(),
470 
471  rocprim::double_buffer<key_type> d_keys(d_keys_input, d_keys_output);
472 
473  size_t temporary_storage_bytes;
474  HIP_CHECK(rocprim::radix_sort_keys(nullptr,
475  temporary_storage_bytes,
476  d_keys,
477  size,
478  start_bit,
479  end_bit));
480 
481  ASSERT_GT(temporary_storage_bytes, 0);
482 
483  void* d_temporary_storage;
484  HIP_CHECK(
485  test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
486 
487  if(descending)
488  {
489  HIP_CHECK(rocprim::radix_sort_keys_desc(d_temporary_storage,
490  temporary_storage_bytes,
491  d_keys,
492  size,
493  start_bit,
494  end_bit,
495  stream,
496  debug_synchronous));
497  }
498  else
499  {
500  HIP_CHECK(rocprim::radix_sort_keys(d_temporary_storage,
501  temporary_storage_bytes,
502  d_keys,
503  size,
504  start_bit,
505  end_bit,
506  stream,
507  debug_synchronous));
508  }
509 
510  HIP_CHECK(hipFree(d_temporary_storage));
511 
512  std::vector<key_type> keys_output(size);
513  HIP_CHECK(hipMemcpy(keys_output.data(),
514  d_keys.current(),
515  size * sizeof(key_type),
516  hipMemcpyDeviceToHost));
517 
518  HIP_CHECK(hipFree(d_keys_input));
519  HIP_CHECK(hipFree(d_keys_output));
520 
521  ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected));
522  }
523  }
524 }
525 
526 template<typename TestFixture>
527 inline void sort_pairs_double_buffer()
528 {
529  int device_id = test_common_utils::obtain_device_from_ctest();
530  SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
531  HIP_CHECK(hipSetDevice(device_id));
532 
533  using key_type = typename TestFixture::params::key_type;
534  using value_type = typename TestFixture::params::value_type;
535  constexpr bool descending = TestFixture::params::descending;
536  constexpr unsigned int start_bit = TestFixture::params::start_bit;
537  constexpr unsigned int end_bit = TestFixture::params::end_bit;
538  constexpr bool check_large_sizes = TestFixture::params::check_large_sizes;
539 
540  hipStream_t stream = 0;
541 
542  const bool debug_synchronous = false;
543 
544  for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
545  {
546  unsigned int seed_value
547  = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
548  SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);
549 
550  auto sizes = test_utils::get_sizes(seed_value);
551  sizes.push_back(1 << 23);
552 
553  for(size_t size : sizes)
554  {
555  if(size > (1 << 17) && !check_large_sizes)
556  break;
557 
558  SCOPED_TRACE(testing::Message() << "with size = " << size);
559 
560  // Generate data
561  std::vector<key_type> keys_input;
562  if(rocprim::is_floating_point<key_type>::value)
563  {
564  keys_input = test_utils::get_random_data<key_type>(size,
565  static_cast<key_type>(-1000),
566  static_cast<key_type>(+1000),
567  seed_value);
568  test_utils::add_special_values(keys_input, seed_value);
569  }
570  else
571  {
572  keys_input
573  = test_utils::get_random_data<key_type>(size,
576  seed_value);
577  }
578 
579  std::vector<value_type> values_input(size);
580  test_utils::iota(values_input.begin(), values_input.end(), 0);
581 
582  key_type* d_keys_input;
583  key_type* d_keys_output;
584  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type)));
585  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type)));
586  HIP_CHECK(hipMemcpy(d_keys_input,
587  keys_input.data(),
588  size * sizeof(key_type),
589  hipMemcpyHostToDevice));
590 
591  value_type* d_values_input;
592  value_type* d_values_output;
593  HIP_CHECK(
594  test_common_utils::hipMallocHelper(&d_values_input, size * sizeof(value_type)));
595  HIP_CHECK(
596  test_common_utils::hipMallocHelper(&d_values_output, size * sizeof(value_type)));
597  HIP_CHECK(hipMemcpy(d_values_input,
598  values_input.data(),
599  size * sizeof(value_type),
600  hipMemcpyHostToDevice));
601 
602  using key_value = std::pair<key_type, value_type>;
603 
604  // Calculate expected results on host
605  std::vector<key_value> expected(size);
606  for(size_t i = 0; i < size; i++)
607  {
608  expected[i] = key_value(keys_input[i], values_input[i]);
609  }
610  std::stable_sort(
611  expected.begin(),
612  expected.end(),
613  test_utils::
614  key_value_comparator<key_type, value_type, descending, start_bit, end_bit>());
615  std::vector<key_type> keys_expected(size);
616  std::vector<value_type> values_expected(size);
617  for(size_t i = 0; i < size; i++)
618  {
619  keys_expected[i] = expected[i].first;
620  values_expected[i] = expected[i].second;
621  }
622 
623  rocprim::double_buffer<key_type> d_keys(d_keys_input, d_keys_output);
624  rocprim::double_buffer<value_type> d_values(d_values_input, d_values_output);
625 
626  void* d_temporary_storage = nullptr;
627  size_t temporary_storage_bytes;
628  HIP_CHECK(rocprim::radix_sort_pairs(d_temporary_storage,
629  temporary_storage_bytes,
630  d_keys,
631  d_values,
632  size,
633  start_bit,
634  end_bit));
635 
636  ASSERT_GT(temporary_storage_bytes, 0);
637 
638  HIP_CHECK(
639  test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
640 
641  if(descending)
642  {
643  HIP_CHECK(rocprim::radix_sort_pairs_desc(d_temporary_storage,
644  temporary_storage_bytes,
645  d_keys,
646  d_values,
647  size,
648  start_bit,
649  end_bit,
650  stream,
651  debug_synchronous));
652  }
653  else
654  {
655  HIP_CHECK(rocprim::radix_sort_pairs(d_temporary_storage,
656  temporary_storage_bytes,
657  d_keys,
658  d_values,
659  size,
660  start_bit,
661  end_bit,
662  stream,
663  debug_synchronous));
664  }
665 
666  HIP_CHECK(hipFree(d_temporary_storage));
667 
668  std::vector<key_type> keys_output(size);
669  HIP_CHECK(hipMemcpy(keys_output.data(),
670  d_keys.current(),
671  size * sizeof(key_type),
672  hipMemcpyDeviceToHost));
673 
674  std::vector<value_type> values_output(size);
675  HIP_CHECK(hipMemcpy(values_output.data(),
676  d_values.current(),
677  size * sizeof(value_type),
678  hipMemcpyDeviceToHost));
679 
680  HIP_CHECK(hipFree(d_keys_input));
681  HIP_CHECK(hipFree(d_keys_output));
682  HIP_CHECK(hipFree(d_values_input));
683  HIP_CHECK(hipFree(d_values_output));
684 
685  ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected));
686  ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected));
687  }
688  }
689 }
690 
691 inline void sort_keys_over_4g()
692 {
693  using key_type = uint8_t;
694  constexpr unsigned int start_bit = 0;
695  constexpr unsigned int end_bit = 8ull * sizeof(key_type);
696  constexpr hipStream_t stream = 0;
697  constexpr bool debug_synchronous = false;
698  constexpr size_t size = (1ull << 32) + 32;
699  constexpr size_t number_of_possible_keys = 1ull << (8ull * sizeof(key_type));
700  assert(std::is_unsigned<key_type>::value);
701  std::vector<size_t> histogram(number_of_possible_keys, 0);
702  const int seed_value = rand();
703 
704  const int device_id = test_common_utils::obtain_device_from_ctest();
705  SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
706  HIP_CHECK(hipSetDevice(device_id));
707 
708  std::vector<key_type> keys_input
709  = test_utils::get_random_data<key_type>(size,
712  seed_value);
713 
714  //generate histogram of the randomly generated values
715  std::for_each(keys_input.begin(), keys_input.end(), [&](const key_type& a) { histogram[a]++; });
716 
717  key_type* d_keys_input_output{};
718  size_t key_type_storage_bytes = size * sizeof(key_type);
719 
720  HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input_output, key_type_storage_bytes));
721  HIP_CHECK(hipMemcpy(d_keys_input_output,
722  keys_input.data(),
723  key_type_storage_bytes,
724  hipMemcpyHostToDevice));
725 
726  size_t temporary_storage_bytes;
727  HIP_CHECK(rocprim::radix_sort_keys(nullptr,
728  temporary_storage_bytes,
729  d_keys_input_output,
730  d_keys_input_output,
731  size,
732  start_bit,
733  end_bit,
734  stream,
735  debug_synchronous));
736 
737  ASSERT_GT(temporary_storage_bytes, 0);
738 
739  hipDeviceProp_t prop;
740  HIP_CHECK(hipGetDeviceProperties(&prop, device_id));
741 
742  size_t total_storage_bytes = key_type_storage_bytes + temporary_storage_bytes;
743  if (total_storage_bytes > (static_cast<size_t>(prop.totalGlobalMem * 0.90))) {
744  HIP_CHECK(hipFree(d_keys_input_output));
745  GTEST_SKIP() << "Test case device memory requirement (" << total_storage_bytes << " bytes) exceeds available memory on current device ("
746  << prop.totalGlobalMem << " bytes). Skipping test";
747  }
748 
749  void* d_temporary_storage;
750  HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
751 
752  HIP_CHECK(rocprim::radix_sort_keys(d_temporary_storage,
753  temporary_storage_bytes,
754  d_keys_input_output,
755  d_keys_input_output,
756  size,
757  start_bit,
758  end_bit,
759  stream,
760  debug_synchronous));
761 
762  std::vector<key_type> output(keys_input.size());
763  HIP_CHECK(hipMemcpy(output.data(),
764  d_keys_input_output,
765  size * sizeof(key_type),
766  hipMemcpyDeviceToHost));
767 
768  size_t counter = 0;
769  for(size_t i = 0; i <= std::numeric_limits<key_type>::max(); ++i)
770  {
771  for(size_t j = 0; j < histogram[i]; ++j)
772  {
773  ASSERT_EQ(static_cast<size_t>(output[counter]), i);
774  ++counter;
775  }
776  }
777  ASSERT_EQ(counter, size);
778 
779  HIP_CHECK(hipFree(d_keys_input_output));
780  HIP_CHECK(hipFree(d_temporary_storage));
781 }
782 
783 #endif // TEST_DEVICE_RADIX_SORT_HPP_
ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments.
Definition: functional.hpp:55
Definition: test_warp_exchange.cpp:34
hipError_t radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel descending radix sort-by-key primitive for device level.
Definition: device_radix_sort.hpp:1157
Definition: test_device_radix_sort.hpp:53
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
hipError_t radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel descending radix sort primitive for device level.
Definition: device_radix_sort.hpp:910
Definition: test_utils_sort_comparator.hpp:45
Definition: test_device_binary_search.cpp:37
Definition: benchmark_block_histogram.cpp:64
hipError_t radix_sort_keys(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel ascending radix sort primitive for device level.
Definition: device_radix_sort.hpp:803
hipError_t radix_sort_pairs(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel ascending radix sort-by-key primitive for device level.
Definition: device_radix_sort.hpp:1035