23 #ifndef TEST_DEVICE_SEGMENTED_RADIX_SORT_HPP_ 24 #define TEST_DEVICE_SEGMENTED_RADIX_SORT_HPP_ 26 #include "../common_test_header.hpp" 29 #include <rocprim/device/device_segmented_radix_sort.hpp> 32 #include "test_utils_custom_float_type.hpp" 33 #include "test_utils_sort_comparator.hpp" 34 #include "test_utils_types.hpp" 39 unsigned int StartBit,
41 unsigned int MinSegmentLength,
42 unsigned int MaxSegmentLength,
43 class Config = rocprim::default_config>
47 using value_type = Value;
48 static constexpr
bool descending = Descending;
49 static constexpr
unsigned int start_bit = StartBit;
50 static constexpr
unsigned int end_bit = EndBit;
51 static constexpr
unsigned int min_segment_length = MinSegmentLength;
52 static constexpr
unsigned int max_segment_length = MaxSegmentLength;
53 using config = Config;
56 using config_default = rocprim::segmented_radix_sort_config<
59 rocprim::kernel_config<256, 4>
62 using config_semi_custom = rocprim::segmented_radix_sort_config<
65 rocprim::kernel_config<128, 4>,
66 rocprim::WarpSortConfig<16,
70 using config_semi_custom_warp_config = rocprim::segmented_radix_sort_config<
73 rocprim::kernel_config<128, 4>,
74 rocprim::WarpSortConfig<16,
81 using config_custom = rocprim::segmented_radix_sort_config<
84 rocprim::kernel_config<128, 4>,
85 rocprim::WarpSortConfig<16,
95 template<
class Params>
104 template<
typename TestFixture>
105 inline void sort_keys()
107 int device_id = test_common_utils::obtain_device_from_ctest();
108 SCOPED_TRACE(testing::Message() <<
"with device_id = " << device_id);
109 HIP_CHECK(hipSetDevice(device_id));
111 using key_type =
typename TestFixture::params::key_type;
112 using config =
typename TestFixture::params::config;
113 static constexpr
bool descending = TestFixture::params::descending;
114 static constexpr
unsigned int start_bit = TestFixture::params::start_bit;
115 static constexpr
unsigned int end_bit = TestFixture::params::end_bit;
117 using offset_type =
unsigned int;
119 hipStream_t stream = 0;
121 const bool debug_synchronous =
false;
123 std::random_device rd;
124 std::default_random_engine gen(rd());
126 std::uniform_int_distribution<size_t> segment_length_dis(
127 TestFixture::params::min_segment_length,
128 TestFixture::params::max_segment_length);
130 for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
132 unsigned int seed_value
133 = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
134 SCOPED_TRACE(testing::Message() <<
"with seed = " << seed_value);
136 for(
size_t size : test_utils::get_sizes(seed_value))
138 SCOPED_TRACE(testing::Message() <<
"with size = " << size);
141 std::vector<key_type> keys_input;
142 if(rocprim::is_floating_point<key_type>::value)
144 keys_input = test_utils::get_random_data<key_type>(size,
145 static_cast<key_type
>(-1000),
146 static_cast<key_type>(+1000),
152 = test_utils::get_random_data<key_type>(size,
158 std::vector<offset_type> offsets;
159 unsigned int segments_count = 0;
163 const size_t segment_length = segment_length_dis(gen);
164 offsets.push_back(offset);
166 offset += segment_length;
168 offsets.push_back(size);
170 key_type* d_keys_input;
171 key_type* d_keys_output;
172 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size *
sizeof(key_type)));
173 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size *
sizeof(key_type)));
174 HIP_CHECK(hipMemcpy(d_keys_input,
176 size *
sizeof(key_type),
177 hipMemcpyHostToDevice));
179 offset_type* d_offsets;
181 test_common_utils::hipMallocHelper(&d_offsets,
182 (segments_count + 1) *
sizeof(offset_type)));
183 HIP_CHECK(hipMemcpy(d_offsets,
185 (segments_count + 1) *
sizeof(offset_type),
186 hipMemcpyHostToDevice));
189 std::vector<key_type> expected(keys_input);
190 for(
size_t i = 0; i < segments_count; i++)
193 expected.begin() + offsets[i],
194 expected.begin() + offsets[i + 1],
198 size_t temporary_storage_bytes = 0;
199 HIP_CHECK(rocprim::segmented_radix_sort_keys<config>(
nullptr,
200 temporary_storage_bytes,
210 ASSERT_GT(temporary_storage_bytes, 0U);
212 void* d_temporary_storage;
214 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
218 HIP_CHECK(rocprim::segmented_radix_sort_keys_desc<config>(d_temporary_storage,
219 temporary_storage_bytes,
233 HIP_CHECK(rocprim::segmented_radix_sort_keys<config>(d_temporary_storage,
234 temporary_storage_bytes,
247 std::vector<key_type> keys_output(size);
248 HIP_CHECK(hipMemcpy(keys_output.data(),
250 size *
sizeof(key_type),
251 hipMemcpyDeviceToHost));
253 HIP_CHECK(hipFree(d_temporary_storage));
254 HIP_CHECK(hipFree(d_keys_input));
255 HIP_CHECK(hipFree(d_keys_output));
256 HIP_CHECK(hipFree(d_offsets));
258 ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected));
263 template<
typename TestFixture>
264 inline void sort_pairs()
266 int device_id = test_common_utils::obtain_device_from_ctest();
267 SCOPED_TRACE(testing::Message() <<
"with device_id = " << device_id);
268 HIP_CHECK(hipSetDevice(device_id));
270 using key_type =
typename TestFixture::params::key_type;
271 using value_type =
typename TestFixture::params::value_type;
272 using config =
typename TestFixture::params::config;
273 constexpr
bool descending = TestFixture::params::descending;
274 constexpr
unsigned int start_bit = TestFixture::params::start_bit;
275 constexpr
unsigned int end_bit = TestFixture::params::end_bit;
277 using offset_type =
unsigned int;
279 hipStream_t stream = 0;
281 const bool debug_synchronous =
false;
283 std::random_device rd;
284 std::default_random_engine gen(rd());
286 std::uniform_int_distribution<size_t> segment_length_dis(
287 TestFixture::params::min_segment_length,
288 TestFixture::params::max_segment_length);
290 for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
292 unsigned int seed_value
293 = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
294 SCOPED_TRACE(testing::Message() <<
"with seed = " << seed_value);
296 for(
size_t size : test_utils::get_sizes(seed_value))
298 SCOPED_TRACE(testing::Message() <<
"with size = " << size);
301 std::vector<key_type> keys_input;
302 if(rocprim::is_floating_point<key_type>::value)
304 keys_input = test_utils::get_random_data<key_type>(size,
305 static_cast<key_type
>(-1000),
306 static_cast<key_type>(+1000),
312 = test_utils::get_random_data<key_type>(size,
318 std::vector<offset_type> offsets;
319 unsigned int segments_count = 0;
323 const size_t segment_length = segment_length_dis(gen);
324 offsets.push_back(offset);
326 offset += segment_length;
328 offsets.push_back(size);
330 std::vector<value_type> values_input(size);
331 test_utils::iota(values_input.begin(), values_input.end(), 0);
333 key_type* d_keys_input;
334 key_type* d_keys_output;
335 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size *
sizeof(key_type)));
336 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size *
sizeof(key_type)));
337 HIP_CHECK(hipMemcpy(d_keys_input,
339 size *
sizeof(key_type),
340 hipMemcpyHostToDevice));
342 value_type* d_values_input;
343 value_type* d_values_output;
345 test_common_utils::hipMallocHelper(&d_values_input, size *
sizeof(value_type)));
347 test_common_utils::hipMallocHelper(&d_values_output, size *
sizeof(value_type)));
348 HIP_CHECK(hipMemcpy(d_values_input,
350 size *
sizeof(value_type),
351 hipMemcpyHostToDevice));
353 offset_type* d_offsets;
355 test_common_utils::hipMallocHelper(&d_offsets,
356 (segments_count + 1) *
sizeof(offset_type)));
357 HIP_CHECK(hipMemcpy(d_offsets,
359 (segments_count + 1) *
sizeof(offset_type),
360 hipMemcpyHostToDevice));
362 using key_value = std::pair<key_type, value_type>;
365 std::vector<key_value> expected(size);
366 for(
size_t i = 0; i < size; i++)
368 expected[i] = key_value(keys_input[i], values_input[i]);
370 for(
size_t i = 0; i < segments_count; i++)
372 std::stable_sort(expected.begin() + offsets[i],
373 expected.begin() + offsets[i + 1],
380 std::vector<key_type> keys_expected(size);
381 std::vector<value_type> values_expected(size);
382 for(
size_t i = 0; i < size; i++)
384 keys_expected[i] = expected[i].first;
385 values_expected[i] = expected[i].second;
388 void* d_temporary_storage =
nullptr;
389 size_t temporary_storage_bytes = 0;
390 HIP_CHECK(rocprim::segmented_radix_sort_pairs<config>(d_temporary_storage,
391 temporary_storage_bytes,
403 ASSERT_GT(temporary_storage_bytes, 0U);
406 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
410 HIP_CHECK(rocprim::segmented_radix_sort_pairs_desc<config>(d_temporary_storage,
411 temporary_storage_bytes,
427 HIP_CHECK(rocprim::segmented_radix_sort_pairs<config>(d_temporary_storage,
428 temporary_storage_bytes,
443 std::vector<key_type> keys_output(size);
444 HIP_CHECK(hipMemcpy(keys_output.data(),
446 size *
sizeof(key_type),
447 hipMemcpyDeviceToHost));
449 std::vector<value_type> values_output(size);
450 HIP_CHECK(hipMemcpy(values_output.data(),
452 size *
sizeof(value_type),
453 hipMemcpyDeviceToHost));
455 HIP_CHECK(hipFree(d_temporary_storage));
456 HIP_CHECK(hipFree(d_keys_input));
457 HIP_CHECK(hipFree(d_values_input));
458 HIP_CHECK(hipFree(d_keys_output));
459 HIP_CHECK(hipFree(d_values_output));
460 HIP_CHECK(hipFree(d_offsets));
462 ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, keys_expected));
463 ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, values_expected));
468 template<
typename TestFixture>
469 inline void sort_keys_double_buffer()
471 int device_id = test_common_utils::obtain_device_from_ctest();
472 SCOPED_TRACE(testing::Message() <<
"with device_id = " << device_id);
473 HIP_CHECK(hipSetDevice(device_id));
475 using key_type =
typename TestFixture::params::key_type;
476 using config =
typename TestFixture::params::config;
477 constexpr
bool descending = TestFixture::params::descending;
478 constexpr
unsigned int start_bit = TestFixture::params::start_bit;
479 constexpr
unsigned int end_bit = TestFixture::params::end_bit;
481 using offset_type =
unsigned int;
483 hipStream_t stream = 0;
485 const bool debug_synchronous =
false;
487 std::random_device rd;
488 std::default_random_engine gen(rd());
490 std::uniform_int_distribution<size_t> segment_length_dis(
491 TestFixture::params::min_segment_length,
492 TestFixture::params::max_segment_length);
494 for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
496 unsigned int seed_value
497 = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
498 SCOPED_TRACE(testing::Message() <<
"with seed = " << seed_value);
500 for(
size_t size : test_utils::get_sizes(seed_value))
502 SCOPED_TRACE(testing::Message() <<
"with size = " << size);
505 std::vector<key_type> keys_input;
506 if(rocprim::is_floating_point<key_type>::value)
508 keys_input = test_utils::get_random_data<key_type>(size,
509 static_cast<key_type
>(-1000),
510 static_cast<key_type>(+1000),
516 = test_utils::get_random_data<key_type>(size,
522 std::vector<offset_type> offsets;
523 unsigned int segments_count = 0;
527 const size_t segment_length = segment_length_dis(gen);
528 offsets.push_back(offset);
530 offset += segment_length;
532 offsets.push_back(size);
534 key_type* d_keys_input;
535 key_type* d_keys_output;
536 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size *
sizeof(key_type)));
537 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size *
sizeof(key_type)));
538 HIP_CHECK(hipMemcpy(d_keys_input,
540 size *
sizeof(key_type),
541 hipMemcpyHostToDevice));
543 offset_type* d_offsets;
545 test_common_utils::hipMallocHelper(&d_offsets,
546 (segments_count + 1) *
sizeof(offset_type)));
547 HIP_CHECK(hipMemcpy(d_offsets,
549 (segments_count + 1) *
sizeof(offset_type),
550 hipMemcpyHostToDevice));
553 std::vector<key_type> expected(keys_input);
554 for(
size_t i = 0; i < segments_count; i++)
557 expected.begin() + offsets[i],
558 expected.begin() + offsets[i + 1],
562 rocprim::double_buffer<key_type> d_keys(d_keys_input, d_keys_output);
564 size_t temporary_storage_bytes = 0;
565 HIP_CHECK(rocprim::segmented_radix_sort_keys<config>(
nullptr,
566 temporary_storage_bytes,
575 ASSERT_GT(temporary_storage_bytes, 0U);
577 void* d_temporary_storage;
579 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
583 HIP_CHECK(rocprim::segmented_radix_sort_keys_desc<config>(d_temporary_storage,
584 temporary_storage_bytes,
597 HIP_CHECK(rocprim::segmented_radix_sort_keys<config>(d_temporary_storage,
598 temporary_storage_bytes,
610 std::vector<key_type> keys_output(size);
611 HIP_CHECK(hipMemcpy(keys_output.data(),
613 size *
sizeof(key_type),
614 hipMemcpyDeviceToHost));
616 HIP_CHECK(hipFree(d_temporary_storage));
617 HIP_CHECK(hipFree(d_keys_input));
618 HIP_CHECK(hipFree(d_keys_output));
619 HIP_CHECK(hipFree(d_offsets));
621 ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected));
626 template<
typename TestFixture>
627 inline void sort_pairs_double_buffer()
629 int device_id = test_common_utils::obtain_device_from_ctest();
630 SCOPED_TRACE(testing::Message() <<
"with device_id = " << device_id);
631 HIP_CHECK(hipSetDevice(device_id));
633 using key_type =
typename TestFixture::params::key_type;
634 using value_type =
typename TestFixture::params::value_type;
635 using config =
typename TestFixture::params::config;
636 constexpr
bool descending = TestFixture::params::descending;
637 constexpr
unsigned int start_bit = TestFixture::params::start_bit;
638 constexpr
unsigned int end_bit = TestFixture::params::end_bit;
640 using offset_type =
unsigned int;
642 hipStream_t stream = 0;
644 const bool debug_synchronous =
false;
646 std::random_device rd;
647 std::default_random_engine gen(rd());
649 std::uniform_int_distribution<size_t> segment_length_dis(
650 TestFixture::params::min_segment_length,
651 TestFixture::params::max_segment_length);
653 for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
655 unsigned int seed_value
656 = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
657 SCOPED_TRACE(testing::Message() <<
"with seed = " << seed_value);
659 for(
size_t size : test_utils::get_sizes(seed_value))
661 SCOPED_TRACE(testing::Message() <<
"with size = " << size);
664 std::vector<key_type> keys_input;
665 if(rocprim::is_floating_point<key_type>::value)
667 keys_input = test_utils::get_random_data<key_type>(size,
668 static_cast<key_type
>(-1000),
669 static_cast<key_type>(+1000),
675 = test_utils::get_random_data<key_type>(size,
681 std::vector<offset_type> offsets;
682 unsigned int segments_count = 0;
686 const size_t segment_length = segment_length_dis(gen);
687 offsets.push_back(offset);
689 offset += segment_length;
691 offsets.push_back(size);
693 std::vector<value_type> values_input(size);
694 test_utils::iota(values_input.begin(), values_input.end(), 0);
696 key_type* d_keys_input;
697 key_type* d_keys_output;
698 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size *
sizeof(key_type)));
699 HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size *
sizeof(key_type)));
700 HIP_CHECK(hipMemcpy(d_keys_input,
702 size *
sizeof(key_type),
703 hipMemcpyHostToDevice));
705 value_type* d_values_input;
706 value_type* d_values_output;
708 test_common_utils::hipMallocHelper(&d_values_input, size *
sizeof(value_type)));
710 test_common_utils::hipMallocHelper(&d_values_output, size *
sizeof(value_type)));
711 HIP_CHECK(hipMemcpy(d_values_input,
713 size *
sizeof(value_type),
714 hipMemcpyHostToDevice));
716 offset_type* d_offsets;
718 test_common_utils::hipMallocHelper(&d_offsets,
719 (segments_count + 1) *
sizeof(offset_type)));
720 HIP_CHECK(hipMemcpy(d_offsets,
722 (segments_count + 1) *
sizeof(offset_type),
723 hipMemcpyHostToDevice));
725 using key_value = std::pair<key_type, value_type>;
728 std::vector<key_value> expected(size);
729 for(
size_t i = 0; i < size; i++)
731 expected[i] = key_value(keys_input[i], values_input[i]);
733 for(
size_t i = 0; i < segments_count; i++)
735 std::stable_sort(expected.begin() + offsets[i],
736 expected.begin() + offsets[i + 1],
743 std::vector<key_type> keys_expected(size);
744 std::vector<value_type> values_expected(size);
745 for(
size_t i = 0; i < size; i++)
747 keys_expected[i] = expected[i].first;
748 values_expected[i] = expected[i].second;
751 rocprim::double_buffer<key_type> d_keys(d_keys_input, d_keys_output);
752 rocprim::double_buffer<value_type> d_values(d_values_input, d_values_output);
754 void* d_temporary_storage =
nullptr;
755 size_t temporary_storage_bytes = 0;
756 HIP_CHECK(rocprim::segmented_radix_sort_pairs<config>(d_temporary_storage,
757 temporary_storage_bytes,
767 ASSERT_GT(temporary_storage_bytes, 0U);
770 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
774 HIP_CHECK(rocprim::segmented_radix_sort_pairs_desc<config>(d_temporary_storage,
775 temporary_storage_bytes,
789 HIP_CHECK(rocprim::segmented_radix_sort_pairs<config>(d_temporary_storage,
790 temporary_storage_bytes,
803 std::vector<key_type> keys_output(size);
804 HIP_CHECK(hipMemcpy(keys_output.data(),
806 size *
sizeof(key_type),
807 hipMemcpyDeviceToHost));
809 std::vector<value_type> values_output(size);
810 HIP_CHECK(hipMemcpy(values_output.data(),
812 size *
sizeof(value_type),
813 hipMemcpyDeviceToHost));
815 HIP_CHECK(hipFree(d_temporary_storage));
816 HIP_CHECK(hipFree(d_keys_input));
817 HIP_CHECK(hipFree(d_keys_output));
818 HIP_CHECK(hipFree(d_values_input));
819 HIP_CHECK(hipFree(d_values_output));
820 HIP_CHECK(hipFree(d_offsets));
822 ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, keys_expected));
823 ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, values_expected));
828 #endif // TEST_DEVICE_SEGMENTED_RADIX_SORT_HPP_ ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments.
Definition: functional.hpp:55
Definition: test_utils_sort_comparator.hpp:110
Definition: test_warp_exchange.cpp:34
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Definition: test_device_segmented_radix_sort.hpp:96
Definition: test_utils_sort_comparator.hpp:45
Definition: test_device_binary_search.cpp:37