hipCUB
test_utils_argminmax.hpp
1 // MIT License
2 //
3 // Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included in all
13 // copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // OUT OF OR IN
21 
22 #ifndef HIPCUB_TEST_UTILS_ARGMINMAX_HPP
23 #define HIPCUB_TEST_UTILS_ARGMINMAX_HPP
24 
25 #include <hipcub/thread/thread_operators.hpp>
26 #include <type_traits>
27 
31 struct ArgMax
32 {
33  template<typename OffsetT,
34  class T,
35  std::enable_if_t<std::is_same<T, test_utils::half>::value
36  || std::is_same<T, test_utils::bfloat16>::value,
37  bool>
38  = true>
39  HIPCUB_HOST_DEVICE __forceinline__ hipcub::KeyValuePair<OffsetT, T>
40  operator()(const hipcub::KeyValuePair<OffsetT, T>& a,
41  const hipcub::KeyValuePair<OffsetT, T>& b) const
42  {
43  const hipcub::KeyValuePair<OffsetT, float> native_a(a.key, a.value);
44  const hipcub::KeyValuePair<OffsetT, float> native_b(b.key, b.value);
45 
46  if((native_b.value > native_a.value)
47  || ((native_a.value == native_b.value) && (native_b.key < native_a.key)))
48  return b;
49  return a;
50  }
51 };
55 struct ArgMin
56 {
57  template<typename OffsetT,
58  class T,
59  std::enable_if_t<std::is_same<T, test_utils::half>::value
60  || std::is_same<T, test_utils::bfloat16>::value,
61  bool>
62  = true>
63  HIPCUB_HOST_DEVICE __forceinline__ hipcub::KeyValuePair<OffsetT, T>
64  operator()(const hipcub::KeyValuePair<OffsetT, T>& a,
65  const hipcub::KeyValuePair<OffsetT, T>& b) const
66  {
67  const hipcub::KeyValuePair<OffsetT, float> native_a(a.key, a.value);
68  const hipcub::KeyValuePair<OffsetT, float> native_b(b.key, b.value);
69 
70  if((native_b.value < native_a.value)
71  || ((native_a.value == native_b.value) && (native_b.key < native_a.key)))
72  return b;
73  return a;
74  }
75 };
76 
77 // Maximum to operator selector
78 template<typename T>
80 {
81  typedef hipcub::ArgMax type;
82 };
83 
84 template<>
86 {
87  typedef ArgMax type;
88 };
89 
90 template<>
91 struct ArgMaxSelector<test_utils::bfloat16>
92 {
93  typedef ArgMax type;
94 };
95 
96 // Minimum to operator selector
97 template<typename T>
99 {
100  typedef hipcub::ArgMin type;
101 };
102 
103 #ifdef __HIP_PLATFORM_NVIDIA__
104 template<>
105 struct ArgMinSelector<test_utils::half>
106 {
107  typedef ArgMin type;
108 };
109 
110 template<>
111 struct ArgMinSelector<test_utils::bfloat16>
112 {
113  typedef ArgMin type;
114 };
115 #endif
116 
117 #endif //HIPCUB_TEST_UTILS_ARGMINMAX_HPP
Definition: test_utils_argminmax.hpp:79
Definition: test_utils_argminmax.hpp:98
Arg max functor - Because NVIDIA&#39;s hipcub::ArgMax doesn&#39;t work with bfloat16 (HOST-SIDE) ...
Definition: thread_operators.hpp:125
Arg min functor - Because NVIDIA&#39;s hipcub::ArgMin doesn&#39;t work with bfloat16 (HOST-SIDE) ...
Definition: thread_operators.hpp:140
Definition: identity_iterator.hpp:26