23 #ifndef TEST_DEVICE_RADIX_SORT_HPP_    24 #define TEST_DEVICE_RADIX_SORT_HPP_    26 #include "../common_test_header.hpp"    29 #include <rocprim/device/device_radix_sort.hpp>    32 #include "test_utils_custom_float_type.hpp"    33 #include "test_utils_sort_comparator.hpp"    34 #include "test_utils_types.hpp"    38          bool         Descending      = 
false,
    39          unsigned int StartBit        = 0,
    40          unsigned int EndBit          = 
sizeof(Key) * 8,
    41          bool         CheckLargeSizes = 
false>
    45     using value_type                                = Value;
    46     static constexpr 
bool         descending        = Descending;
    47     static constexpr 
unsigned int start_bit         = StartBit;
    48     static constexpr 
unsigned int end_bit           = EndBit;
    49     static constexpr 
bool         check_large_sizes = CheckLargeSizes;
    52 template<
class Params>
    61 template<
typename TestFixture>
    62 inline void sort_keys()
    64     int device_id = test_common_utils::obtain_device_from_ctest();
    65     SCOPED_TRACE(testing::Message() << 
"with device_id = " << device_id);
    66     HIP_CHECK(hipSetDevice(device_id));
    68     using key_type                           = 
typename TestFixture::params::key_type;
    69     constexpr 
bool         descending        = TestFixture::params::descending;
    70     constexpr 
unsigned int start_bit         = TestFixture::params::start_bit;
    71     constexpr 
unsigned int end_bit           = TestFixture::params::end_bit;
    72     constexpr 
bool         check_large_sizes = TestFixture::params::check_large_sizes;
    74     hipStream_t stream = 0;
    76     const bool debug_synchronous = 
false;
    78     bool in_place = 
