mlpack
join_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_JOIN_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_JOIN_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "join.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType>
23  inSizeRows(0),
24  inSizeCols(0)
25 {
26  // Nothing to do here.
27 }
28 
29 template<typename InputDataType, typename OutputDataType>
30 template<typename InputType, typename OutputType>
32  const InputType& input, OutputType& output)
33 {
34  inSizeRows = input.n_rows;
35  inSizeCols = input.n_cols;
36  output = arma::vectorise(input);
37 }
38 
39 template<typename InputDataType, typename OutputDataType>
40 template<typename eT>
42  const arma::Mat<eT>& /* input */,
43  const arma::Mat<eT>& gy,
44  arma::Mat<eT>& g)
45 {
46  g = arma::mat(((arma::Mat<eT>&) gy).memptr(), inSizeRows, inSizeCols, false,
47  false);
48 }
49 
50 template<typename InputDataType, typename OutputDataType>
51 template<typename Archive>
53  Archive& ar,
54  const uint32_t /* version */)
55 {
56  ar(CEREAL_NVP(inSizeRows));
57  ar(CEREAL_NVP(inSizeCols));
58 }
59 
60 } // namespace ann
61 } // namespace mlpack
62 
63 #endif
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: join_impl.hpp:41
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Join()
Create the Join object.
Definition: join_impl.hpp:22
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: join_impl.hpp:31
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: join_impl.hpp:52