mlpack
highway_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_HIGHWAY_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_HIGHWAY_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "highway.hpp"
18 
19 #include "../visitor/forward_visitor.hpp"
20 #include "../visitor/backward_visitor.hpp"
21 #include "../visitor/gradient_visitor.hpp"
22 #include "../visitor/set_input_height_visitor.hpp"
23 #include "../visitor/set_input_width_visitor.hpp"
24 
25 namespace mlpack {
26 namespace ann {
27 
28 template<typename InputDataType, typename OutputDataType,
29  typename... CustomLayers>
31  inSize(0),
32  model(true),
33  reset(false),
34  width(0),
35  height(0)
36 {
37  // Nothing to do here.
38 }
39 
40 template<
41  typename InputDataType, typename OutputDataType, typename... CustomLayers>
43  const size_t inSize,
44  const bool model) :
45  inSize(inSize),
46  model(model),
47  reset(false),
48  width(0),
49  height(0)
50 {
51  weights.set_size(inSize * inSize + inSize, 1);
52 }
53 
54 template<typename InputDataType, typename OutputDataType,
55  typename... CustomLayers>
57 {
58  if (!model)
59  {
60  for (size_t i = 0; i < network.size(); ++i)
61  {
62  if (networkOwnerships[i])
63  boost::apply_visitor(deleteVisitor, network[i]);
64  }
65  }
66 }
67 
68 template<typename InputDataType, typename OutputDataType,
69  typename... CustomLayers>
71 {
72  transformWeight = arma::mat(weights.memptr(), inSize, inSize, false, false);
73  transformBias = arma::mat(weights.memptr() + transformWeight.n_elem,
74  inSize, 1, false, false);
75 }
76 
77 template<typename InputDataType, typename OutputDataType,
78  typename... CustomLayers>
79 template<typename eT>
81  const arma::Mat<eT>& input, arma::Mat<eT>& output)
82 {
83  boost::apply_visitor(ForwardVisitor(input,
84  boost::apply_visitor(outputParameterVisitor, network.front())),
85  network.front());
86 
87  if (!reset)
88  {
89  if (boost::apply_visitor(outputWidthVisitor, network.front()) != 0)
90  {
91  width = boost::apply_visitor(outputWidthVisitor, network.front());
92  }
93 
94  if (boost::apply_visitor(outputHeightVisitor, network.front()) != 0)
95  {
96  height = boost::apply_visitor(outputHeightVisitor, network.front());
97  }
98  }
99 
100  for (size_t i = 1; i < network.size(); ++i)
101  {
102  if (!reset)
103  {
104  // Set the input width.
105  boost::apply_visitor(SetInputWidthVisitor(width), network[i]);
106 
107  // Set the input height.
108  boost::apply_visitor(SetInputHeightVisitor(height), network[i]);
109  }
110 
111  boost::apply_visitor(ForwardVisitor(boost::apply_visitor(
112  outputParameterVisitor, network[i - 1]),
113  boost::apply_visitor(outputParameterVisitor, network[i])),
114  network[i]);
115 
116  if (!reset)
117  {
118  // Get the output width.
119  if (boost::apply_visitor(outputWidthVisitor, network[i]) != 0)
120  {
121  width = boost::apply_visitor(outputWidthVisitor, network[i]);
122  }
123 
124  // Get the output height.
125  if (boost::apply_visitor(outputHeightVisitor, network[i]) != 0)
126  {
127  height = boost::apply_visitor(outputHeightVisitor, network[i]);
128  }
129  }
130  }
131  if (!reset)
132  {
133  reset = true;
134  }
135 
136  output = boost::apply_visitor(outputParameterVisitor, network.back());
137 
138  if (arma::size(output) != arma::size(input))
139  {
140  Log::Fatal << "The sizes of the output and input matrices of the Highway"
141  << " network should be equal. Please examine the network layers.";
142  }
143 
144  transformGate = transformWeight * input;
145  transformGate.each_col() += transformBias;
146  transformGateActivation = 1.0 /(1 + arma::exp(-transformGate));
147  inputParameter = input;
148  networkOutput = output;
149  output = (output % transformGateActivation) +
150  (input % (1 - transformGateActivation));
151 }
152 
153 template<typename InputDataType, typename OutputDataType,
154  typename... CustomLayers>
155 template<typename eT>
157  const arma::Mat<eT>& /* input */,
158  const arma::Mat<eT>& gy,
159  arma::Mat<eT>& g)
160 {
161  arma::Mat<eT> gyTransform = gy % transformGateActivation;
162  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
163  outputParameterVisitor, network.back()),
164  gyTransform,
165  boost::apply_visitor(deltaVisitor, network.back())),
166  network.back());
167 
168  for (size_t i = 2; i < network.size() + 1; ++i)
169  {
170  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
171  outputParameterVisitor, network[network.size() - i]),
172  boost::apply_visitor(deltaVisitor, network[network.size() - i + 1]),
173  boost::apply_visitor(deltaVisitor,
174  network[network.size() - i])), network[network.size() - i]);
175  }
176 
177  g = boost::apply_visitor(deltaVisitor, network.front());
178 
179  transformGateError = gy % (networkOutput - inputParameter) %
180  transformGateActivation % (1.0 - transformGateActivation);
181  g += transformWeight.t() * transformGateError;
182  g += gy % (1 - transformGateActivation);
183 }
184 
185 template<typename InputDataType, typename OutputDataType,
186  typename... CustomLayers>
187 template<typename eT>
189  const arma::Mat<eT>& input,
190  const arma::Mat<eT>& error,
191  arma::Mat<eT>& gradient)
192 {
193  arma::Mat<eT> errorTransform = error % transformGateActivation;
194  boost::apply_visitor(GradientVisitor(boost::apply_visitor(
195  outputParameterVisitor, network[network.size() - 2]),
196  errorTransform), network.back());
197 
198  for (size_t i = 2; i < network.size(); ++i)
199  {
200  boost::apply_visitor(GradientVisitor(boost::apply_visitor(
201  outputParameterVisitor, network[network.size() - i - 1]),
202  boost::apply_visitor(deltaVisitor, network[network.size() - i + 1])),
203  network[network.size() - i]);
204  }
205 
206  boost::apply_visitor(GradientVisitor(input,
207  boost::apply_visitor(deltaVisitor, network[1])), network.front());
208 
209  gradient.submat(0, 0, transformWeight.n_elem - 1, 0) = arma::vectorise(
210  transformGateError * input.t());
211  gradient.submat(transformWeight.n_elem, 0, gradient.n_elem - 1, 0) =
212  arma::sum(transformGateError, 1);
213 }
214 
215 template<typename InputDataType, typename OutputDataType,
216  typename... CustomLayers>
217 template<typename Archive>
219  Archive& ar, const uint32_t /* version */)
220 {
221  // If loading, delete the old layers and set size for weights.
222  if (cereal::is_loading<Archive>())
223  {
224  for (LayerTypes<CustomLayers...>& layer : network)
225  {
226  boost::apply_visitor(deleteVisitor, layer);
227  }
228  weights.set_size(inSize * inSize + inSize, 1);
229  }
230 
231  ar(CEREAL_NVP(model));
232  ar(CEREAL_VECTOR_VARIANT_POINTER(network));
233 }
234 
235 } // namespace ann
236 } // namespace mlpack
237 
238 #endif
SetInputHeightVisitor updates the input height parameter with the given input height.
Definition: set_input_height_visitor.hpp:27
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: highway_impl.hpp:218
~Highway()
Destroy the Highway object.
Definition: highway_impl.hpp:56
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: highway_impl.hpp:80
SetInputWidthVisitor updates the input width parameter with the given input width.
Definition: set_input_width_visitor.hpp:27
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed-backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: highway_impl.hpp:156
#define CEREAL_VECTOR_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_variant_wrapper.hpp:92
void Reset()
Reset the layer parameter.
Definition: highway_impl.hpp:70
Highway()
Create the Highway object.
Definition: highway_impl.hpp:30
OutputDataType const & Gradient() const
Get the gradient.
Definition: highway.hpp:171