false;
    80     for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
    82         unsigned int seed_value
    83             = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
    84         SCOPED_TRACE(testing::Message() << 
"with seed = " << seed_value);
    86         auto sizes = test_utils::get_sizes(seed_value);
    87         sizes.push_back(1 << 23);
    89         for(
size_t size : sizes)
    91             if(size > (1 << 17) && !check_large_sizes)
    94             SCOPED_TRACE(testing::Message() << 
"with size = " << size);
    99             std::vector<key_type> keys_input;
   100             if(rocprim::is_floating_point<key_type>::value)
   102                 keys_input = test_utils::get_random_data<key_type>(size,
   103                                                                    static_cast<key_type
>(-1000),
   104                                                                    static_cast<key_type>(+1000),
   106                 test_utils::add_special_values(keys_input, seed_value);
   111                     = test_utils::get_random_data<key_type>(size,
   117             key_type* d_keys_input;
   118             key_type* d_keys_output;
   119             HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * 
sizeof(key_type)));
   122                 d_keys_output = d_keys_input;
   127                     test_common_utils::hipMallocHelper(&d_keys_output, size * 
sizeof(key_type)));
   129             HIP_CHECK(hipMemcpy(d_keys_input,
   131                                 size * 
sizeof(key_type),
   132                                 hipMemcpyHostToDevice));
   135             std::vector<key_type> expected(keys_input);
   142             using config = rocprim::radix_sort_config_v2<rocprim::default_config,
   143                                                          rocprim::default_config,
   144                                                          rocprim::default_config,
   147             size_t temporary_storage_bytes;
   148             HIP_CHECK(rocprim::radix_sort_keys<config>(
nullptr,
   149                                                        temporary_storage_bytes,
   156             ASSERT_GT(temporary_storage_bytes, 0);
   158             void* d_temporary_storage;
   160                 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
   164                 HIP_CHECK(rocprim::radix_sort_keys_desc<config>(d_temporary_storage,
   165                                                                 temporary_storage_bytes,
   176                 HIP_CHECK(rocprim::radix_sort_keys<config>(d_temporary_storage,
   177                                                            temporary_storage_bytes,
   187             std::vector<key_type> keys_output(size);
   188             HIP_CHECK(hipMemcpy(keys_output.data(),
   190                                 size * 
sizeof(key_type),
   191                                 hipMemcpyDeviceToHost));
   193             HIP_CHECK(hipFree(d_temporary_storage));
   194             HIP_CHECK(hipFree(d_keys_input));
   197                 HIP_CHECK(hipFree(d_keys_output));
   200             ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected));
   205 template<
typename TestFixture>
   206 inline void sort_pairs()
   208     int device_id = test_common_utils::obtain_device_from_ctest();
   209     SCOPED_TRACE(testing::Message() << 
"with device_id = " << device_id);
   210     HIP_CHECK(hipSetDevice(device_id));
   212     using key_type                           = 
typename TestFixture::params::key_type;
   213     using value_type                         = 
typename TestFixture::params::value_type;
   214     constexpr 
bool         descending        = TestFixture::params::descending;
   215     constexpr 
unsigned int start_bit         = TestFixture::params::start_bit;
   216     constexpr 
unsigned int end_bit           = TestFixture::params::end_bit;
   217     constexpr 
bool         check_large_sizes = TestFixture::params::check_large_sizes;
   219     hipStream_t stream = 0;
   221     const bool debug_synchronous = 
false;
   223     bool in_place = 
false;
   225     for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
   227         unsigned int seed_value
   228             = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
   229         SCOPED_TRACE(testing::Message() << 
"with seed = " << seed_value);
   231         auto sizes = test_utils::get_sizes(seed_value);
   232         sizes.push_back(1 << 23);
   234         for(
size_t size : sizes)
   236             if(size > (1 << 17) && !check_large_sizes)
   239             SCOPED_TRACE(testing::Message() << 
"with size = " << size);
   241             in_place = !in_place;
   244             std::vector<key_type> keys_input;
   245             if(rocprim::is_floating_point<key_type>::value)
   247                 keys_input = test_utils::get_random_data<key_type>(size,
   248                                                                    static_cast<key_type
>(-1000),
   249                                                                    static_cast<key_type>(+1000),
   251                 test_utils::add_special_values(keys_input, seed_value);
   256                     = test_utils::get_random_data<key_type>(size,
   262             std::vector<value_type> values_input(size);
   263             test_utils::iota(values_input.begin(), values_input.end(), 0);
   265             key_type* d_keys_input;
   266             key_type* d_keys_output;
   267             HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * 
sizeof(key_type)));
   270                 d_keys_output = d_keys_input;
   275                     test_common_utils::hipMallocHelper(&d_keys_output, size * 
sizeof(key_type)));
   277             HIP_CHECK(hipMemcpy(d_keys_input,
   279                                 size * 
sizeof(key_type),
   280                                 hipMemcpyHostToDevice));
   282             value_type* d_values_input;
   283             value_type* d_values_output;
   285                 test_common_utils::hipMallocHelper(&d_values_input, size * 
sizeof(value_type)));
   288                 d_values_output = d_values_input;
   292                 HIP_CHECK(test_common_utils::hipMallocHelper(&d_values_output,
   293                                                              size * 
sizeof(value_type)));
   295             HIP_CHECK(hipMemcpy(d_values_input,
   297                                 size * 
sizeof(value_type),
   298                                 hipMemcpyHostToDevice));
   300             using key_value = std::pair<key_type, value_type>;
   303             std::vector<key_value> expected(size);
   304             for(
size_t i = 0; i < size; i++)
   306                 expected[i] = key_value(keys_input[i], values_input[i]);
   312                     key_value_comparator<key_type, value_type, descending, start_bit, end_bit>());
   313             std::vector<key_type>   keys_expected(size);
   314             std::vector<value_type> values_expected(size);
   315             for(
size_t i = 0; i < size; i++)
   317                 keys_expected[i]   = expected[i].first;
   318                 values_expected[i] = expected[i].second;
   322             using config = rocprim::radix_sort_config_v2<
   323                 rocprim::kernel_config<256, 1>,
   324                 rocprim::merge_sort_config<128, 64, 2, 128, 64, 2>,
   325                 rocprim::radix_sort_onesweep_config<rocprim::kernel_config<128, 1>,
   326                                                     rocprim::kernel_config<128, 1>,
   330             void*  d_temporary_storage = 
nullptr;
   331             size_t temporary_storage_bytes;
   332             HIP_CHECK(rocprim::radix_sort_pairs<config>(d_temporary_storage,
   333                                                         temporary_storage_bytes,
   342             ASSERT_GT(temporary_storage_bytes, 0);
   345                 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
   349                 HIP_CHECK(rocprim::radix_sort_pairs_desc<config>(d_temporary_storage,
   350                                                                  temporary_storage_bytes,
   363                 HIP_CHECK(rocprim::radix_sort_pairs<config>(d_temporary_storage,
   364                                                             temporary_storage_bytes,
   376             std::vector<key_type> keys_output(size);
   377             HIP_CHECK(hipMemcpy(keys_output.data(),
   379                                 size * 
sizeof(key_type),
   380                                 hipMemcpyDeviceToHost));
   382             std::vector<value_type> values_output(size);
   383             HIP_CHECK(hipMemcpy(values_output.data(),
   385                                 size * 
sizeof(value_type),
   386                                 hipMemcpyDeviceToHost));
   388             HIP_CHECK(hipFree(d_temporary_storage));
   389             HIP_CHECK(hipFree(d_keys_input));
   390             HIP_CHECK(hipFree(d_values_input));
   393                 HIP_CHECK(hipFree(d_keys_output));
   394                 HIP_CHECK(hipFree(d_values_output));
   397             ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected));
   398             ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected));
   403 template<
typename TestFixture>
   404 inline void sort_keys_double_buffer()
   406     int device_id = test_common_utils::obtain_device_from_ctest();
   407     SCOPED_TRACE(testing::Message() << 
"with device_id = " << device_id);
   408     HIP_CHECK(hipSetDevice(device_id));
   410     using key_type                           = 
typename TestFixture::params::key_type;
   411     constexpr 
bool         descending        = TestFixture::params::descending;
   412     constexpr 
unsigned int start_bit         = TestFixture::params::start_bit;
   413     constexpr 
unsigned int end_bit           = TestFixture::params::end_bit;
   414     constexpr 
bool         check_large_sizes = TestFixture::params::check_large_sizes;
   416     hipStream_t stream = 0;
   418     const bool debug_synchronous = 
false;
   420     for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
   422         unsigned int seed_value
   423             = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
   424         SCOPED_TRACE(testing::Message() << 
"with seed = " << seed_value);
   426         auto sizes = test_utils::get_sizes(seed_value);
   427         sizes.push_back(1 << 23);
   429         for(
size_t size : sizes)
   431             if(size > (1 << 17) && !check_large_sizes)
   434             SCOPED_TRACE(testing::Message() << 
"with size = " << size);
   437             std::vector<key_type> keys_input;
   438             if(rocprim::is_floating_point<key_type>::value)
   440                 keys_input = test_utils::get_random_data<key_type>(size,
   441                                                                    static_cast<key_type
>(-1000),
   442                                                                    static_cast<key_type>(+1000),
   444                 test_utils::add_special_values(keys_input, seed_value);
   449                     = test_utils::get_random_data<key_type>(size,
   455             key_type* d_keys_input;
   456             key_type* d_keys_output;
   457             HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * 
sizeof(key_type)));
   458             HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * 
sizeof(key_type)));
   459             HIP_CHECK(hipMemcpy(d_keys_input,
   461                                 size * 
sizeof(key_type),
   462                                 hipMemcpyHostToDevice));
   465             std::vector<key_type> expected(keys_input);
   471             rocprim::double_buffer<key_type> d_keys(d_keys_input, d_keys_output);
   473             size_t temporary_storage_bytes;
   475                                                temporary_storage_bytes,
   481             ASSERT_GT(temporary_storage_bytes, 0);
   483             void* d_temporary_storage;
   485                 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
   490                                                         temporary_storage_bytes,
   501                                                    temporary_storage_bytes,
   510             HIP_CHECK(hipFree(d_temporary_storage));
   512             std::vector<key_type> keys_output(size);
   513             HIP_CHECK(hipMemcpy(keys_output.data(),
   515                                 size * 
sizeof(key_type),
   516                                 hipMemcpyDeviceToHost));
   518             HIP_CHECK(hipFree(d_keys_input));
   519             HIP_CHECK(hipFree(d_keys_output));
   521             ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected));
   526 template<
