mlpack
concatenate_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_CONCATENATE_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_CONCATENATE_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "concatenate.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24  inRows(0)
25 {
26  // Nothing to do here.
27 }
28 
29 template<typename InputDataType, typename OutputDataType>
31  const Concatenate& layer) :
32  inRows(layer.inRows),
33  weights(layer.weights),
34  delta(layer.delta),
35  concat(layer.concat)
36 {
37  // Nothing to to here.
38 }
39 
40 template<typename InputDataType, typename OutputDataType>
42  inRows(layer.inRows),
43  weights(std::move(layer.weights)),
44  delta(std::move(layer.delta)),
45  concat(std::move(layer.concat))
46 {
47  // Nothing to do here.
48 }
49 
50 template<typename InputDataType, typename OutputDataType>
53 operator=(const Concatenate& layer)
54 {
55  if (this != &layer)
56  {
57  inRows = layer.inRows;
58  weights = layer.weights;
59  delta = layer.delta;
60  concat = layer.concat;
61  }
62 
63  return *this;
64 }
65 
66 template<typename InputDataType, typename OutputDataType>
70 {
71  if (this != &layer)
72  {
73  inRows = layer.inRows;
74  weights = std::move(layer.weights);
75  delta = std::move(layer.delta);
76  concat = std::move(layer.concat);
77  }
78  return *this;
79 }
80 
81 template<typename InputDataType, typename OutputDataType>
82 template<typename eT>
84  const arma::Mat<eT>& input, arma::Mat<eT>& output)
85 {
86  if (concat.is_empty())
87  Log::Warn << "The concat matrix has not been provided." << std::endl;
88 
89  if (input.n_cols != concat.n_cols)
90  {
91  Log::Fatal << "The number of columns of the concat matrix should be equal "
92  << "to the number of columns of input matrix." << std::endl;
93  }
94 
95  inRows = input.n_rows;
96  output = arma::join_cols(input, concat);
97 }
98 
99 template<typename InputDataType, typename OutputDataType>
100 template<typename eT>
102  const arma::Mat<eT>& /* input */,
103  const arma::Mat<eT>& gy,
104  arma::Mat<eT>& g)
105 {
106  g = gy.submat(0, 0, inRows - 1, concat.n_cols - 1);
107 }
108 
109 } // namespace ann
110 } // namespace mlpack
111 
112 #endif
Implementation of the Concatenate module class.
Definition: concatenate.hpp:36
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
Concatenate()
Create the Concatenate object using the specified number of output units.
Definition: concatenate_impl.hpp:23
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: concatenate_impl.hpp:101
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: concatenate_impl.hpp:83
Concatenate & operator=(const Concatenate &layer)
Operator= copy constructor.
Definition: concatenate_impl.hpp:53