OpenKalman
layout_stride.hpp
1 //@HEADER
2 // ************************************************************************
3 //
4 // Kokkos v. 4.0
5 // Copyright (2022) National Technology & Engineering
6 // Solutions of Sandia, LLC (NTESS).
7 //
8 // Under the terms of Contract DE-NA0003525 with NTESS,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12 // See https://kokkos.org/LICENSE for license information.
13 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14 //
15 //@HEADER
16 #pragma once
17 
18 #include "macros.hpp"
19 #include "extents.hpp"
20 #include "trait_backports.hpp"
21 #include "compressed_pair.hpp"
22 
23 #if !defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
24 # include "no_unique_address.hpp"
25 #endif
26 
27 #include <algorithm>
28 #include <numeric>
29 #include <array>
30 #ifdef __cpp_lib_span
31 #include <span>
32 #endif
33 #if defined(_MDSPAN_USE_CONCEPTS) && MDSPAN_HAS_CXX_20
34 #include<concepts>
35 #endif
36 
37 namespace std {
38 namespace experimental {
39 
40 struct layout_left {
41  template<class Extents>
42  class mapping;
43 };
44 struct layout_right {
45  template<class Extents>
46  class mapping;
47 };
48 
49 namespace detail {
50  template<class Layout, class Mapping>
51  constexpr bool __is_mapping_of =
52  is_same<typename Layout::template mapping<typename Mapping::extents_type>, Mapping>::value;
53 
54 #if defined(_MDSPAN_USE_CONCEPTS) && MDSPAN_HAS_CXX_20
55  template<class M>
56  concept __layout_mapping_alike = requires {
58  { M::is_always_strided() } -> same_as<bool>;
59  { M::is_always_exhaustive() } -> same_as<bool>;
60  { M::is_always_unique() } -> same_as<bool>;
64  };
65 #endif
66 } // namespace detail
67 
68 struct layout_stride {
69  template <class Extents>
70  class mapping
71 #if !defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
73  detail::__compressed_pair<
74  Extents,
75  std::array<typename Extents::index_type, Extents::rank()>
76  >
77  >
78 #endif
79  {
80  public:
81  using extents_type = Extents;
82  using index_type = typename extents_type::index_type;
83  using size_type = typename extents_type::size_type;
84  using rank_type = typename extents_type::rank_type;
85  using layout_type = layout_stride;
86 
87  // This could be a `requires`, but I think it's better and clearer as a `static_assert`.
88  static_assert(detail::__is_extents_v<Extents>, "std::experimental::layout_stride::mapping must be instantiated with a specialization of std::experimental::extents.");
89 
90 
91  private:
92 
93  //----------------------------------------------------------------------------
94 
95  using __strides_storage_t = array<index_type, extents_type::rank()>;//::std::experimental::dextents<index_type, extents_type::rank()>;
97 
98 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
99  _MDSPAN_NO_UNIQUE_ADDRESS __member_pair_t __members;
100 #else
102 #endif
103 
104  MDSPAN_FORCE_INLINE_FUNCTION constexpr __strides_storage_t const&
105  __strides_storage() const noexcept {
106 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
107  return __members.__second();
108 #else
109  return this->__base_t::__ref().__second();
110 #endif
111  }
112  MDSPAN_FORCE_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14 __strides_storage_t&
113  __strides_storage() noexcept {
114 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
115  return __members.__second();
116 #else
117  return this->__base_t::__ref().__second();
118 #endif
119  }
120 
121  template<class SizeType, ::std::size_t ... Ep, ::std::size_t ... Idx>
122  _MDSPAN_HOST_DEVICE
124  return _MDSPAN_FOLD_TIMES_RIGHT( static_cast<index_type>(extents().extent(Idx)), 1 );
125  }
126 
127  //----------------------------------------------------------------------------
128 
129  template <class>
130  friend class mapping;
131 
132  //----------------------------------------------------------------------------
133 
134  // Workaround for non-deducibility of the index sequence template parameter if it's given at the top level
135  template <class>
136  struct __deduction_workaround;
137 
138  template <size_t... Idxs>
139  struct __deduction_workaround<index_sequence<Idxs...>>
140  {
141  template <class OtherExtents>
142  MDSPAN_INLINE_FUNCTION
143  static constexpr bool _eq_impl(mapping const& self, mapping<OtherExtents> const& other) noexcept {
144  return _MDSPAN_FOLD_AND((self.stride(Idxs) == other.stride(Idxs)) /* && ... */)
145  && _MDSPAN_FOLD_AND((self.extents().extent(Idxs) == other.extents().extent(Idxs)) /* || ... */);
146  }
147  template <class OtherExtents>
148  MDSPAN_INLINE_FUNCTION
149  static constexpr bool _not_eq_impl(mapping const& self, mapping<OtherExtents> const& other) noexcept {
150  return _MDSPAN_FOLD_OR((self.stride(Idxs) != other.stride(Idxs)) /* || ... */)
151  || _MDSPAN_FOLD_OR((self.extents().extent(Idxs) != other.extents().extent(Idxs)) /* || ... */);
152  }
153 
154  template <class... Integral>
155  MDSPAN_FORCE_INLINE_FUNCTION
156  static constexpr size_t _call_op_impl(mapping const& self, Integral... idxs) noexcept {
157  return _MDSPAN_FOLD_PLUS_RIGHT((idxs * self.stride(Idxs)), /* + ... + */ 0);
158  }
159 
160  MDSPAN_INLINE_FUNCTION
161  static constexpr size_t _req_span_size_impl(mapping const& self) noexcept {
162  // assumes no negative strides; not sure if I'm allowed to assume that or not
163  return __impl::_call_op_impl(self, (self.extents().template __extent<Idxs>() - 1)...) + 1;
164  }
165 
166  template<class OtherMapping>
167  MDSPAN_INLINE_FUNCTION
168  static constexpr const __strides_storage_t fill_strides(const OtherMapping& map) {
169  return __strides_storage_t{static_cast<index_type>(map.stride(Idxs))...};
170  }
171 
172  MDSPAN_INLINE_FUNCTION
173  static constexpr const __strides_storage_t& fill_strides(const __strides_storage_t& s) {
174  return s;
175  }
176 
177  template<class IntegralType>
178  MDSPAN_INLINE_FUNCTION
179  static constexpr const __strides_storage_t fill_strides(const array<IntegralType,extents_type::rank()>& s) {
180  return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
181  }
182 
183 #ifdef __cpp_lib_span
184  template<class IntegralType>
185  MDSPAN_INLINE_FUNCTION
186  static constexpr const __strides_storage_t fill_strides(const span<IntegralType,extents_type::rank()>& s) {
187  return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
188  }
189 #endif
190 
191  template<size_t K>
192  MDSPAN_INLINE_FUNCTION
193  static constexpr size_t __return_zero() { return 0; }
194 
195  template<class Mapping>
196  MDSPAN_INLINE_FUNCTION
197  static constexpr typename Mapping::index_type
198  __OFFSET(const Mapping& m) { return m(__return_zero<Idxs>()...); }
199  };
200 
201  // Can't use defaulted parameter in the __deduction_workaround template because of a bug in MSVC warning C4348.
202  using __impl = __deduction_workaround<make_index_sequence<Extents::rank()>>;
203 
204 
205  //----------------------------------------------------------------------------
206 
207 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
208  MDSPAN_INLINE_FUNCTION constexpr explicit
209  mapping(__member_pair_t&& __m) : __members(::std::move(__m)) {}
210 #else
211  MDSPAN_INLINE_FUNCTION constexpr explicit
212  mapping(__base_t&& __b) : __base_t(::std::move(__b)) {}
213 #endif
214 
215  public:
216 
217  //--------------------------------------------------------------------------------
218 
219  MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping() noexcept = default;
220  MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping(mapping const&) noexcept = default;
221 
222  MDSPAN_TEMPLATE_REQUIRES(
223  class IntegralTypes,
224  /* requires */ (
225  // MSVC 19.32 does not like using index_type here, requires the typename Extents::index_type
226  // error C2641: cannot deduce template arguments for 'std::experimental::layout_stride::mapping'
227  _MDSPAN_TRAIT(is_convertible, const remove_const_t<IntegralTypes>&, typename Extents::index_type) &&
228  _MDSPAN_TRAIT(is_nothrow_constructible, typename Extents::index_type, const remove_const_t<IntegralTypes>&)
229  )
230  )
231  MDSPAN_INLINE_FUNCTION
232  constexpr
233  mapping(
234  extents_type const& e,
235  array<IntegralTypes, extents_type::rank()> const& s
236  ) noexcept
237 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
238  : __members{
239 #else
241 #endif
242  e, __strides_storage_t(__impl::fill_strides(s))
243 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
244  }
245 #else
246  )})
247 #endif
248  {
249  /*
250  * TODO: check preconditions
251  * - s[i] > 0 is true for all i in the range [0, rank_ ).
252  * - REQUIRED-SPAN-SIZE(e, s) is a representable value of type index_type ([basic.fundamental]).
253  * - If rank_ is greater than 0, then there exists a permutation P of the integers in the
254  * range [0, rank_), such that s[ pi ] >= s[ pi − 1 ] * e.extent( pi − 1 ) is true for
255  * all i in the range [1, rank_ ), where pi is the ith element of P.
256  */
257  }
258 
259 #ifdef __cpp_lib_span
260  MDSPAN_TEMPLATE_REQUIRES(
261  class IntegralTypes,
262  /* requires */ (
263  // MSVC 19.32 does not like using index_type here, requires the typename Extents::index_type
264  // error C2641: cannot deduce template arguments for 'std::experimental::layout_stride::mapping'
265  _MDSPAN_TRAIT(is_convertible, const remove_const_t<IntegralTypes>&, typename Extents::index_type) &&
266  _MDSPAN_TRAIT(is_nothrow_constructible, typename Extents::index_type, const remove_const_t<IntegralTypes>&)
267  )
268  )
269  MDSPAN_INLINE_FUNCTION
270  constexpr
271  mapping(
272  extents_type const& e,
273  span<IntegralTypes, extents_type::rank()> const& s
274  ) noexcept
275 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
276  : __members{
277 #else
279 #endif
280  e, __strides_storage_t(__impl::fill_strides(s))
281 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
282  }
283 #else
284  )})
285 #endif
286  {
287  /*
288  * TODO: check preconditions
289  * - s[i] > 0 is true for all i in the range [0, rank_ ).
290  * - REQUIRED-SPAN-SIZE(e, s) is a representable value of type index_type ([basic.fundamental]).
291  * - If rank_ is greater than 0, then there exists a permutation P of the integers in the
292  * range [0, rank_), such that s[ pi ] >= s[ pi − 1 ] * e.extent( pi − 1 ) is true for
293  * all i in the range [1, rank_ ), where pi is the ith element of P.
294  */
295  }
296 #endif // __cpp_lib_span
297 
298 #if !(defined(_MDSPAN_USE_CONCEPTS) && MDSPAN_HAS_CXX_20)
299  MDSPAN_TEMPLATE_REQUIRES(
300  class StridedLayoutMapping,
301  /* requires */ (
302  _MDSPAN_TRAIT(is_constructible, extents_type, typename StridedLayoutMapping::extents_type) &&
303  detail::__is_mapping_of<typename StridedLayoutMapping::layout_type, StridedLayoutMapping> &&
304  StridedLayoutMapping::is_always_unique() &&
305  StridedLayoutMapping::is_always_strided()
306  )
307  )
308 #else
309  template<class StridedLayoutMapping>
310  requires(
311  detail::__layout_mapping_alike<StridedLayoutMapping> &&
312  _MDSPAN_TRAIT(is_constructible, extents_type, typename StridedLayoutMapping::extents_type) &&
313  StridedLayoutMapping::is_always_unique() &&
314  StridedLayoutMapping::is_always_strided()
315  )
316 #endif
317  MDSPAN_CONDITIONAL_EXPLICIT(
319  (detail::__is_mapping_of<layout_left, StridedLayoutMapping> ||
320  detail::__is_mapping_of<layout_right, StridedLayoutMapping> ||
321  detail::__is_mapping_of<layout_stride, StridedLayoutMapping>)
322  ) // needs two () due to comma
323  MDSPAN_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14
324  mapping(StridedLayoutMapping const& other) noexcept // NOLINT(google-explicit-constructor)
325 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
326  : __members{
327 #else
329 #endif
330  other.extents(), __strides_storage_t(__impl::fill_strides(other))
331 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
332  }
333 #else
334  )})
335 #endif
336  {
337  /*
338  * TODO: check preconditions
339  * - other.stride(i) > 0 is true for all i in the range [0, rank_ ).
340  * - other.required_span_size() is a representable value of type index_type ([basic.fundamental]).
341  * - OFFSET(other) == 0
342  */
343  }
344 
345  //--------------------------------------------------------------------------------
346 
347  MDSPAN_INLINE_FUNCTION_DEFAULTED _MDSPAN_CONSTEXPR_14_DEFAULTED
348  mapping& operator=(mapping const&) noexcept = default;
349 
350  MDSPAN_INLINE_FUNCTION constexpr const extents_type& extents() const noexcept {
351 #if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
352  return __members.__first();
353 #else
354  return this->__base_t::__ref().__first();
355 #endif
356  };
357 
358  MDSPAN_INLINE_FUNCTION
359  constexpr array< index_type, extents_type::rank() > strides() const noexcept {
360  return __strides_storage();
361  }
362 
363  MDSPAN_INLINE_FUNCTION
364  constexpr index_type required_span_size() const noexcept {
365  index_type span_size = 1;
366  for(unsigned r = 0; r < extents_type::rank(); r++) {
367  // Return early if any of the extents are zero
368  if(extents().extent(r)==0) return 0;
369  span_size += ( static_cast<index_type>(extents().extent(r) - 1 ) * __strides_storage()[r]);
370  }
371  return span_size;
372  }
373 
374 
375  MDSPAN_TEMPLATE_REQUIRES(
376  class... Indices,
377  /* requires */ (
378  sizeof...(Indices) == Extents::rank() &&
379  _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(is_convertible, Indices, index_type) /*&& ...*/ ) &&
380  _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(is_nothrow_constructible, index_type, Indices) /*&& ...*/)
381  )
382  )
383  MDSPAN_FORCE_INLINE_FUNCTION
384  constexpr index_type operator()(Indices... idxs) const noexcept {
385  return static_cast<index_type>(__impl::_call_op_impl(*this, static_cast<index_type>(idxs)...));
386  }
387 
388  MDSPAN_INLINE_FUNCTION static constexpr bool is_always_unique() noexcept { return true; }
389  MDSPAN_INLINE_FUNCTION static constexpr bool is_always_exhaustive() noexcept {
390  return false;
391  }
392  MDSPAN_INLINE_FUNCTION static constexpr bool is_always_strided() noexcept { return true; }
393 
394  MDSPAN_INLINE_FUNCTION static constexpr bool is_unique() noexcept { return true; }
395  MDSPAN_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14 bool is_exhaustive() const noexcept {
396  return required_span_size() == __get_size(extents(), make_index_sequence<extents_type::rank()>());
397  }
398  MDSPAN_INLINE_FUNCTION static constexpr bool is_strided() noexcept { return true; }
399 
400 
401  MDSPAN_INLINE_FUNCTION
402  constexpr index_type stride(rank_type r) const noexcept
403 #if MDSPAN_HAS_CXX_20
404  requires ( Extents::rank() > 0 )
405 #endif
406  {
407  return __strides_storage()[r];
408  }
409 
410 #if !(defined(_MDSPAN_USE_CONCEPTS) && MDSPAN_HAS_CXX_20)
411  MDSPAN_TEMPLATE_REQUIRES(
412  class StridedLayoutMapping,
413  /* requires */ (
414  detail::__is_mapping_of<typename StridedLayoutMapping::layout_type, StridedLayoutMapping> &&
415  (extents_type::rank() == StridedLayoutMapping::extents_type::rank()) &&
416  StridedLayoutMapping::is_always_strided()
417  )
418  )
419 #else
420  template<class StridedLayoutMapping>
421  requires(
422  detail::__layout_mapping_alike<StridedLayoutMapping> &&
423  (extents_type::rank() == StridedLayoutMapping::extents_type::rank()) &&
424  StridedLayoutMapping::is_always_strided()
425  )
426 #endif
427  MDSPAN_INLINE_FUNCTION
428  friend constexpr bool operator==(const mapping& x, const StridedLayoutMapping& y) noexcept {
429  bool strides_match = true;
430  for(rank_type r = 0; r < extents_type::rank(); r++)
431  strides_match = strides_match && (x.stride(r) == y.stride(r));
432  return (x.extents() == y.extents()) &&
433  (__impl::__OFFSET(y)== static_cast<typename StridedLayoutMapping::index_type>(0)) &&
434  strides_match;
435  }
436 
437  // This one is not technically part of the proposal. Just here to make implementation a bit more optimal hopefully
438  MDSPAN_TEMPLATE_REQUIRES(
439  class OtherExtents,
440  /* requires */ (
441  (extents_type::rank() == OtherExtents::rank())
442  )
443  )
444  MDSPAN_INLINE_FUNCTION
445  friend constexpr bool operator==(mapping const& lhs, mapping<OtherExtents> const& rhs) noexcept {
446  return __impl::_eq_impl(lhs, rhs);
447  }
448 
449 #if !MDSPAN_HAS_CXX_20
450  MDSPAN_TEMPLATE_REQUIRES(
451  class StridedLayoutMapping,
452  /* requires */ (
453  detail::__is_mapping_of<typename StridedLayoutMapping::layout_type, StridedLayoutMapping> &&
454  (extents_type::rank() == StridedLayoutMapping::extents_type::rank()) &&
455  StridedLayoutMapping::is_always_strided()
456  )
457  )
458  MDSPAN_INLINE_FUNCTION
459  friend constexpr bool operator!=(const mapping& x, const StridedLayoutMapping& y) noexcept {
460  return not (x == y);
461  }
462 
463  MDSPAN_TEMPLATE_REQUIRES(
464  class OtherExtents,
465  /* requires */ (
466  (extents_type::rank() == OtherExtents::rank())
467  )
468  )
469  MDSPAN_INLINE_FUNCTION
470  friend constexpr bool operator!=(mapping const& lhs, mapping<OtherExtents> const& rhs) noexcept {
471  return __impl::_not_eq_impl(lhs, rhs);
472  }
473 #endif
474 
475  };
476 };
477 
478 } // end namespace experimental
479 } // end namespace std
Definition: layout_stride.hpp:44
Definition: compressed_pair.hpp:31
Definition: layout_right.hpp:29
Definition: layout_stride.hpp:70
constexpr bool value
T is a fixed or dynamic value that is reducible to a number.
Definition: value.hpp:45
Definition: layout_stride.hpp:40
Definition: layout_stride.hpp:68
Definition: trait_backports.hpp:64
Definition: layout_left.hpp:28
Definition: extents.hpp:372