OpenKalman
Tensor.hpp
Go to the documentation of this file.
1 /* This file is part of OpenKalman, a header-only C++ library for
2  * Kalman filters and other recursive filters.
3  *
4  * Copyright (c) 2023 Christopher Lee Ogden <ogden@gatech.edu>
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at https://mozilla.org/MPL/2.0/.
9  */
10 
16 #ifndef OPENKALMAN_EIGEN_TRAITS_TENSOR_HPP
17 #define OPENKALMAN_EIGEN_TRAITS_TENSOR_HPP
18 
19 
20 namespace OpenKalman::interface
21 {
22  template<typename Scalar, int NumIndices, int options, typename IndexType>
23  struct object_traits<Eigen::Tensor<Scalar, NumIndices, options, IndexType>>
24  : Eigen3::object_traits_tensor_base<Eigen::Tensor<Scalar, NumIndices, options, IndexType>>
25  {
26  private:
27 
29 
30  public:
31 
32  template<typename Arg, typename N>
33  static constexpr std::size_t get_pattern_collection(const Arg& arg, N n) { return arg.dimension(n); }
34 
35  // nested_object() not defined
36 
37  // get_constant() not defined
38 
39  // get_constant_diagonal() not defined
40 
41 
42 #ifdef __cpp_lib_concepts
43  template<typename Arg, std::convertible_to<IndexType>...I> requires (sizeof...(I) == NumIndices)
44 #else
45  template<typename Arg, typename...I, std::enable_if_t<(stdex::convertible_to<I, IndexType> and ...) and
46  (sizeof...(I) == NumIndices), int> = 0>
47 #endif
48  static constexpr decltype(auto) get(Arg&& arg, I...i)
49  {
50  if constexpr ((Eigen::internal::traits<std::decay_t<Arg>>::Flags & Eigen::LvalueBit) != 0)
51  return std::forward<Arg>(arg).coeffRef(static_cast<IndexType>(i)...);
52  else
53  return std::forward<Arg>(arg).coeff(static_cast<IndexType>(i)...);
54  }
55 
56 
57 #ifdef __cpp_lib_concepts
58  template<typename Arg, std::convertible_to<IndexType>...I> requires (sizeof...(I) == NumIndices) and
59  ((Eigen::internal::traits<std::decay_t<Arg>>::Flags & Eigen::LvalueBit) != 0x0)
60 #else
61  template<typename Arg, typename...I, std::enable_if_t<(stdex::convertible_to<I, IndexType> and ...) and
62  (sizeof...(I) == NumIndices) and ((Eigen::internal::traits<std::decay_t<Arg>>::Flags & Eigen::LvalueBit) != 0x0), int> = 0>
63 #endif
64  static void set(Arg& arg, const scalar_type_of_t<Arg>& s, I...i)
65  {
66  arg.coeffRef(static_cast<IndexType>(i)...) = s;
67  }
68 
69  static constexpr bool is_writable = true;
70 
71  template<typename Arg>
72  static constexpr auto * const
73  raw_data(Arg& arg) { return arg.data(); }
74 
75  static constexpr data_layout layout = options & Eigen::RowMajor ? data_layout::right : data_layout::left;
76 
77  };
78 
79 }
80 
81 #endif
Definition: basics.hpp:41
Definition: eigen-comma-initializers.hpp:20
decltype(auto) constexpr get_pattern_collection(T &&t)
Get the coordinates::pattern_collection associated with indexible object T.
Definition: get_pattern_collection.hpp:59
Trait object providing get and set routines for Eigen tensors.
Definition: eigen-tensor-forward-declarations.hpp:114
Definition: object_traits.hpp:38