13 #ifndef MLPACK_CORE_MATH_MULTIPLY_SLICES_IMPL_HPP 14 #define MLPACK_CORE_MATH_MULTIPLY_SLICES_IMPL_HPP 21 template <
typename CubeType>
23 const CubeType& cubeB,
24 const bool aTranspose,
25 const bool bTranspose)
27 size_t rows = cubeA.n_rows, cols = cubeB.n_cols, slices = cubeA.n_slices;
29 if (cubeA.n_slices != cubeB.n_slices)
30 Log::Fatal <<
"Number of slices is not same in both cubes." << std::endl;
32 if (aTranspose && bTranspose)
34 if (cubeA.n_rows != cubeB.n_cols)
35 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
39 else if (bTranspose && !aTranspose)
41 if (cubeA.n_cols != cubeB.n_cols)
42 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
45 else if (aTranspose && !bTranspose)
47 if (cubeA.n_rows != cubeB.n_rows)
48 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
53 if (cubeA.n_cols != cubeB.n_rows)
54 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
57 CubeType z(rows, cols, slices);
59 if (aTranspose && bTranspose)
61 for (
size_t i = 0; i < slices; ++i)
62 z.slice(i) = arma::trans(cubeB.slice(i) * cubeA.slice(i));
64 else if (bTranspose && !aTranspose)
66 for (
size_t i = 0; i < slices; ++i)
67 z.slice(i) = cubeA.slice(i) * cubeB.slice(i).t();
69 else if (aTranspose && !bTranspose)
71 for (
size_t i = 0; i < slices; ++i)
72 z.slice(i) = cubeA.slice(i).t() * cubeB.slice(i);
76 for (
size_t i = 0; i < slices; ++i)
77 z.slice(i) = cubeA.slice(i) * cubeB.slice(i);
82 template <
typename MatType,
typename CubeType>
84 const CubeType& cubeB,
85 const bool aTranspose,
86 const bool bTranspose)
88 size_t rows = matA.n_rows, cols = cubeB.n_cols, slices = cubeB.n_slices;
90 if (aTranspose && bTranspose)
92 if (matA.n_rows != cubeB.n_cols)
93 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
97 else if (bTranspose && !aTranspose)
99 if (matA.n_cols != cubeB.n_cols)
100 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
103 else if (aTranspose && !bTranspose)
105 if (matA.n_rows != cubeB.n_rows)
106 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
111 if (matA.n_cols != cubeB.n_rows)
112 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
115 CubeType z(rows, cols, slices);
117 if (aTranspose && bTranspose)
119 for (
size_t i = 0; i < slices; ++i)
120 z.slice(i) = arma::trans(cubeB.slice(i) * matA);
124 for (
size_t i = 0; i < slices; ++i)
125 z.slice(i) = matA * cubeB.slice(i).t();
129 for (
size_t i = 0; i < slices; ++i)
130 z.slice(i) = matA.t() * cubeB.slice(i);
134 for (
size_t i = 0; i < slices; ++i)
135 z.slice(i) = matA * cubeB.slice(i);
140 template <
typename CubeType,
typename MatType>
143 const bool aTranspose,
144 const bool bTranspose)
146 size_t rows = cubeA.n_rows, cols = matB.n_cols, slices = cubeA.n_slices;
148 if (aTranspose && bTranspose)
150 if (cubeA.n_rows != matB.n_cols)
151 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
155 else if (bTranspose && !aTranspose)
157 if (cubeA.n_cols != matB.n_cols)
158 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
161 else if (aTranspose && !bTranspose)
163 if (cubeA.n_rows != matB.n_rows)
164 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
168 if (cubeA.n_cols != matB.n_rows)
169 Log::Fatal <<
"Matrix multiplication invalid!" << std::endl;
171 CubeType z(rows, cols, slices);
173 if (aTranspose && bTranspose)
175 for (
size_t i = 0; i < slices; ++i)
176 z.slice(i) = arma::trans(matB * cubeA.slice(i));
178 else if (bTranspose && !aTranspose)
180 for (
size_t i = 0; i < slices; ++i)
181 z.slice(i) = cubeA.slice(i) * matB.t();
183 else if (aTranspose && !bTranspose)
185 for (
size_t i = 0; i < slices; ++i)
186 z.slice(i) = cubeA.slice(i).t() * matB;
190 for (
size_t i = 0; i < slices; ++i)
191 z.slice(i) = cubeA.slice(i) * matB;
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
CubeType MultiplyMat2Cube(const MatType &matA, const CubeType &cubeB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of a matrix and all the slices of a cube.
Definition: multiply_slices_impl.hpp:83
CubeType MultiplyCube2Cube(const CubeType &cubeA, const CubeType &cubeB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of slices of two cubes.
Definition: multiply_slices_impl.hpp:22
CubeType MultiplyCube2Mat(const CubeType &cubeA, const MatType &matB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of all slices of a cube with a matrix.
Definition: multiply_slices_impl.hpp:141