TooN
SVD.h
1 // -*- c++ -*-
2 
3 // Copyright (C) 2005,2009 Tom Drummond (twd20@cam.ac.uk)
4 
5 //All rights reserved.
6 //
7 //Redistribution and use in source and binary forms, with or without
8 //modification, are permitted provided that the following conditions
9 //are met:
10 //1. Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //2. Redistributions in binary form must reproduce the above copyright
13 // notice, this list of conditions and the following disclaimer in the
14 // documentation and/or other materials provided with the distribution.
15 //
16 //THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND OTHER CONTRIBUTORS ``AS IS''
17 //AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 //IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 //ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR OTHER CONTRIBUTORS BE
20 //LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 //CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 //SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 //INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 //CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 //ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 //POSSIBILITY OF SUCH DAMAGE.
27 
28 #ifndef __SVD_H
29 #define __SVD_H
30 
31 #include <TooN/TooN.h>
32 #include <TooN/lapack.h>
33 #include <algorithm>
34 
35 namespace TooN {
36 
37  // TODO - should this depend on precision?
38 static const double condition_no=1e9; // GK HACK TO GLOBAL
39 
40 
41 
42 
43 
44 
45 
86 template<int Rows=Dynamic, int Cols=Rows, typename Precision=DefaultPrecision>
87 class SVD {
88  // this is the size of the diagonal
89  // NB works for semi-dynamic sizes because -1 < +ve ints
90  static const int Min_Dim = Rows<Cols?Rows:Cols;
91 
92 public:
93 
95  SVD() {}
96 
98  SVD(int rows, int cols)
99  : my_copy(rows,cols),
100  my_diagonal(std::min(rows,cols)),
101  my_square(std::min(rows,cols), std::min(rows,cols))
102  {}
103 
106  template <int R2, int C2, typename P2, typename B2>
108  : my_copy(m),
109  my_diagonal(std::min(m.num_rows(),m.num_cols())),
110  my_square(std::min(m.num_rows(),m.num_cols()),std::min(m.num_rows(),m.num_cols()))
111  {
112  do_compute();
113  }
114 
116  template <int R2, int C2, typename P2, typename B2>
117  void compute(const Matrix<R2,C2,P2,B2>& m){
118  my_copy=m;
119  do_compute();
120  }
121 
122  private:
123  void do_compute(){
124  Precision* const a = my_copy.my_data;
125  int lda = my_copy.num_cols();
126  int m = my_copy.num_cols();
127  int n = my_copy.num_rows();
128  Precision* const uorvt = my_square.my_data;
129  Precision* const s = my_diagonal.my_data;
130  int ldu;
131  int ldvt = lda;
132  int LWORK;
133  int INFO;
134  char JOBU;
135  char JOBVT;
136 
137  if(is_vertical()){ // u is a
138  JOBU='O';
139  JOBVT='S';
140  ldu = lda;
141  } else { // vt is a
142  JOBU='S';
143  JOBVT='O';
144  ldu = my_square.num_cols();
145  }
146 
147  Precision* wk;
148 
149  Precision size;
150  LWORK = -1;
151 
152  // arguments are scrambled because we use rowmajor and lapack uses colmajor
153  // thus u and vt play each other's roles.
154  gesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
155  &ldvt, uorvt, &ldu, &size, &LWORK, &INFO);
156 
157  LWORK = (long int)(size);
158  wk = new Precision[LWORK];
159 
160  gesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
161  &ldvt, uorvt, &ldu, wk, &LWORK, &INFO);
162 
163  delete[] wk;
164  }
165 
166  bool is_vertical(){
167  return (my_copy.num_rows() >= my_copy.num_cols());
168  }
169 
170  int min_dim(){ return std::min(my_copy.num_rows(), my_copy.num_cols()); }
171 
172  public:
173 
178  template <int Rows2, int Cols2, typename P2, typename B2>
180  backsub(const Matrix<Rows2,Cols2,P2,B2>& rhs, const Precision condition=condition_no)
181  {
182  Vector<Min_Dim> inv_diag(min_dim());
183  get_inv_diag(inv_diag,condition);
184  return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
185  }
186 
191  template <int Size, typename P2, typename B2>
193  backsub(const Vector<Size,P2,B2>& rhs, const Precision condition=condition_no)
194  {
195  Vector<Min_Dim> inv_diag(min_dim());
196  get_inv_diag(inv_diag,condition);
197  return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
198  }
199 
204  Matrix<Cols,Rows> get_pinv(const Precision condition = condition_no){
205  Vector<Min_Dim> inv_diag(min_dim());
206  get_inv_diag(inv_diag,condition);
207  return diagmult(get_VT().T(),inv_diag) * get_U().T();
208  }
209 
212  Precision determinant() {
213  Precision result = my_diagonal[0];
214  for(int i=1; i<my_diagonal.size(); i++){
215  result *= my_diagonal[i];
216  }
217  return result;
218  }
219 
222  int rank(const Precision condition = condition_no) {
223  if (my_diagonal[0] == 0) return 0;
224  int result=1;
225  for(int i=0; i<min_dim(); i++){
226  if(my_diagonal[i] * condition <= my_diagonal[0]){
227  result++;
228  }
229  }
230  return result;
231  }
232 
237  if(is_vertical()){
239  (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
240  } else {
242  (my_square.my_data, my_square.num_rows(), my_square.num_cols());
243  }
244  }
245 
247  Vector<Min_Dim,Precision>& get_diagonal(){ return my_diagonal; }
248 
253  if(is_vertical()){
255  (my_square.my_data, my_square.num_rows(), my_square.num_cols());
256  } else {
258  (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
259  }
260  }
261 
267  void get_inv_diag(Vector<Min_Dim>& inv_diag, const Precision condition){
268  for(int i=0; i<min_dim(); i++){
269  if(my_diagonal[i] * condition <= my_diagonal[0]){
270  inv_diag[i]=0;
271  } else {
272  inv_diag[i]=static_cast<Precision>(1)/my_diagonal[i];
273  }
274  }
275  }
276 
277 private:
279  Vector<Min_Dim,Precision> my_diagonal;
280  Matrix<Min_Dim,Min_Dim,Precision,RowMajor> my_square; // square matrix (U or V' depending on the shape of my_copy)
281 };
282 
283 
284 
285 
286 
287 
291 template<int Size, typename Precision>
292 struct SQSVD : public SVD<Size, Size, Precision> {
296  SQSVD() {}
297  SQSVD(int size) : SVD<Size,Size,Precision>(size, size) {}
298 
299  template <int R2, int C2, typename P2, typename B2>
302 };
303 
304 
305 }
306 
307 
308 #endif
int rank(const Precision condition=condition_no)
Calculate the rank of the matrix.
Definition: SVD.h:222
SVD(const Matrix< R2, C2, P2, B2 > &m)
Construct the SVD decomposition of a matrix.
Definition: SVD.h:107
void get_inv_diag(Vector< Min_Dim > &inv_diag, const Precision condition)
Return the pesudo-inverse diagonal.
Definition: SVD.h:267
Pretty generic SFINAE introspection generator.
Definition: vec_test.cc:21
A vector.
Definition: vector.hh:126
Matrix< Rows, Min_Dim, Precision, Reference::RowMajor > get_U()
Return the U matrix from the decomposition The size of this depends on the shape of the original matr...
Definition: SVD.h:236
Definition: TooN.h:364
void compute(const Matrix< R2, C2, P2, B2 > &m)
Compute the SVD decomposition of M, typically used after the default constructor. ...
Definition: SVD.h:117
SVD(int rows, int cols)
constructor for Rows=-1 or Cols=-1 (or both)
Definition: SVD.h:98
Matrix< Cols, Rows > get_pinv(const Precision condition=condition_no)
Calculate (pseudo-)inverse of the matrix.
Definition: SVD.h:204
Precision determinant()
Calculate the product of the singular values for square matrices this is the determinant.
Definition: SVD.h:212
Vector< Min_Dim, Precision > & get_diagonal()
Return the singular values as a vector.
Definition: SVD.h:247
Vector< Cols, typename Internal::MultiplyType< Precision, P2 >::type > backsub(const Vector< Size, P2, B2 > &rhs, const Precision condition=condition_no)
Calculate result of multiplying the (pseudo-)inverse of M by a vector.
Definition: SVD.h:193
Matrix< Min_Dim, Cols, Precision, Reference::RowMajor > get_VT()
Return the VT matrix from the decomposition The size of this depends on the shape of the original mat...
Definition: SVD.h:252
Matrix< Cols, Cols2, typename Internal::MultiplyType< Precision, P2 >::type > backsub(const Matrix< Rows2, Cols2, P2, B2 > &rhs, const Precision condition=condition_no)
Calculate result of multiplying the (pseudo-)inverse of M by another matrix.
Definition: SVD.h:180
Performs SVD and back substitute to solve equations.
Definition: SVD.h:87
SVD()
default constructor for Rows>0 and Cols>0
Definition: SVD.h:95
version of SVD forced to be square princiapally here to allow use in WLS
Definition: SVD.h:292