23 #ifndef TEST_DEVICE_RADIX_SORT_HPP_ 24 #define TEST_DEVICE_RADIX_SORT_HPP_ 26 #include "../common_test_header.hpp" 29 #include <rocprim/device/device_radix_sort.hpp> 32 #include "test_utils_custom_float_type.hpp" 33 #include "test_utils_sort_comparator.hpp" 34 #include "test_utils_types.hpp" 38 bool Descending =
false,
39 unsigned int StartBit = 0,
40 unsigned int EndBit =
sizeof(Key) * 8,
41 bool CheckLargeSizes =
false>
45 using value_type = Value;
46 static constexpr
bool descending = Descending;
47 static constexpr
unsigned int start_bit = StartBit;
48 static constexpr
unsigned int end_bit = EndBit;
49 static constexpr
bool check_large_sizes = CheckLargeSizes;
52 template<
class Params>
61 template<
typename TestFixture>
62 inline void sort_keys()
64 int device_id = test_common_utils::obtain_device_from_ctest();
65 SCOPED_TRACE(testing::Message() <<
"with device_id = " << device_id);
66 HIP_CHECK(hipSetDevice(device_id));
68 using key_type =
typename TestFixture::params::key_type;
69 constexpr
bool descending = TestFixture::params::descending;
70 constexpr
unsigned int start_bit = TestFixture::params::start_bit;
71 constexpr
unsigned int end_bit = TestFixture::params::end_bit;
72 constexpr
bool check_large_sizes = TestFixture::params::check_large_sizes;
74 hipStream_t stream = 0;
76 const bool debug_synchronous =
false;
78 bool in_place =
false;
80 for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
82 unsigned int seed_value
83 = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
84 SCOPED_TRACE(testing::Message() <<
"with seed = " << seed_value);
86 auto sizes = test_utils::get_sizes(seed_value);
87 sizes.push_back(1 << 23);
89 for(
size_t size : sizes)
91 if(size > (1 << 17) && !check_large_sizes)
94 SCOPED_TRACE(testing::Message() <<
"with size = " << size);
99 std::vector<key_type> keys_input;
100 if(rocprim::is_floating_point<key_type>::value)
102 keys_input = test_utils::get_random_data<key_type>(size,
103 static_cast<key_type
>(-1000),
104 static_cast<key_type>(+1000),
106 test_utils::add_special_values(keys_input, seed_value);
111 = test_utils::get_random_data<key_type>(size,
117 key_type* d_keys_input;
118 key_type* d_keys_output;
119 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size *
sizeof(key_type)));
122 d_keys_output = d_keys_input;
127 test_common_utils::hipMallocHelper(&d_keys_output, size *
sizeof(key_type)));
129 HIP_CHECK(hipMemcpy(d_keys_input,
131 size *
sizeof(key_type),
132 hipMemcpyHostToDevice));
135 std::vector<key_type> expected(keys_input);
142 using config = rocprim::radix_sort_config_v2<rocprim::default_config,
143 rocprim::default_config,
144 rocprim::default_config,
147 size_t temporary_storage_bytes;
148 HIP_CHECK(rocprim::radix_sort_keys<config>(
nullptr,
149 temporary_storage_bytes,
156 ASSERT_GT(temporary_storage_bytes, 0);
158 void* d_temporary_storage;
160 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
164 HIP_CHECK(rocprim::radix_sort_keys_desc<config>(d_temporary_storage,
165 temporary_storage_bytes,
176 HIP_CHECK(rocprim::radix_sort_keys<config>(d_temporary_storage,
177 temporary_storage_bytes,
187 std::vector<key_type> keys_output(size);
188 HIP_CHECK(hipMemcpy(keys_output.data(),
190 size *
sizeof(key_type),
191 hipMemcpyDeviceToHost));
193 HIP_CHECK(hipFree(d_temporary_storage));
194 HIP_CHECK(hipFree(d_keys_input));
197 HIP_CHECK(hipFree(d_keys_output));
200 ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected));
205 template<
typename TestFixture>
206 inline void sort_pairs()
208 int device_id = test_common_utils::obtain_device_from_ctest();
209 SCOPED_TRACE(testing::Message() <<
"with device_id = " << device_id);
210 HIP_CHECK(hipSetDevice(device_id));
212 using key_type =
typename TestFixture::params::key_type;
213 using value_type =
typename TestFixture::params::value_type;
214 constexpr
bool descending = TestFixture::params::descending;
215 constexpr
unsigned int start_bit = TestFixture::params::start_bit;
216 constexpr
unsigned int end_bit = TestFixture::params::end_bit;
217 constexpr
bool check_large_sizes = TestFixture::params::check_large_sizes;
219 hipStream_t stream = 0;
221 const bool debug_synchronous =
false;
223 bool in_place =
false;
225 for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
227 unsigned int seed_value
228 = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
229 SCOPED_TRACE(testing::Message() <<
"with seed = " << seed_value);
231 auto sizes = test_utils::get_sizes(seed_value);
232 sizes.push_back(1 << 23);
234 for(
size_t size : sizes)
236 if(size > (1 << 17) && !check_large_sizes)
239 SCOPED_TRACE(testing::Message() <<
"with size = " << size);
241 in_place = !in_place;
244 std::vector<key_type> keys_input;
245 if(rocprim::is_floating_point<key_type>::value)
247 keys_input = test_utils::get_random_data<key_type>(size,
248 static_cast<key_type
>(-1000),
249 static_cast<key_type>(+1000),
251 test_utils::add_special_values(keys_input, seed_value);
256 = test_utils::get_random_data<key_type>(size,
262 std::vector<value_type> values_input(size);
263 test_utils::iota(values_input.begin(), values_input.end(), 0);
265 key_type* d_keys_input;
266 key_type* d_keys_output;
267 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size *
sizeof(key_type)));
270 d_keys_output = d_keys_input;
275 test_common_utils::hipMallocHelper(&d_keys_output, size *
sizeof(key_type)));
277 HIP_CHECK(hipMemcpy(d_keys_input,
279 size *
sizeof(key_type),
280 hipMemcpyHostToDevice));
282 value_type* d_values_input;
283 value_type* d_values_output;
285 test_common_utils::hipMallocHelper(&d_values_input, size *
sizeof(value_type)));
288 d_values_output = d_values_input;
292 HIP_CHECK(test_common_utils::hipMallocHelper(&d_values_output,
293 size *
sizeof(value_type)));
295 HIP_CHECK(hipMemcpy(d_values_input,
297 size *
sizeof(value_type),
298 hipMemcpyHostToDevice));
300 using key_value = std::pair<key_type, value_type>;
303 std::vector<key_value> expected(size);
304 for(
size_t i = 0; i < size; i++)
306 expected[i] = key_value(keys_input[i], values_input[i]);
312 key_value_comparator<key_type, value_type, descending, start_bit, end_bit>());
313 std::vector<key_type> keys_expected(size);
314 std::vector<value_type> values_expected(size);
315 for(
size_t i = 0; i < size; i++)
317 keys_expected[i] = expected[i].first;
318 values_expected[i] = expected[i].second;
322 using config = rocprim::radix_sort_config_v2<
323 rocprim::kernel_config<256, 1>,
324 rocprim::merge_sort_config<128, 64, 2, 128, 64, 2>,
325 rocprim::radix_sort_onesweep_config<rocprim::kernel_config<128, 1>,
326 rocprim::kernel_config<128, 1>,
330 void* d_temporary_storage =
nullptr;
331 size_t temporary_storage_bytes;
332 HIP_CHECK(rocprim::radix_sort_pairs<config>(d_temporary_storage,
333 temporary_storage_bytes,
342 ASSERT_GT(temporary_storage_bytes, 0);
345 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
349 HIP_CHECK(rocprim::radix_sort_pairs_desc<config>(d_temporary_storage,
350 temporary_storage_bytes,
363 HIP_CHECK(rocprim::radix_sort_pairs<config>(d_temporary_storage,
364 temporary_storage_bytes,
376 std::vector<key_type> keys_output(size);
377 HIP_CHECK(hipMemcpy(keys_output.data(),
379 size *
sizeof(key_type),
380 hipMemcpyDeviceToHost));
382 std::vector<value_type> values_output(size);
383 HIP_CHECK(hipMemcpy(values_output.data(),
385 size *
sizeof(value_type),
386 hipMemcpyDeviceToHost));
388 HIP_CHECK(hipFree(d_temporary_storage));
389 HIP_CHECK(hipFree(d_keys_input));
390 HIP_CHECK(hipFree(d_values_input));
393 HIP_CHECK(hipFree(d_keys_output));
394 HIP_CHECK(hipFree(d_values_output));
397 ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected));
398 ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected));
403 template<
typename TestFixture>
404 inline void sort_keys_double_buffer()
406 int device_id = test_common_utils::obtain_device_from_ctest();
407 SCOPED_TRACE(testing::Message() <<
"with device_id = " << device_id);
408 HIP_CHECK(hipSetDevice(device_id));
410 using key_type =
typename TestFixture::params::key_type;
411 constexpr
bool descending = TestFixture::params::descending;
412 constexpr
unsigned int start_bit = TestFixture::params::start_bit;
413 constexpr
unsigned int end_bit = TestFixture::params::end_bit;
414 constexpr
bool check_large_sizes = TestFixture::params::check_large_sizes;
416 hipStream_t stream = 0;
418 const bool debug_synchronous =
false;
420 for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
422 unsigned int seed_value
423 = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
424 SCOPED_TRACE(testing::Message() <<
"with seed = " << seed_value);
426 auto sizes = test_utils::get_sizes(seed_value);
427 sizes.push_back(1 << 23);
429 for(
size_t size : sizes)
431 if(size > (1 << 17) && !check_large_sizes)
434 SCOPED_TRACE(testing::Message() <<
"with size = " << size);
437 std::vector<key_type> keys_input;
438 if(rocprim::is_floating_point<key_type>::value)
440 keys_input = test_utils::get_random_data<key_type>(size,
441 static_cast<key_type
>(-1000),
442 static_cast<key_type>(+1000),
444 test_utils::add_special_values(keys_input, seed_value);
449 = test_utils::get_random_data<key_type>(size,
455 key_type* d_keys_input;
456 key_type* d_keys_output;
457 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size *
sizeof(key_type)));
458 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size *
sizeof(key_type)));
459 HIP_CHECK(hipMemcpy(d_keys_input,
461 size *
sizeof(key_type),
462 hipMemcpyHostToDevice));
465 std::vector<key_type> expected(keys_input);
471 rocprim::double_buffer<key_type> d_keys(d_keys_input, d_keys_output);
473 size_t temporary_storage_bytes;
475 temporary_storage_bytes,
481 ASSERT_GT(temporary_storage_bytes, 0);
483 void* d_temporary_storage;
485 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
490 temporary_storage_bytes,
501 temporary_storage_bytes,
510 HIP_CHECK(hipFree(d_temporary_storage));
512 std::vector<key_type> keys_output(size);
513 HIP_CHECK(hipMemcpy(keys_output.data(),
515 size *
sizeof(key_type),
516 hipMemcpyDeviceToHost));
518 HIP_CHECK(hipFree(d_keys_input));
519 HIP_CHECK(hipFree(d_keys_output));
521 ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected));
526 template<
typename TestFixture>
527 inline void sort_pairs_double_buffer()
529 int device_id = test_common_utils::obtain_device_from_ctest();
530 SCOPED_TRACE(testing::Message() <<
"with device_id = " << device_id);
531 HIP_CHECK(hipSetDevice(device_id));
533 using key_type =
typename TestFixture::params::key_type;
534 using value_type =
typename TestFixture::params::value_type;
535 constexpr
bool descending = TestFixture::params::descending;
536 constexpr
unsigned int start_bit = TestFixture::params::start_bit;
537 constexpr
unsigned int end_bit = TestFixture::params::end_bit;
538 constexpr
bool check_large_sizes = TestFixture::params::check_large_sizes;
540 hipStream_t stream = 0;
542 const bool debug_synchronous =
false;
544 for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
546 unsigned int seed_value
547 = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
548 SCOPED_TRACE(testing::Message() <<
"with seed = " << seed_value);
550 auto sizes = test_utils::get_sizes(seed_value);
551 sizes.push_back(1 << 23);
553 for(
size_t size : sizes)
555 if(size > (1 << 17) && !check_large_sizes)
558 SCOPED_TRACE(testing::Message() <<
"with size = " << size);
561 std::vector<key_type> keys_input;
562 if(rocprim::is_floating_point<key_type>::value)
564 keys_input = test_utils::get_random_data<key_type>(size,
565 static_cast<key_type
>(-1000),
566 static_cast<key_type>(+1000),
568 test_utils::add_special_values(keys_input, seed_value);
573 = test_utils::get_random_data<key_type>(size,
579 std::vector<value_type> values_input(size);
580 test_utils::iota(values_input.begin(), values_input.end(), 0);
582 key_type* d_keys_input;
583 key_type* d_keys_output;
584 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size *
sizeof(key_type)));
585 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size *
sizeof(key_type)));
586 HIP_CHECK(hipMemcpy(d_keys_input,
588 size *
sizeof(key_type),
589 hipMemcpyHostToDevice));
591 value_type* d_values_input;
592 value_type* d_values_output;
594 test_common_utils::hipMallocHelper(&d_values_input, size *
sizeof(value_type)));
596 test_common_utils::hipMallocHelper(&d_values_output, size *
sizeof(value_type)));
597 HIP_CHECK(hipMemcpy(d_values_input,
599 size *
sizeof(value_type),
600 hipMemcpyHostToDevice));
602 using key_value = std::pair<key_type, value_type>;
605 std::vector<key_value> expected(size);
606 for(
size_t i = 0; i < size; i++)
608 expected[i] = key_value(keys_input[i], values_input[i]);
614 key_value_comparator<key_type, value_type, descending, start_bit, end_bit>());
615 std::vector<key_type> keys_expected(size);
616 std::vector<value_type> values_expected(size);
617 for(
size_t i = 0; i < size; i++)
619 keys_expected[i] = expected[i].first;
620 values_expected[i] = expected[i].second;
623 rocprim::double_buffer<key_type> d_keys(d_keys_input, d_keys_output);
624 rocprim::double_buffer<value_type> d_values(d_values_input, d_values_output);
626 void* d_temporary_storage =
nullptr;
627 size_t temporary_storage_bytes;
629 temporary_storage_bytes,
636 ASSERT_GT(temporary_storage_bytes, 0);
639 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
644 temporary_storage_bytes,
656 temporary_storage_bytes,
666 HIP_CHECK(hipFree(d_temporary_storage));
668 std::vector<key_type> keys_output(size);
669 HIP_CHECK(hipMemcpy(keys_output.data(),
671 size *
sizeof(key_type),
672 hipMemcpyDeviceToHost));
674 std::vector<value_type> values_output(size);
675 HIP_CHECK(hipMemcpy(values_output.data(),
677 size *
sizeof(value_type),
678 hipMemcpyDeviceToHost));
680 HIP_CHECK(hipFree(d_keys_input));
681 HIP_CHECK(hipFree(d_keys_output));
682 HIP_CHECK(hipFree(d_values_input));
683 HIP_CHECK(hipFree(d_values_output));
685 ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected));
686 ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected));
691 inline void sort_keys_over_4g()
693 using key_type = uint8_t;
694 constexpr
unsigned int start_bit = 0;
695 constexpr
unsigned int end_bit = 8ull *
sizeof(key_type);
696 constexpr hipStream_t stream = 0;
697 constexpr
bool debug_synchronous =
false;
698 constexpr
size_t size = (1ull << 32) + 32;
699 constexpr
size_t number_of_possible_keys = 1ull << (8ull *
sizeof(key_type));
700 assert(std::is_unsigned<key_type>::value);
701 std::vector<size_t>
histogram(number_of_possible_keys, 0);
702 const int seed_value = rand();
704 const int device_id = test_common_utils::obtain_device_from_ctest();
705 SCOPED_TRACE(testing::Message() <<
"with device_id = " << device_id);
706 HIP_CHECK(hipSetDevice(device_id));
708 std::vector<key_type> keys_input
709 = test_utils::get_random_data<key_type>(size,
715 std::for_each(keys_input.begin(), keys_input.end(), [&](
const key_type& a) { histogram[a]++; });
717 key_type* d_keys_input_output{};
718 size_t key_type_storage_bytes = size *
sizeof(key_type);
720 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input_output, key_type_storage_bytes));
721 HIP_CHECK(hipMemcpy(d_keys_input_output,
723 key_type_storage_bytes,
724 hipMemcpyHostToDevice));
726 size_t temporary_storage_bytes;
728 temporary_storage_bytes,
737 ASSERT_GT(temporary_storage_bytes, 0);
739 hipDeviceProp_t prop;
740 HIP_CHECK(hipGetDeviceProperties(&prop, device_id));
742 size_t total_storage_bytes = key_type_storage_bytes + temporary_storage_bytes;
743 if (total_storage_bytes > (static_cast<size_t>(prop.totalGlobalMem * 0.90))) {
744 HIP_CHECK(hipFree(d_keys_input_output));
745 GTEST_SKIP() <<
"Test case device memory requirement (" << total_storage_bytes <<
" bytes) exceeds available memory on current device (" 746 << prop.totalGlobalMem <<
" bytes). Skipping test";
749 void* d_temporary_storage;
750 HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
753 temporary_storage_bytes,
762 std::vector<key_type> output(keys_input.size());
763 HIP_CHECK(hipMemcpy(output.data(),
765 size *
sizeof(key_type),
766 hipMemcpyDeviceToHost));
771 for(
size_t j = 0; j < histogram[i]; ++j)
773 ASSERT_EQ(static_cast<size_t>(output[counter]), i);
777 ASSERT_EQ(counter, size);
779 HIP_CHECK(hipFree(d_keys_input_output));
780 HIP_CHECK(hipFree(d_temporary_storage));
783 #endif // TEST_DEVICE_RADIX_SORT_HPP_ ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments.
Definition: functional.hpp:55
Definition: test_warp_exchange.cpp:34
hipError_t radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel descending radix sort-by-key primitive for device level.
Definition: device_radix_sort.hpp:1157
Definition: test_device_radix_sort.hpp:53
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
hipError_t radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel descending radix sort primitive for device level.
Definition: device_radix_sort.hpp:910
Definition: test_utils_sort_comparator.hpp:45
Definition: test_device_binary_search.cpp:37
Definition: benchmark_block_histogram.cpp:64
hipError_t radix_sort_keys(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel ascending radix sort primitive for device level.
Definition: device_radix_sort.hpp:803
hipError_t radix_sort_pairs(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel ascending radix sort-by-key primitive for device level.
Definition: device_radix_sort.hpp:1035