OpenKalman
TensorContractionOp.hpp
Go to the documentation of this file.
1 /* This file is part of OpenKalman, a header-only C++ library for
2  * Kalman filters and other recursive filters.
3  *
4  * Copyright (c) 2023 Christopher Lee Ogden <ogden@gatech.edu>
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at https://mozilla.org/MPL/2.0/.
9  */
10 
16 #ifndef OPENKALMAN_EIGEN_TRAITS_TENSORCONTRACTIONOP_HPP
17 #define OPENKALMAN_EIGEN_TRAITS_TENSORCONTRACTIONOP_HPP
18 
19 
20 namespace OpenKalman::interface
21 {
22  template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
23  struct object_traits<Eigen::TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>>
24  : Eigen3::object_traits_tensor_base<Eigen::TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>>
25  {
26  private:
27 
28  using Xpr = Eigen::TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>;
30 
31  public:
32 
33  template<typename Arg, typename N>
34  static constexpr std::size_t get_pattern_collection(const Arg& arg, N n)
35  {
36  using IndexType = typename Xpr::Index;
37  return Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions()[static_cast<IndexType>(n)];
38  }
39 
40 
41  // nested_object() not defined
42 
43 
44  template<typename Arg>
45  static constexpr auto get_constant(const Arg& arg)
46  {
47  using Scalar = scalar_type_of_t<Arg>;
48 
49  if constexpr (zero<LhsXprType>)
50  {
51  return constant_value{arg.lhsExpression()};
52  }
53  else if constexpr (zero<RhsXprType>)
54  {
55  return constant_value{arg.rhsExpression()};
56  }
57  else if constexpr (constant_diagonal_matrix<LhsXprType> and constant_matrix<RhsXprType>)
58  {
59  if constexpr (collections::size_of_v<decltype(arg.indices())> == 1)
60  {
61  return constant_diagonal_value{arg.lhsExpression()} * constant_value{arg.rhsExpression()};
62  }
63  else
64  {
65  auto& indices = arg.indices();
66  auto dims = Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions();
67  auto f = [&dims](const Scalar& a, auto b) -> Scalar { return a * dims[b.first]; };
68  auto factor = std::accumulate(++indices.cbegin(), indices.cend(), Scalar{1}, f);
69  return factor * (constant_diagonal_value{arg.lhsExpression()} * constant_value{arg.rhsExpression()});
70  }
71  }
72  else if constexpr (constant_matrix<LhsXprType> and constant_diagonal_matrix<RhsXprType>)
73  {
74  if constexpr (collections::size_of_v<decltype(arg.indices())> == 1)
75  {
76  return constant_value{arg.lhsExpression()} * constant_diagonal_value{arg.rhsExpression()};
77  }
78  else
79  {
80  auto& indices = arg.indices();
81  auto dims = Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions();
82  auto f = [&dims](const Scalar& a, auto b) -> Scalar { return a * dims[b.first]; };
83  auto factor = std::accumulate(++indices.cbegin(), indices.cend(), Scalar{1}, f);
84  return factor * (constant_value{arg.lhsExpression()} * constant_diagonal_value{arg.rhsExpression()});
85  }
86  }
87  else
88  {
89  auto& indices = arg.indices();
90  auto dims = Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions();
91  auto f = [&dims](const Scalar& a, auto b) -> Scalar { return a * dims[b.first]; };
92  auto factor = std::accumulate(indices.cbegin(), indices.cend(), Scalar{1}, f);
93  return factor * (constant_value{arg.lhsExpression()} * constant_value{arg.rhsExpression()});
94  }
95  }
96 
97 
98  template<typename Arg>
99  static constexpr auto get_constant_diagonal(const Arg& arg)
100  {
101  if constexpr (collections::size_of_v<decltype(arg.indices())> == 1)
102  {
103  return values::operation(std::multiplies{},
104  constant_diagonal_value{arg.lhs()}, constant_diagonal_value{arg.rhs()});
105  }
106  else
107  {
108  using Scalar = scalar_type_of_t<Arg>;
109  auto& indices = arg.indices();
110  auto dims = Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions();
111  auto f = [&dims](const Scalar& a, auto b) -> Scalar { return a * dims[b.first]; };
112  auto factor = std::accumulate(++indices.cbegin(), indices.cend(), Scalar{1}, f);
113  return factor * (constant_diagonal_value{arg.lhsExpression()} * constant_value{arg.rhsExpression()});
114  }
115  }
116 
117 
118  // one_dimensional not defined
119 
120  // is_square not defined
121 
122  //template<triangle_type t>
123  //static constexpr bool triangle_type_value = collections::size_of_v<decltype(std::declval<T>().indices())> == 1 and
124  // triangular_matrix<LhsXprType, t> and triangular_matrix<RhsXprType, t>;
125 
126 
127  static constexpr bool is_triangular_adapter = false;
128 
129 
130  static constexpr bool is_hermitian = collections::size_of_v<decltype(std::declval<Xpr>().indices())> == 1 and
131  ((constant_diagonal_matrix<LhsXprType> and hermitian_matrix<RhsXprType, applicability::permitted>) or
132  (constant_diagonal_matrix<RhsXprType> and hermitian_matrix<LhsXprType, applicability::permitted>));
133 
134 
135  static constexpr bool is_writable = false;
136 
137 
138  // raw_data() not defined
139 
140 
141  // layout not defined
142 
143  };
144 
145 }
146 
147 #endif
Definition: basics.hpp:41
Definition: eigen-comma-initializers.hpp:20
decltype(auto) constexpr get_pattern_collection(T &&t)
Get the coordinates::pattern_collection associated with indexible object T.
Definition: get_pattern_collection.hpp:59
Trait object providing get and set routines for Eigen tensors.
Definition: eigen-tensor-forward-declarations.hpp:114
Definition: object_traits.hpp:38
constexpr auto constant_value(T &&t)
The constant value associated with a constant_object or constant_diagonal_object. ...
Definition: constant_value.hpp:37
constexpr std::size_t size_of_v
Helper for collections::size_of.
Definition: size_of.hpp:60
constexpr auto operation(Operation &&op, Args &&...args)
A potentially constant-evaluated operation involving some number of values.
Definition: operation.hpp:98