typename TestFixture>
   527 inline void sort_pairs_double_buffer()
   529     int device_id = test_common_utils::obtain_device_from_ctest();
   530     SCOPED_TRACE(testing::Message() << 
"with device_id = " << device_id);
   531     HIP_CHECK(hipSetDevice(device_id));
   533     using key_type                           = 
typename TestFixture::params::key_type;
   534     using value_type                         = 
typename TestFixture::params::value_type;
   535     constexpr 
bool         descending        = TestFixture::params::descending;
   536     constexpr 
unsigned int start_bit         = TestFixture::params::start_bit;
   537     constexpr 
unsigned int end_bit           = TestFixture::params::end_bit;
   538     constexpr 
bool         check_large_sizes = TestFixture::params::check_large_sizes;
   540     hipStream_t stream = 0;
   542     const bool debug_synchronous = 
false;
   544     for(
size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++)
   546         unsigned int seed_value
   547             = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
   548         SCOPED_TRACE(testing::Message() << 
"with seed = " << seed_value);
   550         auto sizes = test_utils::get_sizes(seed_value);
   551         sizes.push_back(1 << 23);
   553         for(
size_t size : sizes)
   555             if(size > (1 << 17) && !check_large_sizes)
   558             SCOPED_TRACE(testing::Message() << 
"with size = " << size);
   561             std::vector<key_type> keys_input;
   562             if(rocprim::is_floating_point<key_type>::value)
   564                 keys_input = test_utils::get_random_data<key_type>(size,
   565                                                                    static_cast<key_type
>(-1000),
   566                                                                    static_cast<key_type>(+1000),
   568                 test_utils::add_special_values(keys_input, seed_value);
   573                     = test_utils::get_random_data<key_type>(size,
   579             std::vector<value_type> values_input(size);
   580             test_utils::iota(values_input.begin(), values_input.end(), 0);
   582             key_type* d_keys_input;
   583             key_type* d_keys_output;
   584             HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * 
sizeof(key_type)));
   585             HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * 
