rocPRIM
binary_op_wrappers.hpp
1 // Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DETAIL_BINARY_OP_WRAPPERS_HPP_
22 #define ROCPRIM_DETAIL_BINARY_OP_WRAPPERS_HPP_
23 
24 #include <type_traits>
25 
26 #include "../config.hpp"
27 #include "../intrinsics.hpp"
28 #include "../types.hpp"
29 #include "../functional.hpp"
30 
31 #include "../detail/various.hpp"
32 
33 BEGIN_ROCPRIM_NAMESPACE
34 
35 namespace detail
36 {
37 
38 template<
39  class BinaryFunction,
40  class ResultType = typename BinaryFunction::result_type,
41  class InputType = typename BinaryFunction::input_type
42 >
44 {
45  using result_type = ResultType;
46  using input_type = InputType;
47 
48  ROCPRIM_HOST_DEVICE inline
49  reverse_binary_op_wrapper() = default;
50 
51  ROCPRIM_HOST_DEVICE inline
52  reverse_binary_op_wrapper(BinaryFunction binary_op)
53  : binary_op_(binary_op)
54  {
55  }
56 
57  ROCPRIM_HOST_DEVICE inline
58  ~reverse_binary_op_wrapper() = default;
59 
60  ROCPRIM_HOST_DEVICE inline
61  result_type operator()(const input_type& t1, const input_type& t2)
62  {
63  return binary_op_(t2, t1);
64  }
65 
66 private:
67  BinaryFunction binary_op_;
68 };
69 
70 // Wrapper for performing head-flagged scan
71 template<class V, class F, class BinaryFunction>
73 {
74  static_assert(std::is_convertible<F, bool>::value, "F must be convertible to bool");
75 
76  using result_type = rocprim::tuple<V, F>;
77  using input_type = result_type;
78 
79  ROCPRIM_HOST_DEVICE inline
80  headflag_scan_op_wrapper() = default;
81 
82  ROCPRIM_HOST_DEVICE inline
83  headflag_scan_op_wrapper(BinaryFunction scan_op)
84  : scan_op_(scan_op)
85  {
86  }
87 
88  ROCPRIM_HOST_DEVICE inline
89  ~headflag_scan_op_wrapper() = default;
90 
91  ROCPRIM_HOST_DEVICE inline
92  result_type operator()(const input_type& t1, const input_type& t2)
93  {
94  return rocprim::make_tuple(
95  rocprim::get<1>(t2) == 0
96  ? scan_op_(rocprim::get<0>(t1), rocprim::get<0>(t2))
97  : static_cast<decltype(scan_op_(rocprim::get<0>(t1), rocprim::get<0>(t2)))>(
98  rocprim::get<0>(t2)),
99  F{rocprim::get<1>(t2) || rocprim::get<1>(t1)});
100  }
101 
102 private:
103  BinaryFunction scan_op_;
104 };
105 
106 
107 template<class EqualityOp>
109 {
110  using equality_op_type = EqualityOp;
111 
112  ROCPRIM_HOST_DEVICE inline
113  inequality_wrapper() = default;
114 
115  ROCPRIM_HOST_DEVICE inline
116  inequality_wrapper(equality_op_type equality_op)
117  : equality_op(equality_op)
118  {}
119 
120  template<class T, class U>
121  ROCPRIM_DEVICE ROCPRIM_INLINE
122  bool operator()(const T &a, const U &b)
123  {
124  return !equality_op(a, b);
125  }
126 
127  equality_op_type equality_op;
128 };
129 
130 } // end of detail namespace
131 
132 END_ROCPRIM_NAMESPACE
133 
134 #endif // ROCPRIM_DETAIL_BINARY_OP_WRAPPERS_HPP_
Definition: binary_op_wrappers.hpp:43
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: binary_op_wrappers.hpp:72
Definition: binary_op_wrappers.hpp:108