rocPRIM
temp_storage.hpp
1 // Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DETAIL_TEMP_STORAGE_HPP_
22 #define ROCPRIM_DETAIL_TEMP_STORAGE_HPP_
23 
24 #include <cstddef>
25 
26 #include "../config.hpp"
27 #include "../types.hpp"
28 #include "various.hpp"
29 
30 BEGIN_ROCPRIM_NAMESPACE
31 namespace detail
32 {
33 
34 namespace temp_storage
35 {
36 
38 constexpr size_t default_alignment = 256;
39 
43 constexpr size_t minimum_allocation_size = 4;
44 
47 struct layout
48 {
50  size_t size;
51 
53  size_t alignment = default_alignment;
54 };
55 
59 template<typename T>
61 {
63  T** dest;
66 
69  {
70  return this->storage_layout;
71  }
72 
76  void set_storage(void* const storage)
77  {
78  *this->dest = this->storage_layout.size == 0 ? nullptr : static_cast<T*>(storage);
79  }
80 };
81 
86 template<typename T>
87 simple_partition<T> make_partition(T** dest, layout storage_layout)
88 {
89  return simple_partition<T>{dest, storage_layout};
90 }
91 
97 template<typename T>
98 simple_partition<T> make_partition(T** dest, size_t size, size_t alignment = default_alignment)
99 {
100  return make_partition(dest, {size, alignment});
101 }
102 
108 template<typename T>
109 simple_partition<T> ptr_aligned_array(T** dest, size_t elements)
110 {
111  return make_partition(dest, elements * sizeof(T), alignof(T));
112 }
113 
122 template<typename... Ts>
124 {
126  ::rocprim::tuple<Ts...> sub_partitions;
127 
129  linear_partition(Ts... sub_partitions) : sub_partitions{sub_partitions...} {}
130 
133  {
134  size_t required_alignment = 1;
135  size_t required_size = 0;
136 
137  for_each_in_tuple(this->sub_partitions,
138  [&](auto& sub_partition)
139  {
140  const auto sub_layout = sub_partition.get_layout();
141 
142  required_alignment
143  = std::max(required_alignment, sub_layout.alignment);
144 
145  if(sub_layout.size > 0)
146  required_size = align_size(required_size, sub_layout.alignment)
147  + sub_layout.size;
148  });
149 
150  return {required_size, required_alignment};
151  }
152 
156  void set_storage(void* const storage)
157  {
158  size_t offset = 0;
159  for_each_in_tuple(this->sub_partitions,
160  [&](auto& sub_partition)
161  {
162  const auto sub_layout = sub_partition.get_layout();
163 
164  if(sub_layout.size > 0)
165  offset = align_size(offset, sub_layout.alignment);
166 
167  sub_partition.set_storage(
168  static_cast<void*>(static_cast<char*>(storage) + offset));
169  offset += sub_layout.size;
170  });
171  }
172 };
173 
177 template<typename... Ts>
178 linear_partition<Ts...> make_linear_partition(Ts... ts)
179 {
180  return linear_partition<Ts...>(ts...);
181 }
182 
190 template<typename... Ts>
192 {
194  ::rocprim::tuple<Ts...> sub_partitions;
195 
197  union_partition(Ts... sub_partitions) : sub_partitions{sub_partitions...} {}
198 
201  {
202  size_t required_alignment = 1;
203  size_t required_size = 0;
204 
205  for_each_in_tuple(this->sub_partitions,
206  [&](auto& sub_partition)
207  {
208  const auto sub_layout = sub_partition.get_layout();
209 
210  required_alignment
211  = std::max(required_alignment, sub_layout.alignment);
212  required_size = std::max(required_size, sub_layout.size);
213  });
214 
215  return {required_size, required_alignment};
216  }
217 
221  void set_storage(void* const storage)
222  {
223  for_each_in_tuple(this->sub_partitions,
224  [&](auto& sub_partition) { sub_partition.set_storage(storage); });
225  }
226 };
227 
231 template<typename... Ts>
232 union_partition<Ts...> make_union_partition(Ts... ts)
233 {
234  return union_partition<Ts...>(ts...);
235 }
236 
263 template<typename TempStoragePartition>
264 hipError_t
265  partition(void* const temporary_storage, size_t& storage_size, TempStoragePartition partition)
266 {
267  const auto layout = partition.get_layout();
268  // Make sure the user wont try to allocate 0 bytes of memory.
269  const size_t required_size = std::max(layout.size, minimum_allocation_size);
270 
271  if(temporary_storage == nullptr)
272  {
273  storage_size = required_size;
274  return hipSuccess;
275  }
276  else if(storage_size < required_size)
277  {
278  return hipErrorInvalidValue;
279  }
280 
281  partition.set_storage(temporary_storage);
282 
283  return hipSuccess;
284 }
285 } // namespace temp_storage
286 
287 } // namespace detail
288 END_ROCPRIM_NAMESPACE
289 
290 #endif // ROCPRIM_DETAIL_TEMP_STORAGE_HPP_
A partition that represents a linear sequence of sub-partitions.
Definition: temp_storage.hpp:123
ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments.
Definition: functional.hpp:55
layout get_layout()
Compute the required layout for this type and return it.
Definition: temp_storage.hpp:200
This structure describes a single required partition of temporary global memory, as well as where to ...
Definition: temp_storage.hpp:60
hipError_t partition(void *temporary_storage, size_t &storage_size, InputIterator input, FlagIterator flags, OutputIterator output, SelectedCountOutputIterator selected_count_output, const size_t size, const hipStream_t stream=0, const bool debug_synchronous=false)
Parallel select primitive for device level using range of flags.
Definition: device_partition.hpp:721
T ** dest
The location to store the pointer to the partitioned memory.
Definition: temp_storage.hpp:63
layout get_layout()
Compute the required layout for this type and return it.
Definition: temp_storage.hpp:132
This value-structure describes the required layout of some piece of temporary memory, which includes the required size and the required alignment.
Definition: temp_storage.hpp:47
::rocprim::tuple< Ts... > sub_partitions
The sub-partitions in this union_partition.
Definition: temp_storage.hpp:194
layout storage_layout
The required memory layout of the memory that this partition should get.
Definition: temp_storage.hpp:65
A partition that represents a union of sub-partitions of temporary memories which are not used at the...
Definition: temp_storage.hpp:191
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
void set_storage(void *const storage)
Assigns the final storage for this partition.
Definition: temp_storage.hpp:156
size_t size
The required size of the temporary memory, in bytes.
Definition: temp_storage.hpp:50
void set_storage(void *const storage)
Assigns the final storage for this partition.
Definition: temp_storage.hpp:221
void set_storage(void *const storage)
Assigns the final storage for this partition.
Definition: temp_storage.hpp:76
layout get_layout()
Compute the required layout for this type and return it.
Definition: temp_storage.hpp:68
linear_partition(Ts... sub_partitions)
Constructor.
Definition: temp_storage.hpp:129
size_t alignment
The required alignment of the temporary memory, in bytes. Defaults to default_alignment.
Definition: temp_storage.hpp:53
::rocprim::tuple< Ts... > sub_partitions
The sub-partitions in this linear_partition.
Definition: temp_storage.hpp:126
union_partition(Ts... sub_partitions)
Constructor.
Definition: temp_storage.hpp:197