sizeof(key_type)));
   586             HIP_CHECK(hipMemcpy(d_keys_input,
   588                                 size * 
sizeof(key_type),
   589                                 hipMemcpyHostToDevice));
   591             value_type* d_values_input;
   592             value_type* d_values_output;
   594                 test_common_utils::hipMallocHelper(&d_values_input, size * 
sizeof(value_type)));
   596                 test_common_utils::hipMallocHelper(&d_values_output, size * 
sizeof(value_type)));
   597             HIP_CHECK(hipMemcpy(d_values_input,
   599                                 size * 
sizeof(value_type),
   600                                 hipMemcpyHostToDevice));
   602             using key_value = std::pair<key_type, value_type>;
   605             std::vector<key_value> expected(size);
   606             for(
size_t i = 0; i < size; i++)
   608                 expected[i] = key_value(keys_input[i], values_input[i]);
   614                     key_value_comparator<key_type, value_type, descending, start_bit, end_bit>());
   615             std::vector<key_type>   keys_expected(size);
   616             std::vector<value_type> values_expected(size);
   617             for(
size_t i = 0; i < size; i++)
   619                 keys_expected[i]   = expected[i].first;
   620                 values_expected[i] = expected[i].second;
   623             rocprim::double_buffer<key_type>   d_keys(d_keys_input, d_keys_output);
   624             rocprim::double_buffer<value_type> d_values(d_values_input, d_values_output);
   626             void*  d_temporary_storage = 
nullptr;
   627             size_t temporary_storage_bytes;
   629                                                 temporary_storage_bytes,
   636             ASSERT_GT(temporary_storage_bytes, 0);
   639                 test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
   644                                                          temporary_storage_bytes,
   656                                                     temporary_storage_bytes,
   666             HIP_CHECK(hipFree(d_temporary_storage));
   668             std::vector<key_type> keys_output(size);
   669             HIP_CHECK(hipMemcpy(keys_output.data(),
   671                                 size * 
sizeof(key_type),
   672                                 hipMemcpyDeviceToHost));
   674             std::vector<value_type> values_output(size);
   675             HIP_CHECK(hipMemcpy(values_output.data(),
   677                                 size * 
sizeof(value_type),
   678                                 hipMemcpyDeviceToHost));
   680             HIP_CHECK(hipFree(d_keys_input));
   681             HIP_CHECK(hipFree(d_keys_output));
   682             HIP_CHECK(hipFree(d_values_input));
   683             HIP_CHECK(hipFree(d_values_output));
   685             ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected));
   686             ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected));
   691 inline void sort_keys_over_4g()
   693     using key_type                                 = uint8_t;
   694     constexpr 
