rocPRIM
merge_path.hpp
1 // Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DETAIL_MERGE_PATH_HPP_
22 #define ROCPRIM_DETAIL_MERGE_PATH_HPP_
23 
24 #include "../config.hpp"
25 
26 #include <iterator>
27 
28 BEGIN_ROCPRIM_NAMESPACE
29 
30 namespace detail
31 {
32 
33 struct range_t
34 {
35  unsigned int begin1;
36  unsigned int end1;
37  unsigned int begin2;
38  unsigned int end2;
39 
40  ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int count1() const
41  {
42  return end1 - begin1;
43  }
44 
45  ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int count2() const
46  {
47  return end2 - begin2;
48  }
49 };
50 
51 template<class KeysInputIterator1, class KeysInputIterator2, class OffsetT, class BinaryFunction>
52 ROCPRIM_DEVICE ROCPRIM_INLINE OffsetT merge_path(KeysInputIterator1 keys_input1,
53  KeysInputIterator2 keys_input2,
54  const OffsetT input1_size,
55  const OffsetT input2_size,
56  const OffsetT diag,
57  BinaryFunction compare_function)
58 {
59  using key_type_1 = typename std::iterator_traits<KeysInputIterator1>::value_type;
60  using key_type_2 = typename std::iterator_traits<KeysInputIterator2>::value_type;
61 
62  OffsetT begin = diag < input2_size ? 0u : diag - input2_size;
63  OffsetT end = min(diag, input1_size);
64 
65  while(begin < end)
66  {
67  OffsetT a = (begin + end) / 2;
68  OffsetT b = diag - 1 - a;
69  key_type_1 input_a = keys_input1[a];
70  key_type_2 input_b = keys_input2[b];
71  if(!compare_function(input_b, input_a))
72  {
73  begin = a + 1;
74  }
75  else
76  {
77  end = a;
78  }
79  }
80 
81  return begin;
82 }
83 
84 template<class KeyType, unsigned int ItemsPerThread, class BinaryFunction>
85 ROCPRIM_DEVICE ROCPRIM_INLINE void serial_merge(KeyType* keys_shared,
86  KeyType (&outputs)[ItemsPerThread],
87  unsigned int (&index)[ItemsPerThread],
88  range_t range,
89  BinaryFunction compare_function)
90 {
91  KeyType a = keys_shared[range.begin1];
92  KeyType b = keys_shared[range.begin2];
93 
94  ROCPRIM_UNROLL
95  for(unsigned int i = 0; i < ItemsPerThread; ++i)
96  {
97  bool compare = (range.begin2 >= range.end2)
98  || ((range.begin1 < range.end1) && !compare_function(b, a));
99  unsigned int x = compare ? range.begin1 : range.begin2;
100 
101  outputs[i] = compare ? a : b;
102  index[i] = x;
103 
104  KeyType c = keys_shared[++x];
105  if(compare)
106  {
107  a = c;
108  range.begin1 = x;
109  }
110  else
111  {
112  b = c;
113  range.begin2 = x;
114  }
115  }
117 }
118 
119 template<class KeyType, unsigned int ItemsPerThread, class BinaryFunction>
120 ROCPRIM_DEVICE ROCPRIM_INLINE void serial_merge(KeyType* keys_shared,
121  KeyType (&outputs)[ItemsPerThread],
122  range_t range,
123  BinaryFunction compare_function)
124 {
125  KeyType a = keys_shared[range.begin1];
126  KeyType b = keys_shared[range.begin2];
127 
128  ROCPRIM_UNROLL
129  for(unsigned int i = 0; i < ItemsPerThread; ++i)
130  {
131  bool compare = (range.begin2 >= range.end2)
132  || ((range.begin1 < range.end1) && !compare_function(b, a));
133  unsigned int x = compare ? range.begin1 : range.begin2;
134 
135  outputs[i] = compare ? a : b;
136 
137  KeyType c = keys_shared[++x];
138  if(compare)
139  {
140  a = c;
141  range.begin1 = x;
142  }
143  else
144  {
145  b = c;
146  range.begin2 = x;
147  }
148  }
150 }
151 
152 template<class KeyType, class ValueType, unsigned int ItemsPerThread, class BinaryFunction>
153 ROCPRIM_DEVICE ROCPRIM_INLINE void serial_merge(KeyType* keys_shared,
154  KeyType (&outputs)[ItemsPerThread],
155  ValueType* values_shared,
156  ValueType (&values)[ItemsPerThread],
157  range_t range,
158  BinaryFunction compare_function)
159 {
160  KeyType a = keys_shared[range.begin1];
161  KeyType b = keys_shared[range.begin2];
162 
163  ROCPRIM_UNROLL
164  for(unsigned int i = 0; i < ItemsPerThread; ++i)
165  {
166  bool compare = (range.begin2 >= range.end2)
167  || ((range.begin1 < range.end1) && !compare_function(b, a));
168  unsigned int x = compare ? range.begin1 : range.begin2;
169 
170  outputs[i] = compare ? a : b;
171  values[i] = values_shared[x];
172 
173  KeyType c = keys_shared[++x];
174  if(compare)
175  {
176  a = c;
177  range.begin1 = x;
178  }
179  else
180  {
181  b = c;
182  range.begin2 = x;
183  }
184  }
186 }
187 
188 } // end namespace detail
189 
190 END_ROCPRIM_NAMESPACE
191 
192 #endif // ROCPRIM_DETAIL_MERGE_PATH_HPP_
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
Definition: merge_path.hpp:33