OpenKalman
mdspan.hpp
1 //@HEADER
2 // ************************************************************************
3 //
4 // Kokkos v. 4.0
5 // Copyright (2022) National Technology & Engineering
6 // Solutions of Sandia, LLC (NTESS).
7 //
8 // Under the terms of Contract DE-NA0003525 with NTESS,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12 // See https://kokkos.org/LICENSE for license information.
13 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14 //
15 //@HEADER
16 
17 #pragma once
18 
19 #include "default_accessor.hpp"
20 #include "layout_right.hpp"
21 #include "extents.hpp"
22 #include "trait_backports.hpp"
23 #include "compressed_pair.hpp"
24 
25 namespace std {
26 namespace experimental {
27 
28 template <
29  class ElementType,
30  class Extents,
31  class LayoutPolicy = layout_right,
32  class AccessorPolicy = default_accessor<ElementType>
33 >
34 class mdspan
35 {
36 private:
37  static_assert(detail::__is_extents_v<Extents>, "std::experimental::mdspan's Extents template parameter must be a specialization of std::experimental::extents.");
38 
39  // Workaround for non-deducibility of the index sequence template parameter if it's given at the top level
40  template <class>
41  struct __deduction_workaround;
42 
43  template <size_t... Idxs>
44  struct __deduction_workaround<index_sequence<Idxs...>>
45  {
46  MDSPAN_FORCE_INLINE_FUNCTION static constexpr
47  size_t __size(mdspan const& __self) noexcept {
48  return _MDSPAN_FOLD_TIMES_RIGHT((__self.__mapping_ref().extents().extent(Idxs)), /* * ... * */ size_t(1));
49  }
50  MDSPAN_FORCE_INLINE_FUNCTION static constexpr
51  bool __empty(mdspan const& __self) noexcept {
52  return (__self.rank()>0) && _MDSPAN_FOLD_OR((__self.__mapping_ref().extents().extent(Idxs)==index_type(0)));
53  }
54  template <class ReferenceType, class SizeType, size_t N>
55  MDSPAN_FORCE_INLINE_FUNCTION static constexpr
56  ReferenceType __callop(mdspan const& __self, const array<SizeType, N>& indices) noexcept {
57  return __self.__accessor_ref().access(__self.__ptr_ref(), __self.__mapping_ref()(indices[Idxs]...));
58  }
59  // Added by CLO:
60  template <class ReferenceType, class SizeType>
61  MDSPAN_FORCE_INLINE_FUNCTION static constexpr
62  ReferenceType __callop(mdspan const& __self, const span<SizeType, Extents::rank()>& indices) noexcept {
63  return __self.__accessor_ref().access(__self.__ptr_ref(), __self.__mapping_ref()(indices[Idxs]...));
64  }
65  // End
66  };
67 
68 public:
69 
70  //--------------------------------------------------------------------------------
71  // Domain and codomain types
72 
73  using extents_type = Extents;
74  using layout_type = LayoutPolicy;
75  using accessor_type = AccessorPolicy;
76  using mapping_type = typename layout_type::template mapping<extents_type>;
77  using element_type = ElementType;
78  using value_type = remove_cv_t<element_type>;
79  using index_type = typename extents_type::index_type;
80  using size_type = typename extents_type::size_type;
81  using rank_type = typename extents_type::rank_type;
82  using data_handle_type = typename accessor_type::data_handle_type;
83  using reference = typename accessor_type::reference;
84 
85  MDSPAN_INLINE_FUNCTION static constexpr size_t rank() noexcept { return extents_type::rank(); }
86  MDSPAN_INLINE_FUNCTION static constexpr size_t rank_dynamic() noexcept { return extents_type::rank_dynamic(); }
87  MDSPAN_INLINE_FUNCTION static constexpr size_t static_extent(size_t r) noexcept { return extents_type::static_extent(r); }
88  MDSPAN_INLINE_FUNCTION constexpr index_type extent(size_t r) const noexcept { return __mapping_ref().extents().extent(r); };
89 
90 private:
91 
92  // Can't use defaulted parameter in the __deduction_workaround template because of a bug in MSVC warning C4348.
93  using __impl = __deduction_workaround<make_index_sequence<extents_type::rank()>>;
94 
96 
97 public:
98 
99  //--------------------------------------------------------------------------------
100  // [mdspan.basic.cons], mdspan constructors, assignment, and destructor
101 
102 #if !MDSPAN_HAS_CXX_20
103  MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mdspan() = default;
104 #else
105  MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mdspan()
106  requires(
107  // nvhpc has a bug where using just rank_dynamic() here doesn't work ...
108  (extents_type::rank_dynamic() > 0) &&
109  _MDSPAN_TRAIT(is_default_constructible, data_handle_type) &&
110  _MDSPAN_TRAIT(is_default_constructible, mapping_type) &&
111  _MDSPAN_TRAIT(is_default_constructible, accessor_type)
112  ) = default;
113 #endif
114  MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mdspan(const mdspan&) = default;
115  MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mdspan(mdspan&&) = default;
116 
117  MDSPAN_TEMPLATE_REQUIRES(
118  class... SizeTypes,
119  /* requires */ (
120  _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(is_convertible, SizeTypes, index_type) /* && ... */) &&
121  _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(is_nothrow_constructible, index_type, SizeTypes) /* && ... */) &&
122  ((sizeof...(SizeTypes) == rank()) || (sizeof...(SizeTypes) == rank_dynamic())) &&
123  _MDSPAN_TRAIT(is_constructible, mapping_type, extents_type) &&
124  _MDSPAN_TRAIT(is_default_constructible, accessor_type)
125  )
126  )
127  MDSPAN_INLINE_FUNCTION
128  explicit constexpr mdspan(data_handle_type p, SizeTypes... dynamic_extents)
129  // TODO @proposal-bug shouldn't I be allowed to do `move(p)` here?
130  : __members(std::move(p), __map_acc_pair_t(mapping_type(extents_type(static_cast<index_type>(std::move(dynamic_extents))...)), accessor_type()))
131  { }
132 
133  MDSPAN_TEMPLATE_REQUIRES(
134  class SizeType, size_t N,
135  /* requires */ (
136  _MDSPAN_TRAIT(is_convertible, SizeType, index_type) &&
137  _MDSPAN_TRAIT(is_nothrow_constructible, index_type, SizeType) &&
138  ((N == rank()) || (N == rank_dynamic())) &&
139  _MDSPAN_TRAIT(is_constructible, mapping_type, extents_type) &&
140  _MDSPAN_TRAIT(is_default_constructible, accessor_type)
141  )
142  )
143  MDSPAN_CONDITIONAL_EXPLICIT(N != rank_dynamic())
144  MDSPAN_INLINE_FUNCTION
145  constexpr mdspan(data_handle_type p, const array<SizeType, N>& dynamic_extents)
146  : __members(std::move(p), __map_acc_pair_t(mapping_type(extents_type(dynamic_extents)), accessor_type()))
147  { }
148 
149 #ifdef __cpp_lib_span
150  MDSPAN_TEMPLATE_REQUIRES(
151  class SizeType, size_t N,
152  /* requires */ (
153  _MDSPAN_TRAIT(is_convertible, SizeType, index_type) &&
154  _MDSPAN_TRAIT(is_nothrow_constructible, index_type, SizeType) &&
155  ((N == rank()) || (N == rank_dynamic())) &&
156  _MDSPAN_TRAIT(is_constructible, mapping_type, extents_type) &&
157  _MDSPAN_TRAIT(is_default_constructible, accessor_type)
158  )
159  )
160  MDSPAN_CONDITIONAL_EXPLICIT(N != rank_dynamic())
161  MDSPAN_INLINE_FUNCTION
162  constexpr mdspan(data_handle_type p, span<SizeType, N> dynamic_extents)
163  : __members(std::move(p), __map_acc_pair_t(mapping_type(extents_type(as_const(dynamic_extents))), accessor_type()))
164  { }
165 #endif
166 
167  MDSPAN_FUNCTION_REQUIRES(
168  (MDSPAN_INLINE_FUNCTION constexpr),
169  mdspan, (data_handle_type p, const extents_type& exts), ,
170  /* requires */ (_MDSPAN_TRAIT(is_default_constructible, accessor_type) &&
171  _MDSPAN_TRAIT(is_constructible, mapping_type, extents_type))
172  ) : __members(std::move(p), __map_acc_pair_t(mapping_type(exts), accessor_type()))
173  { }
174 
175  MDSPAN_FUNCTION_REQUIRES(
176  (MDSPAN_INLINE_FUNCTION constexpr),
177  mdspan, (data_handle_type p, const mapping_type& m), ,
178  /* requires */ (_MDSPAN_TRAIT(is_default_constructible, accessor_type))
179  ) : __members(std::move(p), __map_acc_pair_t(m, accessor_type()))
180  { }
181 
182  MDSPAN_INLINE_FUNCTION
183  constexpr mdspan(data_handle_type p, const mapping_type& m, const accessor_type& a)
184  : __members(std::move(p), __map_acc_pair_t(m, a))
185  { }
186 
187  MDSPAN_TEMPLATE_REQUIRES(
188  class OtherElementType, class OtherExtents, class OtherLayoutPolicy, class OtherAccessor,
189  /* requires */ (
190  _MDSPAN_TRAIT(is_constructible, mapping_type, typename OtherLayoutPolicy::template mapping<OtherExtents>) &&
191  _MDSPAN_TRAIT(is_constructible, accessor_type, OtherAccessor)
192  )
193  )
194  MDSPAN_INLINE_FUNCTION
196  : __members(other.__ptr_ref(), __map_acc_pair_t(other.__mapping_ref(), other.__accessor_ref()))
197  {
198  static_assert(_MDSPAN_TRAIT(is_constructible, data_handle_type, typename OtherAccessor::data_handle_type),"Incompatible data_handle_type for mdspan construction");
199  static_assert(_MDSPAN_TRAIT(is_constructible, extents_type, OtherExtents),"Incompatible extents for mdspan construction");
200  /*
201  * TODO: Check precondition
202  * For each rank index r of extents_type, static_extent(r) == dynamic_extent || static_extent(r) == other.extent(r) is true.
203  */
204  }
205 
206  /* Might need this on NVIDIA?
207  MDSPAN_INLINE_FUNCTION_DEFAULTED
208  ~mdspan() = default;
209  */
210 
211  MDSPAN_INLINE_FUNCTION_DEFAULTED _MDSPAN_CONSTEXPR_14_DEFAULTED mdspan& operator=(const mdspan&) = default;
212  MDSPAN_INLINE_FUNCTION_DEFAULTED _MDSPAN_CONSTEXPR_14_DEFAULTED mdspan& operator=(mdspan&&) = default;
213 
214 
215  //--------------------------------------------------------------------------------
216  // [mdspan.basic.mapping], mdspan mapping domain multidimensional index to access codomain element
217 
218  #if MDSPAN_USE_BRACKET_OPERATOR
219  MDSPAN_TEMPLATE_REQUIRES(
220  class... SizeTypes,
221  /* requires */ (
222  _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(is_convertible, SizeTypes, index_type) /* && ... */) &&
223  _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(is_nothrow_constructible, index_type, SizeTypes) /* && ... */) &&
224  (rank() == sizeof...(SizeTypes))
225  )
226  )
227  MDSPAN_FORCE_INLINE_FUNCTION
228  constexpr reference operator[](SizeTypes... indices) const
229  {
230  return __accessor_ref().access(__ptr_ref(), __mapping_ref()(static_cast<index_type>(std::move(indices))...));
231  }
232  #endif
233 
234  MDSPAN_TEMPLATE_REQUIRES(
235  class SizeType,
236  /* requires */ (
237  _MDSPAN_TRAIT(is_convertible, SizeType, index_type) &&
238  _MDSPAN_TRAIT(is_nothrow_constructible, index_type, SizeType)
239  )
240  )
241  MDSPAN_FORCE_INLINE_FUNCTION
242  constexpr reference operator[](const array<SizeType, rank()>& indices) const
243  {
244  return __impl::template __callop<reference>(*this, indices);
245  }
246 
247  #if defined (__cpp_lib_span) or defined (OPENKALMAN_COMPATIBILITY_SPAN)
248  MDSPAN_TEMPLATE_REQUIRES(
249  class SizeType,
250  /* requires */ (
251  _MDSPAN_TRAIT(is_convertible, SizeType, index_type) &&
252  _MDSPAN_TRAIT(is_nothrow_constructible, index_type, SizeType)
253  )
254  )
255  MDSPAN_FORCE_INLINE_FUNCTION
256  constexpr reference operator[](span<SizeType, rank()> indices) const
257  {
258  return __impl::template __callop<reference>(*this, indices);
259  }
260  #endif // __cpp_lib_span
261 
262  #if !MDSPAN_USE_BRACKET_OPERATOR
263  MDSPAN_TEMPLATE_REQUIRES(
264  class Index,
265  /* requires */ (
266  _MDSPAN_TRAIT(is_convertible, Index, index_type) &&
267  _MDSPAN_TRAIT(is_nothrow_constructible, index_type, Index) &&
268  extents_type::rank() == 1
269  )
270  )
271  MDSPAN_FORCE_INLINE_FUNCTION
272  constexpr reference operator[](Index idx) const
273  {
274  return __accessor_ref().access(__ptr_ref(), __mapping_ref()(static_cast<index_type>(std::move(idx))));
275  }
276  #endif
277 
278  #if MDSPAN_USE_PAREN_OPERATOR
279  MDSPAN_TEMPLATE_REQUIRES(
280  class... SizeTypes,
281  /* requires */ (
282  _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(is_convertible, SizeTypes, index_type) /* && ... */) &&
283  _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(is_nothrow_constructible, index_type, SizeTypes) /* && ... */) &&
284  extents_type::rank() == sizeof...(SizeTypes)
285  )
286  )
287  MDSPAN_FORCE_INLINE_FUNCTION
288  constexpr reference operator()(SizeTypes... indices) const
289  {
290  return __accessor_ref().access(__ptr_ref(), __mapping_ref()(static_cast<index_type>(std::move(indices))...));
291  }
292 
293  MDSPAN_TEMPLATE_REQUIRES(
294  class SizeType,
295  /* requires */ (
296  _MDSPAN_TRAIT(is_convertible, SizeType, index_type) &&
297  _MDSPAN_TRAIT(is_nothrow_constructible, index_type, SizeType)
298  )
299  )
300  MDSPAN_FORCE_INLINE_FUNCTION
301  constexpr reference operator()(const array<SizeType, rank()>& indices) const
302  {
303  return __impl::template __callop<reference>(*this, indices);
304  }
305 
306  #ifdef __cpp_lib_span
307  MDSPAN_TEMPLATE_REQUIRES(
308  class SizeType,
309  /* requires */ (
310  _MDSPAN_TRAIT(is_convertible, SizeType, index_type) &&
311  _MDSPAN_TRAIT(is_nothrow_constructible, index_type, SizeType)
312  )
313  )
314  MDSPAN_FORCE_INLINE_FUNCTION
315  constexpr reference operator()(span<SizeType, rank()> indices) const
316  {
317  return __impl::template __callop<reference>(*this, indices);
318  }
319  #endif // __cpp_lib_span
320  #endif // MDSPAN_USE_PAREN_OPERATOR
321 
322  MDSPAN_INLINE_FUNCTION constexpr size_t size() const noexcept {
323  return __impl::__size(*this);
324  };
325 
326  MDSPAN_INLINE_FUNCTION constexpr bool empty() const noexcept {
327  return __impl::__empty(*this);
328  };
329 
330  MDSPAN_INLINE_FUNCTION
331  friend constexpr void swap(mdspan& x, mdspan& y) noexcept {
332  // can't call the std::swap inside on HIP
333  #if !defined(_MDSPAN_HAS_HIP) && !defined(_MDSPAN_HAS_CUDA)
334  swap(x.__ptr_ref(), y.__ptr_ref());
335  swap(x.__mapping_ref(), y.__mapping_ref());
336  swap(x.__accessor_ref(), y.__accessor_ref());
337  #else
338  mdspan tmp = y;
339  y = x;
340  x = tmp;
341  #endif
342  }
343 
344  //--------------------------------------------------------------------------------
345  // [mdspan.basic.domobs], mdspan observers of the domain multidimensional index space
346 
347 
348  MDSPAN_INLINE_FUNCTION constexpr const extents_type& extents() const noexcept { return __mapping_ref().extents(); };
349  MDSPAN_INLINE_FUNCTION constexpr const data_handle_type& data_handle() const noexcept { return __ptr_ref(); };
350  MDSPAN_INLINE_FUNCTION constexpr const mapping_type& mapping() const noexcept { return __mapping_ref(); };
351  MDSPAN_INLINE_FUNCTION constexpr const accessor_type& accessor() const noexcept { return __accessor_ref(); };
352 
353  //--------------------------------------------------------------------------------
354  // [mdspan.basic.obs], mdspan observers of the mapping
355 
356  MDSPAN_INLINE_FUNCTION static constexpr bool is_always_unique() noexcept { return mapping_type::is_always_unique(); };
357  MDSPAN_INLINE_FUNCTION static constexpr bool is_always_exhaustive() noexcept { return mapping_type::is_always_exhaustive(); };
358  MDSPAN_INLINE_FUNCTION static constexpr bool is_always_strided() noexcept { return mapping_type::is_always_strided(); };
359 
360  MDSPAN_INLINE_FUNCTION constexpr bool is_unique() const noexcept { return __mapping_ref().is_unique(); };
361  MDSPAN_INLINE_FUNCTION constexpr bool is_exhaustive() const noexcept { return __mapping_ref().is_exhaustive(); };
362  MDSPAN_INLINE_FUNCTION constexpr bool is_strided() const noexcept { return __mapping_ref().is_strided(); };
363  MDSPAN_INLINE_FUNCTION constexpr index_type stride(size_t r) const { return __mapping_ref().stride(r); };
364 
365 private:
366 
368 
369  MDSPAN_FORCE_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14 data_handle_type& __ptr_ref() noexcept { return __members.__first(); }
370  MDSPAN_FORCE_INLINE_FUNCTION constexpr data_handle_type const& __ptr_ref() const noexcept { return __members.__first(); }
371  MDSPAN_FORCE_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14 mapping_type& __mapping_ref() noexcept { return __members.__second().__first(); }
372  MDSPAN_FORCE_INLINE_FUNCTION constexpr mapping_type const& __mapping_ref() const noexcept { return __members.__second().__first(); }
373  MDSPAN_FORCE_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14 accessor_type& __accessor_ref() noexcept { return __members.__second().__second(); }
374  MDSPAN_FORCE_INLINE_FUNCTION constexpr accessor_type const& __accessor_ref() const noexcept { return __members.__second().__second(); }
375 
376  template <class, class, class, class>
377  friend class mdspan;
378 
379 };
380 
381 #if defined(_MDSPAN_USE_CLASS_TEMPLATE_ARGUMENT_DEDUCTION)
382 MDSPAN_TEMPLATE_REQUIRES(
383  class ElementType, class... SizeTypes,
384  /* requires */ _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(is_integral, SizeTypes) /* && ... */) &&
385  (sizeof...(SizeTypes) > 0)
386 )
387 MDSPAN_DEDUCTION_GUIDE explicit mdspan(ElementType*, SizeTypes...)
388  -> mdspan<ElementType, ::std::experimental::dextents<size_t, sizeof...(SizeTypes)>>;
389 
390 MDSPAN_TEMPLATE_REQUIRES(
391  class Pointer,
392  (_MDSPAN_TRAIT(is_pointer, std::remove_reference_t<Pointer>))
393 )
395 
396 MDSPAN_TEMPLATE_REQUIRES(
397  class CArray,
398  (_MDSPAN_TRAIT(is_array, CArray) && (rank_v<CArray> == 1))
399 )
401 
402 template <class ElementType, class SizeType, size_t N>
403 MDSPAN_DEDUCTION_GUIDE mdspan(ElementType*, const ::std::array<SizeType, N>&)
405 
406 #ifdef __cpp_lib_span
407 template <class ElementType, class SizeType, size_t N>
408 MDSPAN_DEDUCTION_GUIDE mdspan(ElementType*, ::std::span<SizeType, N>)
410 #endif
411 
412 // This one is necessary because all the constructors take `data_handle_type`s, not
413 // `ElementType*`s, and `data_handle_type` is taken from `accessor_type::data_handle_type`, which
414 // seems to throw off automatic deduction guides.
415 template <class ElementType, class SizeType, size_t... ExtentsPack>
416 MDSPAN_DEDUCTION_GUIDE mdspan(ElementType*, const extents<SizeType, ExtentsPack...>&)
417  -> mdspan<ElementType, ::std::experimental::extents<SizeType, ExtentsPack...>>;
418 
419 template <class ElementType, class MappingType>
420 MDSPAN_DEDUCTION_GUIDE mdspan(ElementType*, const MappingType&)
422 
423 template <class MappingType, class AccessorType>
424 MDSPAN_DEDUCTION_GUIDE mdspan(const typename AccessorType::data_handle_type, const MappingType&, const AccessorType&)
426 #endif
427 
428 
429 
430 } // end namespace experimental
431 } // end namespace std
Definition: compressed_pair.hpp:31
Definition: mdspan.hpp:34
Definition: extents.hpp:372