13 #ifndef MLPACK_METHODS_ANN_LAYER_CONCAT_HPP 14 #define MLPACK_METHODS_ANN_LAYER_CONCAT_HPP 18 #include "../visitor/delete_visitor.hpp" 19 #include "../visitor/delta_visitor.hpp" 20 #include "../visitor/output_parameter_visitor.hpp" 39 typename InputDataType = arma::mat,
40 typename OutputDataType = arma::mat,
41 typename... CustomLayers
52 Concat(
const bool model =
false,
53 const bool run =
true);
63 Concat(arma::Row<size_t>& inputSize,
65 const bool model =
false,
66 const bool run =
true);
81 void Forward(
const arma::Mat<eT>& input, arma::Mat<eT>& output);
94 const arma::Mat<eT>& gy,
106 template<
typename eT>
107 void Backward(
const arma::Mat<eT>& ,
108 const arma::Mat<eT>& gy,
119 template<
typename eT>
120 void Gradient(
const arma::Mat<eT>& ,
121 const arma::Mat<eT>& error,
133 template<
typename eT>
134 void Gradient(
const arma::Mat<eT>& input,
135 const arma::Mat<eT>& error,
136 arma::Mat<eT>& gradient,
144 template <
class LayerType,
class... Args>
145 void Add(Args... args) { network.push_back(
new LayerType(args...)); }
152 void Add(LayerTypes<CustomLayers...> layer) { network.push_back(layer); }
155 std::vector<LayerTypes<CustomLayers...> >&
Model()
171 bool Run()
const {
return run; }
173 bool&
Run() {
return run; }
175 arma::mat
const& InputParameter()
const {
return inputParameter; }
185 arma::mat
const&
Delta()
const {
return delta; }
187 arma::mat&
Delta() {
return delta; }
190 arma::mat
const&
Gradient()
const {
return gradient; }
203 template<
typename Archive>
204 void serialize(Archive& ar,
const uint32_t );
208 arma::Row<size_t> inputSize;
227 std::vector<LayerTypes<CustomLayers...> > network;
230 OutputDataType weights;
242 std::vector<LayerTypes<CustomLayers...> > empty;
248 arma::mat inputParameter;
251 arma::mat outputParameter;
DeleteVisitor executes the destructor of the instantiated object.
Definition: delete_visitor.hpp:27
arma::mat & InputParameter()
Modify the input parameter.
Definition: concat.hpp:177
Implementation of the Add module class.
Definition: add.hpp:34
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: concat_impl.hpp:270
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Concat(const bool model=false, const bool run=true)
Create the Concat object using the specified parameters.
Definition: concat_impl.hpp:28
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t WeightSize() const
Get the size of the weight matrix.
Definition: concat.hpp:198
arma::mat & Gradient()
Modify the gradient.
Definition: concat.hpp:192
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: concat.hpp:166
arma::mat const & Gradient() const
Get the gradient.
Definition: concat.hpp:190
Implementation of the Concat class.
Definition: concat.hpp:43
arma::mat const & OutputParameter() const
Get the output parameter.
Definition: concat.hpp:180
OutputParameterVisitor exposes the output parameter of the given module.
Definition: output_parameter_visitor.hpp:27
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: concat.hpp:168
bool Run() const
Get the value of run parameter.
Definition: concat.hpp:171
arma::mat const & Delta() const
Get the delta.e.
Definition: concat.hpp:185
std::vector< LayerTypes< CustomLayers... > > & Model()
Return the model modules.
Definition: concat.hpp:155
~Concat()
Destroy the layers held by the model.
Definition: concat_impl.hpp:91
DeltaVisitor exposes the delta parameter of the given module.
Definition: delta_visitor.hpp:27
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, using 3rd-order tensors as input, calculating the function f(x) by propagating x backwards through f.
Definition: concat_impl.hpp:139
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: concat_impl.hpp:104
arma::mat & OutputParameter()
Modify the output parameter.
Definition: concat.hpp:182
size_t const & ConcatAxis() const
Get the axis of concatenation.
Definition: concat.hpp:195
bool & Run()
Modify the value of run parameter.
Definition: concat.hpp:173
arma::mat & Delta()
Modify the delta.
Definition: concat.hpp:187