22 #ifndef HIPCUB_TEST_UTILS_ARGMINMAX_HPP 23 #define HIPCUB_TEST_UTILS_ARGMINMAX_HPP 25 #include <hipcub/thread/thread_operators.hpp> 26 #include <type_traits> 33 template<
typename OffsetT,
35 std::enable_if_t<std::is_same<T, test_utils::half>::value
36 || std::is_same<T, test_utils::bfloat16>::value,
39 HIPCUB_HOST_DEVICE __forceinline__ hipcub::KeyValuePair<OffsetT, T>
40 operator()(
const hipcub::KeyValuePair<OffsetT, T>& a,
41 const hipcub::KeyValuePair<OffsetT, T>& b)
const 43 const hipcub::KeyValuePair<OffsetT, float> native_a(a.key, a.value);
44 const hipcub::KeyValuePair<OffsetT, float> native_b(b.key, b.value);
46 if((native_b.value > native_a.value)
47 || ((native_a.value == native_b.value) && (native_b.key < native_a.key)))
57 template<
typename OffsetT,
59 std::enable_if_t<std::is_same<T, test_utils::half>::value
60 || std::is_same<T, test_utils::bfloat16>::value,
63 HIPCUB_HOST_DEVICE __forceinline__ hipcub::KeyValuePair<OffsetT, T>
64 operator()(
const hipcub::KeyValuePair<OffsetT, T>& a,
65 const hipcub::KeyValuePair<OffsetT, T>& b)
const 67 const hipcub::KeyValuePair<OffsetT, float> native_a(a.key, a.value);
68 const hipcub::KeyValuePair<OffsetT, float> native_b(b.key, b.value);
70 if((native_b.value < native_a.value)
71 || ((native_a.value == native_b.value) && (native_b.key < native_a.key)))
81 typedef hipcub::ArgMax type;
100 typedef hipcub::ArgMin type;
103 #ifdef __HIP_PLATFORM_NVIDIA__ 117 #endif //HIPCUB_TEST_UTILS_ARGMINMAX_HPP Definition: test_utils_argminmax.hpp:79
Definition: test_utils_argminmax.hpp:98
Arg max functor - Because NVIDIA's hipcub::ArgMax doesn't work with bfloat16 (HOST-SIDE) ...
Definition: thread_operators.hpp:125
Arg min functor - Because NVIDIA's hipcub::ArgMin doesn't work with bfloat16 (HOST-SIDE) ...
Definition: thread_operators.hpp:140
Definition: identity_iterator.hpp:26