13 #ifndef MLPACK_METHODS_ANN_LAYER_HIGHWAY_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_HIGHWAY_IMPL_HPP 19 #include "../visitor/forward_visitor.hpp" 20 #include "../visitor/backward_visitor.hpp" 21 #include "../visitor/gradient_visitor.hpp" 22 #include "../visitor/set_input_height_visitor.hpp" 23 #include "../visitor/set_input_width_visitor.hpp" 28 template<
typename InputDataType,
typename OutputDataType,
29 typename... CustomLayers>
41 typename InputDataType,
typename OutputDataType,
typename... CustomLayers>
51 weights.set_size(inSize * inSize + inSize, 1);
54 template<
typename InputDataType,
typename OutputDataType,
55 typename... CustomLayers>
60 for (
size_t i = 0; i < network.size(); ++i)
62 if (networkOwnerships[i])
63 boost::apply_visitor(deleteVisitor, network[i]);
68 template<
typename InputDataType,
typename OutputDataType,
69 typename... CustomLayers>
72 transformWeight = arma::mat(weights.memptr(), inSize, inSize,
false,
false);
73 transformBias = arma::mat(weights.memptr() + transformWeight.n_elem,
74 inSize, 1,
false,
false);
77 template<
typename InputDataType,
typename OutputDataType,
78 typename... CustomLayers>
81 const arma::Mat<eT>& input, arma::Mat<eT>& output)
84 boost::apply_visitor(outputParameterVisitor, network.front())),
89 if (boost::apply_visitor(outputWidthVisitor, network.front()) != 0)
91 width = boost::apply_visitor(outputWidthVisitor, network.front());
94 if (boost::apply_visitor(outputHeightVisitor, network.front()) != 0)
96 height = boost::apply_visitor(outputHeightVisitor, network.front());
100 for (
size_t i = 1; i < network.size(); ++i)
112 outputParameterVisitor, network[i - 1]),
113 boost::apply_visitor(outputParameterVisitor, network[i])),
119 if (boost::apply_visitor(outputWidthVisitor, network[i]) != 0)
121 width = boost::apply_visitor(outputWidthVisitor, network[i]);
125 if (boost::apply_visitor(outputHeightVisitor, network[i]) != 0)
127 height = boost::apply_visitor(outputHeightVisitor, network[i]);
136 output = boost::apply_visitor(outputParameterVisitor, network.back());
138 if (arma::size(output) != arma::size(input))
140 Log::Fatal <<
"The sizes of the output and input matrices of the Highway" 141 <<
" network should be equal. Please examine the network layers.";
144 transformGate = transformWeight * input;
145 transformGate.each_col() += transformBias;
146 transformGateActivation = 1.0 /(1 + arma::exp(-transformGate));
147 inputParameter = input;
148 networkOutput = output;
149 output = (output % transformGateActivation) +
150 (input % (1 - transformGateActivation));
153 template<
typename InputDataType,
typename OutputDataType,
154 typename... CustomLayers>
155 template<
typename eT>
157 const arma::Mat<eT>& ,
158 const arma::Mat<eT>& gy,
161 arma::Mat<eT> gyTransform = gy % transformGateActivation;
163 outputParameterVisitor, network.back()),
165 boost::apply_visitor(deltaVisitor, network.back())),
168 for (
size_t i = 2; i < network.size() + 1; ++i)
171 outputParameterVisitor, network[network.size() - i]),
172 boost::apply_visitor(deltaVisitor, network[network.size() - i + 1]),
173 boost::apply_visitor(deltaVisitor,
174 network[network.size() - i])), network[network.size() - i]);
177 g = boost::apply_visitor(deltaVisitor, network.front());
179 transformGateError = gy % (networkOutput - inputParameter) %
180 transformGateActivation % (1.0 - transformGateActivation);
181 g += transformWeight.t() * transformGateError;
182 g += gy % (1 - transformGateActivation);
185 template<
typename InputDataType,
typename OutputDataType,
186 typename... CustomLayers>
187 template<
typename eT>
189 const arma::Mat<eT>& input,
190 const arma::Mat<eT>& error,
191 arma::Mat<eT>& gradient)
193 arma::Mat<eT> errorTransform = error % transformGateActivation;
195 outputParameterVisitor, network[network.size() - 2]),
196 errorTransform), network.back());
198 for (
size_t i = 2; i < network.size(); ++i)
201 outputParameterVisitor, network[network.size() - i - 1]),
202 boost::apply_visitor(deltaVisitor, network[network.size() - i + 1])),
203 network[network.size() - i]);
207 boost::apply_visitor(deltaVisitor, network[1])), network.front());
209 gradient.submat(0, 0, transformWeight.n_elem - 1, 0) = arma::vectorise(
210 transformGateError * input.t());
211 gradient.submat(transformWeight.n_elem, 0, gradient.n_elem - 1, 0) =
212 arma::sum(transformGateError, 1);
215 template<
typename InputDataType,
typename OutputDataType,
216 typename... CustomLayers>
217 template<
typename Archive>
219 Archive& ar,
const uint32_t )
222 if (cereal::is_loading<Archive>())
224 for (LayerTypes<CustomLayers...>& layer : network)
226 boost::apply_visitor(deleteVisitor, layer);
228 weights.set_size(inSize * inSize + inSize, 1);
231 ar(CEREAL_NVP(model));
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: highway_impl.hpp:218
~Highway()
Destroy the Highway object.
Definition: highway_impl.hpp:56
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: highway_impl.hpp:80
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed-backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: highway_impl.hpp:156
#define CEREAL_VECTOR_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_variant_wrapper.hpp:92
void Reset()
Reset the layer parameter.
Definition: highway_impl.hpp:70
Highway()
Create the Highway object.
Definition: highway_impl.hpp:30
OutputDataType const & Gradient() const
Get the gradient.
Definition: highway.hpp:171