13 #ifndef MLPACK_METHODS_ANN_LAYER_CONCAT_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_CONCAT_IMPL_HPP 19 #include "../visitor/forward_visitor.hpp" 20 #include "../visitor/backward_visitor.hpp" 21 #include "../visitor/gradient_visitor.hpp" 26 template<
typename InputDataType,
typename OutputDataType,
27 typename... CustomLayers>
29 const bool model,
const bool run) :
36 weights.set_size(0, 0);
39 template<
typename InputDataType,
typename OutputDataType,
40 typename... CustomLayers>
42 arma::Row<size_t>& inputSize,
52 weights.set_size(0, 0);
55 size_t oldColSize = 1, newColSize = 1;
61 if (inputSize.n_elem > 0)
66 size_t i = std::min(axis + 1, (
size_t) inputSize.n_elem);
67 for (; i < inputSize.n_elem; ++i)
69 newColSize *= inputSize[i];
74 throw std::logic_error(
"Input dimensions not specified.");
83 throw std::logic_error(
"Col size is zero.");
85 channels = newColSize / oldColSize;
89 template<
typename InputDataType,
typename OutputDataType,
90 typename... CustomLayers>
96 std::for_each(network.begin(), network.end(),
97 boost::apply_visitor(deleteVisitor));
101 template<
typename InputDataType,
typename OutputDataType,
102 typename... CustomLayers>
103 template<
typename eT>
105 const arma::Mat<eT>& input, arma::Mat<eT>& output)
109 for (
size_t i = 0; i < network.size(); ++i)
112 boost::apply_visitor(outputParameterVisitor, network[i])),
117 output = boost::apply_visitor(outputParameterVisitor, network.front());
120 output.reshape(output.n_rows / channels, output.n_cols * channels);
122 for (
size_t i = 1; i < network.size(); ++i)
124 arma::Mat<eT> out = boost::apply_visitor(outputParameterVisitor,
127 out.reshape(out.n_rows / channels, out.n_cols * channels);
130 output = arma::join_cols(output, out);
133 output.reshape(output.n_rows * channels, output.n_cols / channels);
136 template<
typename InputDataType,
typename OutputDataType,
137 typename... CustomLayers>
138 template<
typename eT>
140 const arma::Mat<eT>& ,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
146 arma::Mat<eT> gyTmp(((arma::Mat<eT>&) gy).memptr(), gy.n_rows / channels,
147 gy.n_cols * channels,
false,
false);
148 for (
size_t i = 0; i < network.size(); ++i)
151 size_t rows = boost::apply_visitor(
152 outputParameterVisitor, network[i]).n_rows;
155 delta = gyTmp.rows(rowCount / channels, (rowCount + rows) / channels - 1);
156 delta.reshape(delta.n_rows * channels, delta.n_cols / channels);
159 boost::apply_visitor(outputParameterVisitor,
161 boost::apply_visitor(deltaVisitor, network[i])), network[i]);
165 g = boost::apply_visitor(deltaVisitor, network[0]);
166 for (
size_t i = 1; i < network.size(); ++i)
168 g += boost::apply_visitor(deltaVisitor, network[i]);
177 template<
typename InputDataType,
typename OutputDataType,
178 typename... CustomLayers>
179 template<
typename eT>
181 const arma::Mat<eT>& ,
182 const arma::Mat<eT>& gy,
186 size_t rowCount = 0, rows = 0;
188 for (
size_t i = 0; i < index; ++i)
190 rowCount += boost::apply_visitor(
191 outputParameterVisitor, network[i]).n_rows;
193 rows = boost::apply_visitor(outputParameterVisitor, network[index]).n_rows;
196 arma::Mat<eT> gyTmp(((arma::Mat<eT>&) gy).memptr(), gy.n_rows / channels,
197 gy.n_cols * channels,
false,
false);
199 arma::Mat<eT> delta = gyTmp.rows(rowCount / channels, (rowCount + rows) /
201 delta.reshape(delta.n_rows * channels, delta.n_cols / channels);
204 outputParameterVisitor, network[index]), delta,
205 boost::apply_visitor(deltaVisitor, network[index])), network[index]);
207 g = boost::apply_visitor(deltaVisitor, network[index]);
210 template<
typename InputDataType,
typename OutputDataType,
211 typename... CustomLayers>
212 template<
typename eT>
214 const arma::Mat<eT>& input,
215 const arma::Mat<eT>& error,
222 arma::Mat<eT> errorTmp(((arma::Mat<eT>&) error).memptr(),
223 error.n_rows / channels, error.n_cols * channels,
false,
false);
224 for (
size_t i = 0; i < network.size(); ++i)
226 size_t rows = boost::apply_visitor(
227 outputParameterVisitor, network[i]).n_rows;
230 arma::Mat<eT> err = errorTmp.rows(rowCount / channels, (rowCount + rows) /
232 err.reshape(err.n_rows * channels, err.n_cols / channels);
240 template<
typename InputDataType,
typename OutputDataType,
241 typename... CustomLayers>
242 template<
typename eT>
244 const arma::Mat<eT>& input,
245 const arma::Mat<eT>& error,
250 for (
size_t i = 0; i < index; ++i)
252 rowCount += boost::apply_visitor(outputParameterVisitor,
255 size_t rows = boost::apply_visitor(
256 outputParameterVisitor, network[index]).n_rows;
258 arma::Mat<eT> errorTmp(((arma::Mat<eT>&) error).memptr(),
259 error.n_rows / channels, error.n_cols * channels,
false,
false);
260 arma::Mat<eT> err = errorTmp.rows(rowCount / channels, (rowCount + rows) /
262 err.reshape(err.n_rows * channels, err.n_cols / channels);
267 template<
typename InputDataType,
typename OutputDataType,
268 typename... CustomLayers>
269 template<
typename Archive>
271 Archive& ar,
const uint32_t )
273 ar(CEREAL_NVP(model));
280 if (cereal::is_loading<Archive>())
282 std::for_each(network.begin(), network.end(),
283 boost::apply_visitor(deleteVisitor));
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: concat_impl.hpp:270
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Concat(const bool model=false, const bool run=true)
Create the Concat object using the specified parameters.
Definition: concat_impl.hpp:28
arma::mat const & Gradient() const
Get the gradient.
Definition: concat.hpp:190
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
~Concat()
Destroy the layers held by the model.
Definition: concat_impl.hpp:91
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, using 3rd-order tensors as input, calculating the function f(x) by propagating x backwards through f.
Definition: concat_impl.hpp:139
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: concat_impl.hpp:104
#define CEREAL_VECTOR_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_variant_wrapper.hpp:92