TooN
LU.h
1 // -*- c++ -*-
2 
3 // Copyright (C) 2005,2009 Tom Drummond (twd20@cam.ac.uk),
4 // Ed Rosten (er258@cam.ac.uk)
5 
6 //All rights reserved.
7 //
8 //Redistribution and use in source and binary forms, with or without
9 //modification, are permitted provided that the following conditions
10 //are met:
11 //1. Redistributions of source code must retain the above copyright
12 // notice, this list of conditions and the following disclaimer.
13 //2. Redistributions in binary form must reproduce the above copyright
14 // notice, this list of conditions and the following disclaimer in the
15 // documentation and/or other materials provided with the distribution.
16 //
17 //THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND OTHER CONTRIBUTORS ``AS IS''
18 //AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 //IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 //ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR OTHER CONTRIBUTORS BE
21 //LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 //CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 //SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 //INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 //CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 //ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 //POSSIBILITY OF SUCH DAMAGE.
28 
29 #ifndef TOON_INCLUDE_LU_H
30 #define TOON_INCLUDE_LU_H
31 
32 #include <iostream>
33 
34 #include <TooN/TooN.h>
35 #include <TooN/lapack.h>
36 
37 
38 namespace TooN {
66 template <int Size=-1, class Precision=double>
67 class LU {
68  public:
69 
72  template<int S1, int S2, class Base>
74  :my_lu(m.num_rows(),m.num_cols()),my_IPIV(m.num_rows()){
75  compute(m);
76  }
77 
79  template<int S1, int S2, class Base>
81  //check for consistency with Size
82  SizeMismatch<Size, S1>::test(my_lu.num_rows(),m.num_rows());
83  SizeMismatch<Size, S2>::test(my_lu.num_rows(),m.num_cols());
84 
85  //Make a local copy. This is guaranteed contiguous
86  my_lu=m;
87  FortranInteger lda = m.num_rows();
88  FortranInteger M = m.num_rows();
89  FortranInteger N = m.num_rows();
90 
91  getrf_(&M,&N,&my_lu[0][0],&lda,&my_IPIV[0],&my_info);
92 
93  if(my_info < 0){
94  std::cerr << "error in LU, INFO was " << my_info << std::endl;
95  }
96  }
97 
100  template <int Rows, int NRHS, class Base>
102  //Check the number of rows is OK.
103  SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.num_rows());
104 
105  Matrix<Size, NRHS, Precision> result(rhs);
106 
107  FortranInteger M=rhs.num_cols();
108  FortranInteger N=my_lu.num_rows();
109  double alpha=1;
110  FortranInteger lda=my_lu.num_rows();
111  FortranInteger ldb=rhs.num_cols();
112  trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb);
113  trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb);
114 
115  // now do the row swapping (lapack dlaswp.f only shuffles fortran rows = Rowmajor cols)
116  for(int i=N-1; i>=0; i--){
117  const int swaprow = my_IPIV[i]-1; // fortran arrays start at 1
118  for(int j=0; j<NRHS; j++){
119  Precision temp = result[i][j];
120  result[i][j] = result[swaprow][j];
121  result[swaprow][j] = temp;
122  }
123  }
124  return result;
125  }
126 
129  template <int Rows, class Base>
131  //Check the number of rows is OK.
132  SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.size());
133 
134  Vector<Size, Precision> result(rhs);
135 
136  FortranInteger M=1;
137  FortranInteger N=my_lu.num_rows();
138  double alpha=1;
139  FortranInteger lda=my_lu.num_rows();
140  FortranInteger ldb=1;
141  trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb);
142  trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb);
143 
144  // now do the row swapping (lapack dlaswp.f only shuffles fortran rows = Rowmajor cols)
145  for(int i=N-1; i>=0; i--){
146  const int swaprow = my_IPIV[i]-1; // fortran arrays start at 1
147  Precision temp = result[i];
148  result[i] = result[swaprow];
149  result[swaprow] = temp;
150  }
151  return result;
152  }
153 
157  Matrix<Size,Size,Precision> Inverse(my_lu);
158  FortranInteger N = my_lu.num_rows();
159  FortranInteger lda=my_lu.num_rows();
160  FortranInteger lwork=-1;
161  Precision size;
162  getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], &size, &lwork, &my_info);
163  lwork=FortranInteger(size);
164  Precision* WORK = new Precision[lwork];
165  getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], WORK, &lwork, &my_info);
166  delete [] WORK;
167  return Inverse;
168  }
169 
175  const Matrix<Size,Size,Precision>& get_lu()const {return my_lu;}
176 
177  private:
178  inline int get_sign() const {
179  int result=1;
180  for(int i=0; i<my_lu.num_rows()-1; i++){
181  if(my_IPIV[i] > i+1){
182  result=-result;
183  }
184  }
185  return result;
186  }
187  public:
188 
190  inline Precision determinant() const {
191  Precision result = get_sign();
192  for (int i=0; i<my_lu.num_rows(); i++){
193  result*=my_lu(i,i);
194  }
195  return result;
196  }
197 
199  int get_info() const { return my_info; }
200 
201  private:
202 
204  FortranInteger my_info;
205  Vector<Size, FortranInteger> my_IPIV; //Convenient static-or-dynamic array of ints :-)
206 
207 };
208 }
209 
210 
211 #endif
Pretty generic SFINAE introspection generator.
Definition: vec_test.cc:21
A matrix.
Definition: matrix.hh:105
Matrix< Size, Size, Precision > get_inverse()
Calculate inverse of the matrix.
Definition: LU.h:156
LU(const Matrix< S1, S2, Precision, Base > &m)
Construct the LU decomposition of a matrix.
Definition: LU.h:73
Performs LU decomposition and back substitutes to solve equations.
Definition: LU.h:67
Matrix< Size, NRHS, Precision > backsub(const Matrix< Rows, NRHS, Precision, Base > &rhs)
Calculate result of multiplying the inverse of M by another matrix.
Definition: LU.h:101
Precision determinant() const
Calculate the determinant of the matrix.
Definition: LU.h:190
Vector< Size, Precision > backsub(const Vector< Rows, Precision, Base > &rhs)
Calculate result of multiplying the inverse of M by a vector.
Definition: LU.h:130
void compute(const Matrix< S1, S2, Precision, Base > &m)
Perform the LU decompsition of another matrix.
Definition: LU.h:80
Definition: size_mismatch.hh:103
int get_info() const
Get the LAPACK info.
Definition: LU.h:199
const Matrix< Size, Size, Precision > & get_lu() const
Returns the L and U matrices.
Definition: LU.h:175