10 #ifndef EIGEN_MATRIX_SQUARE_ROOT 11 #define EIGEN_MATRIX_SQUARE_ROOT 26 template <
typename MatrixType>
42 eigen_assert(A.rows() == A.cols());
53 template <
typename ResultType>
void compute(ResultType &result);
56 typedef typename MatrixType::Index Index;
61 void compute2x2diagonalBlock(
MatrixType& sqrtT,
const MatrixType& T,
typename MatrixType::Index i);
63 typename MatrixType::Index i,
typename MatrixType::Index j);
65 typename MatrixType::Index i,
typename MatrixType::Index j);
67 typename MatrixType::Index i,
typename MatrixType::Index j);
69 typename MatrixType::Index i,
typename MatrixType::Index j);
71 template <
typename SmallMatrixType>
72 static void solveAuxiliaryEquation(SmallMatrixType& X,
const SmallMatrixType& A,
73 const SmallMatrixType& B,
const SmallMatrixType&
C);
78 template <
typename MatrixType>
79 template <
typename ResultType>
82 result.resize(m_A.rows(), m_A.cols());
83 computeDiagonalPartOfSqrt(result, m_A);
84 computeOffDiagonalPartOfSqrt(result, m_A);
89 template <
typename MatrixType>
94 const Index
size = m_A.rows();
95 for (Index i = 0; i <
size; i++) {
96 if (i == size - 1 || T.coeff(i+1, i) == 0) {
97 eigen_assert(T(i,i) >= 0);
98 sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i));
101 compute2x2diagonalBlock(sqrtT, T, i);
109 template <
typename MatrixType>
113 const Index
size = m_A.rows();
114 for (Index j = 1; j <
size; j++) {
115 if (T.coeff(j, j-1) != 0)
117 for (Index i = j-1; i >= 0; i--) {
118 if (i > 0 && T.coeff(i, i-1) != 0)
120 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
121 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
122 if (iBlockIs2x2 && jBlockIs2x2)
123 compute2x2offDiagonalBlock(sqrtT, T, i, j);
124 else if (iBlockIs2x2 && !jBlockIs2x2)
125 compute2x1offDiagonalBlock(sqrtT, T, i, j);
126 else if (!iBlockIs2x2 && jBlockIs2x2)
127 compute1x2offDiagonalBlock(sqrtT, T, i, j);
128 else if (!iBlockIs2x2 && !jBlockIs2x2)
129 compute1x1offDiagonalBlock(sqrtT, T, i, j);
136 template <
typename MatrixType>
144 sqrtT.template block<2,2>(i,i)
145 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
151 template <
typename MatrixType>
154 typename MatrixType::Index i,
typename MatrixType::Index j)
156 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
157 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
161 template <
typename MatrixType>
164 typename MatrixType::Index i,
typename MatrixType::Index j)
168 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
170 A += sqrtT.template block<2,2>(j,j).transpose();
171 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
175 template <
typename MatrixType>
178 typename MatrixType::Index i,
typename MatrixType::Index j)
182 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
184 A += sqrtT.template block<2,2>(i,i);
185 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
189 template <
typename MatrixType>
192 typename MatrixType::Index i,
typename MatrixType::Index j)
198 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
200 solveAuxiliaryEquation(X, A, B, C);
201 sqrtT.template block<2,2>(i,j) = X;
205 template <
typename MatrixType>
206 template <
typename SmallMatrixType>
209 const SmallMatrixType& B,
const SmallMatrixType&
C)
212 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
215 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
216 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
217 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
218 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
219 coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
220 coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
221 coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
222 coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
223 coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
224 coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
225 coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
226 coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
229 rhs.coeffRef(0) = C.coeff(0,0);
230 rhs.coeffRef(1) = C.coeff(0,1);
231 rhs.coeffRef(2) = C.coeff(1,0);
232 rhs.coeffRef(3) = C.coeff(1,1);
235 result = coeffMatrix.fullPivLu().solve(rhs);
237 X.coeffRef(0,0) = result.coeff(0);
238 X.coeffRef(0,1) = result.coeff(1);
239 X.coeffRef(1,0) = result.coeff(2);
240 X.coeffRef(1,1) = result.coeff(3);
255 template <
typename MatrixType>
262 eigen_assert(A.rows() == A.cols());
274 template <
typename ResultType>
void compute(ResultType &result);
280 template <
typename MatrixType>
281 template <
typename ResultType>
288 result.resize(m_A.rows(), m_A.cols());
289 typedef typename MatrixType::Index Index;
290 for (Index i = 0; i < m_A.rows(); i++) {
291 result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
293 for (Index j = 1; j < m_A.cols(); j++) {
294 for (Index i = j-1; i >= 0; i--) {
297 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
299 result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
333 template <
typename ResultType>
void compute(ResultType &result);
339 template <
typename MatrixType>
347 eigen_assert(A.rows() == A.cols());
350 template <
typename ResultType>
void compute(ResultType &result)
358 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
362 result = U * sqrtT * U.adjoint();
372 template <
typename MatrixType>
380 eigen_assert(A.rows() == A.cols());
383 template <
typename ResultType>
void compute(ResultType &result)
395 result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
416 :
public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
418 typedef typename Derived::Index Index;
432 template <
typename ResultType>
433 inline void evalTo(ResultType& result)
const 435 const typename Derived::PlainObject srcEvaluated = m_src.eval();
440 Index rows()
const {
return m_src.rows(); }
441 Index cols()
const {
return m_src.cols(); }
444 const Derived& m_src;
450 template<
typename Derived>
453 typedef typename Derived::PlainObject ReturnType;
457 template <
typename Derived>
460 eigen_assert(rows() == cols());
466 #endif // EIGEN_MATRIX_FUNCTION
Proxy for the matrix square root of some matrix (expression).
Definition: ForwardDeclarations.h:274
iterative scaling algorithm to equilibrate rows and column norms in matrices
Definition: TestIMU_Common.h:87
void compute(ResultType &result)
Compute the matrix square root.
Definition: ReturnByValue.h:50
void compute(ResultType &result)
Compute the matrix square root.
Definition: MatrixSquareRoot.h:80
detail::size< coerce_list< Ts... >> size
Get the size of a list (number of elements.)
Definition: Size.h:56
void compute(ResultType &result)
Compute the matrix square root.
Definition: MatrixSquareRoot.h:282
Class for computing matrix square roots of upper quasi-triangular matrices.
Definition: MatrixSquareRoot.h:27
Class for computing matrix square roots of upper triangular matrices.
Definition: MatrixSquareRoot.h:256
void evalTo(ResultType &result) const
Compute the matrix square root.
Definition: MatrixSquareRoot.h:433
MatrixSquareRootQuasiTriangular(const MatrixType &A)
Constructor.
Definition: MatrixSquareRoot.h:39
const MatrixType & matrixT() const
Returns the quasi-triangular matrix in the Schur decomposition.
Definition: RealSchur.h:143
Class for computing matrix square roots of general matrices.
Definition: MatrixSquareRoot.h:313
Definition: BandTriangularSolver.h:13
const MatrixType & matrixU() const
Returns the orthogonal matrix in the Schur decomposition.
Definition: RealSchur.h:126
MatrixSquareRootReturnValue(const Derived &src)
Constructor.
Definition: MatrixSquareRoot.h:425
const ComplexMatrixType & matrixU() const
Returns the unitary matrix in the Schur decomposition.
Definition: ComplexSchur.h:137
Definition: EigenSolver.h:64
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:48
const ComplexMatrixType & matrixT() const
Returns the triangular matrix in the Schur decomposition.
Definition: ComplexSchur.h:161
Definition: ForwardDeclarations.h:17
double Scalar
Common scalar type.
Definition: FlexibleKalmanBase.h:48