unsigned int start_bit               = 0;
   695     constexpr 
unsigned int end_bit                 = 8ull * 
sizeof(key_type);
   696     constexpr hipStream_t  stream                  = 0;
   697     constexpr 
bool         debug_synchronous       = 
false;
   698     constexpr 
size_t       size                    = (1ull << 32) + 32;
   699     constexpr 
size_t       number_of_possible_keys = 1ull << (8ull * 
sizeof(key_type));
   700     assert(std::is_unsigned<key_type>::value);
   701     std::vector<size_t> 
histogram(number_of_possible_keys, 0);
   702     const int           seed_value = rand();
   704     const int device_id = test_common_utils::obtain_device_from_ctest();
   705     SCOPED_TRACE(testing::Message() << 
"with device_id = " << device_id);
   706     HIP_CHECK(hipSetDevice(device_id));
   708     std::vector<key_type> keys_input
   709         = test_utils::get_random_data<key_type>(size,
   715     std::for_each(keys_input.begin(), keys_input.end(), [&](
const key_type& a) { histogram[a]++; });
   717     key_type* d_keys_input_output{};
   718     size_t key_type_storage_bytes = size * 
sizeof(key_type);
   720     HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input_output, key_type_storage_bytes));
   721     HIP_CHECK(hipMemcpy(d_keys_input_output,
   723                         key_type_storage_bytes,
   724                         hipMemcpyHostToDevice));
   726     size_t temporary_storage_bytes;
   728                                        temporary_storage_bytes,
   737     ASSERT_GT(temporary_storage_bytes, 0);
   739     hipDeviceProp_t prop;
   740     HIP_CHECK(hipGetDeviceProperties(&prop, device_id));
   742    size_t total_storage_bytes = key_type_storage_bytes +  temporary_storage_bytes;
   743     if (total_storage_bytes > (static_cast<size_t>(prop.totalGlobalMem * 0.90))) {
   744         HIP_CHECK(hipFree(d_keys_input_output));
   745         GTEST_SKIP() << 
"Test case device memory requirement (" << total_storage_bytes << 
" bytes) exceeds available memory on current device ("   746                      << prop.totalGlobalMem << 
" bytes). Skipping test";
   749     void* d_temporary_storage;
   750     HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));
   753                                        temporary_storage_bytes,
   762     std::vector<key_type> output(keys_input.size());
   763     HIP_CHECK(hipMemcpy(output.data(),
   765                         size * 
sizeof(key_type),
   766                         hipMemcpyDeviceToHost));
   771         for(
size_t j = 0; j < histogram[i]; ++j)
   773             ASSERT_EQ(static_cast<size_t>(output[counter]), i);
   777     ASSERT_EQ(counter, size);
   779     HIP_CHECK(hipFree(d_keys_input_output));
   780     HIP_CHECK(hipFree(d_temporary_storage));
   783 #endif // TEST_DEVICE_RADIX_SORT_HPP_ ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments. 
Definition: functional.hpp:55
Definition: test_warp_exchange.cpp:34
hipError_t radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel descending radix sort-by-key primitive for device level. 
Definition: device_radix_sort.hpp:1157
Definition: test_device_radix_sort.hpp:53
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments. 
Definition: functional.hpp:63
hipError_t radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel descending radix sort primitive for device level. 
Definition: device_radix_sort.hpp:910
Definition: test_utils_sort_comparator.hpp:45
Definition: test_device_binary_search.cpp:37
Definition: benchmark_block_histogram.cpp:64
hipError_t radix_sort_keys(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel ascending radix sort primitive for device level. 
Definition: device_radix_sort.hpp:803
hipError_t radix_sort_pairs(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel ascending radix sort-by-key primitive for device level. 
Definition: device_radix_sort.hpp:1035