mlpack
multiply_slices_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_MATH_MULTIPLY_SLICES_IMPL_HPP
14 #define MLPACK_CORE_MATH_MULTIPLY_SLICES_IMPL_HPP
15 
16 #include "multiply_slices.hpp"
17 
18 namespace mlpack {
19 namespace math {
20 
21 template <typename CubeType>
22 CubeType MultiplyCube2Cube(const CubeType& cubeA,
23  const CubeType& cubeB,
24  const bool aTranspose,
25  const bool bTranspose)
26 {
27  size_t rows = cubeA.n_rows, cols = cubeB.n_cols, slices = cubeA.n_slices;
28 
29  if (cubeA.n_slices != cubeB.n_slices)
30  Log::Fatal << "Number of slices is not same in both cubes." << std::endl;
31 
32  if (aTranspose && bTranspose)
33  {
34  if (cubeA.n_rows != cubeB.n_cols)
35  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
36  rows = cubeA.n_cols;
37  cols = cubeB.n_rows;
38  }
39  else if (bTranspose && !aTranspose)
40  {
41  if (cubeA.n_cols != cubeB.n_cols)
42  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
43  cols = cubeB.n_rows;
44  }
45  else if (aTranspose && !bTranspose)
46  {
47  if (cubeA.n_rows != cubeB.n_rows)
48  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
49  rows = cubeA.n_cols;
50  }
51  else
52  {
53  if (cubeA.n_cols != cubeB.n_rows)
54  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
55  }
56 
57  CubeType z(rows, cols, slices);
58 
59  if (aTranspose && bTranspose)
60  {
61  for (size_t i = 0; i < slices; ++i)
62  z.slice(i) = arma::trans(cubeB.slice(i) * cubeA.slice(i));
63  }
64  else if (bTranspose && !aTranspose)
65  {
66  for (size_t i = 0; i < slices; ++i)
67  z.slice(i) = cubeA.slice(i) * cubeB.slice(i).t();
68  }
69  else if (aTranspose && !bTranspose)
70  {
71  for (size_t i = 0; i < slices; ++i)
72  z.slice(i) = cubeA.slice(i).t() * cubeB.slice(i);
73  }
74  else
75  {
76  for (size_t i = 0; i < slices; ++i)
77  z.slice(i) = cubeA.slice(i) * cubeB.slice(i);
78  }
79  return z;
80 }
81 
82 template <typename MatType, typename CubeType>
83 CubeType MultiplyMat2Cube(const MatType& matA,
84  const CubeType& cubeB,
85  const bool aTranspose,
86  const bool bTranspose)
87 {
88  size_t rows = matA.n_rows, cols = cubeB.n_cols, slices = cubeB.n_slices;
89 
90  if (aTranspose && bTranspose)
91  {
92  if (matA.n_rows != cubeB.n_cols)
93  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
94  rows = matA.n_cols;
95  cols = cubeB.n_rows;
96  }
97  else if (bTranspose && !aTranspose)
98  {
99  if (matA.n_cols != cubeB.n_cols)
100  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
101  cols = cubeB.n_rows;
102  }
103  else if (aTranspose && !bTranspose)
104  {
105  if (matA.n_rows != cubeB.n_rows)
106  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
107  rows = matA.n_cols;
108  }
109  else
110  {
111  if (matA.n_cols != cubeB.n_rows)
112  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
113  }
114 
115  CubeType z(rows, cols, slices);
116 
117  if (aTranspose && bTranspose)
118  {
119  for (size_t i = 0; i < slices; ++i)
120  z.slice(i) = arma::trans(cubeB.slice(i) * matA);
121  }
122  else if (bTranspose)
123  {
124  for (size_t i = 0; i < slices; ++i)
125  z.slice(i) = matA * cubeB.slice(i).t();
126  }
127  else if (aTranspose)
128  {
129  for (size_t i = 0; i < slices; ++i)
130  z.slice(i) = matA.t() * cubeB.slice(i);
131  }
132  else
133  {
134  for (size_t i = 0; i < slices; ++i)
135  z.slice(i) = matA * cubeB.slice(i);
136  }
137  return z;
138 }
139 
140 template <typename CubeType, typename MatType>
141 CubeType MultiplyCube2Mat(const CubeType& cubeA,
142  const MatType& matB,
143  const bool aTranspose,
144  const bool bTranspose)
145 {
146  size_t rows = cubeA.n_rows, cols = matB.n_cols, slices = cubeA.n_slices;
147 
148  if (aTranspose && bTranspose)
149  {
150  if (cubeA.n_rows != matB.n_cols)
151  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
152  rows = cubeA.n_cols;
153  cols = matB.n_rows;
154  }
155  else if (bTranspose && !aTranspose)
156  {
157  if (cubeA.n_cols != matB.n_cols)
158  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
159  cols = matB.n_rows;
160  }
161  else if (aTranspose && !bTranspose)
162  {
163  if (cubeA.n_rows != matB.n_rows)
164  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
165  rows = cubeA.n_cols;
166  }
167  else
168  if (cubeA.n_cols != matB.n_rows)
169  Log::Fatal << "Matrix multiplication invalid!" << std::endl;
170 
171  CubeType z(rows, cols, slices);
172 
173  if (aTranspose && bTranspose)
174  {
175  for (size_t i = 0; i < slices; ++i)
176  z.slice(i) = arma::trans(matB * cubeA.slice(i));
177  }
178  else if (bTranspose && !aTranspose)
179  {
180  for (size_t i = 0; i < slices; ++i)
181  z.slice(i) = cubeA.slice(i) * matB.t();
182  }
183  else if (aTranspose && !bTranspose)
184  {
185  for (size_t i = 0; i < slices; ++i)
186  z.slice(i) = cubeA.slice(i).t() * matB;
187  }
188  else
189  {
190  for (size_t i = 0; i < slices; ++i)
191  z.slice(i) = cubeA.slice(i) * matB;
192  }
193  return z;
194 }
195 
196 } // namespace math
197 } // namespace mlpack
198 
199 #endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
CubeType MultiplyMat2Cube(const MatType &matA, const CubeType &cubeB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of a matrix and all the slices of a cube.
Definition: multiply_slices_impl.hpp:83
CubeType MultiplyCube2Cube(const CubeType &cubeA, const CubeType &cubeB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of slices of two cubes.
Definition: multiply_slices_impl.hpp:22
CubeType MultiplyCube2Mat(const CubeType &cubeA, const MatType &matB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of all slices of a cube with a matrix.
Definition: multiply_slices_impl.hpp:141