mlpack
base_layer.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15 
16 #include <mlpack/prereqs.hpp>
33 
34 namespace mlpack {
35 namespace ann {
36 
66 template <
67  class ActivationFunction = LogisticFunction,
68  typename InputDataType = arma::mat,
69  typename OutputDataType = arma::mat
70 >
71 class BaseLayer
72 {
73  public:
78  {
79  // Nothing to do here.
80  }
81 
89  template<typename InputType, typename OutputType>
90  void Forward(const InputType& input, OutputType& output)
91  {
92  ActivationFunction::Fn(input, output);
93  }
94 
104  template<typename eT>
105  void Backward(const arma::Mat<eT>& input,
106  const arma::Mat<eT>& gy,
107  arma::Mat<eT>& g)
108  {
109  arma::Mat<eT> derivative;
110  ActivationFunction::Deriv(input, derivative);
111  g = gy % derivative;
112  }
113 
115  OutputDataType const& OutputParameter() const { return outputParameter; }
117  OutputDataType& OutputParameter() { return outputParameter; }
118 
120  OutputDataType const& Delta() const { return delta; }
122  OutputDataType& Delta() { return delta; }
123 
127  template<typename Archive>
128  void serialize(Archive& /* ar */, const uint32_t /* version */)
129  {
130  /* Nothing to do here */
131  }
132 
133  private:
135  OutputDataType delta;
136 
138  OutputDataType outputParameter;
139 }; // class BaseLayer
140 
141 // Convenience typedefs.
142 
146 template <
147  class ActivationFunction = LogisticFunction,
148  typename InputDataType = arma::mat,
149  typename OutputDataType = arma::mat
150 >
151 using SigmoidLayer = BaseLayer<
152  ActivationFunction, InputDataType, OutputDataType>;
153 
157 template <
158  class ActivationFunction = IdentityFunction,
159  typename InputDataType = arma::mat,
160  typename OutputDataType = arma::mat
161 >
162 using IdentityLayer = BaseLayer<
163  ActivationFunction, InputDataType, OutputDataType>;
164 
168 template <
169  class ActivationFunction = RectifierFunction,
170  typename InputDataType = arma::mat,
171  typename OutputDataType = arma::mat
172 >
173 using ReLULayer = BaseLayer<
174  ActivationFunction, InputDataType, OutputDataType>;
175 
179 template <
180  class ActivationFunction = TanhFunction,
181  typename InputDataType = arma::mat,
182  typename OutputDataType = arma::mat
183 >
184 using TanHLayer = BaseLayer<
185  ActivationFunction, InputDataType, OutputDataType>;
186 
190 template <
191  class ActivationFunction = SoftplusFunction,
192  typename InputDataType = arma::mat,
193  typename OutputDataType = arma::mat
194 >
195 using SoftPlusLayer = BaseLayer<
196  ActivationFunction, InputDataType, OutputDataType>;
197 
201 template <
202  class ActivationFunction = HardSigmoidFunction,
203  typename InputDataType = arma::mat,
204  typename OutputDataType = arma::mat
205 >
207  ActivationFunction, InputDataType, OutputDataType>;
208 
212 template <
213  class ActivationFunction = SwishFunction,
214  typename InputDataType = arma::mat,
215  typename OutputDataType = arma::mat
216 >
218  ActivationFunction, InputDataType, OutputDataType>;
219 
223 template <
224  class ActivationFunction = MishFunction,
225  typename InputDataType = arma::mat,
226  typename OutputDataType = arma::mat
227 >
229  ActivationFunction, InputDataType, OutputDataType>;
230 
234 template <
235  class ActivationFunction = LiSHTFunction,
236  typename InputDataType = arma::mat,
237  typename OutputDataType = arma::mat
238 >
240  ActivationFunction, InputDataType, OutputDataType>;
241 
245 template <
246  class ActivationFunction = GELUFunction,
247  typename InputDataType = arma::mat,
248  typename OutputDataType = arma::mat
249 >
251  ActivationFunction, InputDataType, OutputDataType>;
252 
256 template <
257  class ActivationFunction = ElliotFunction,
258  typename InputDataType = arma::mat,
259  typename OutputDataType = arma::mat
260 >
262  ActivationFunction, InputDataType, OutputDataType>;
263 
267 template <
268  class ActivationFunction = ElishFunction,
269  typename InputDataType = arma::mat,
270  typename OutputDataType = arma::mat
271 >
273  ActivationFunction, InputDataType, OutputDataType>;
274 
278 template <
279  class ActivationFunction = GaussianFunction,
280  typename InputDataType = arma::mat,
281  typename OutputDataType = arma::mat
282 >
284  ActivationFunction, InputDataType, OutputDataType>;
285 
289 template <
290  class ActivationFunction = HardSwishFunction,
291  typename InputDataType = arma::mat,
292  typename OutputDataType = arma::mat
293 >
295  ActivationFunction, InputDataType, OutputDataType>;
296 
300 template <
301  class ActivationFunction = TanhExpFunction,
302  typename InputDataType = arma::mat,
303  typename OutputDataType = arma::mat
304 >
306  ActivationFunction, InputDataType, OutputDataType>;
307 
311 template <
312  class ActivationFunction = SILUFunction,
313  typename InputDataType = arma::mat,
314  typename OutputDataType = arma::mat
315 >
317  ActivationFunction, InputDataType, OutputDataType
318 >;
319 
320 } // namespace ann
321 } // namespace mlpack
322 
323 #endif
The identity function, defined by.
Definition: identity_function.hpp:28
The Hard Swish function, defined by.
Definition: hard_swish_function.hpp:47
The LiSHT function, defined by.
Definition: lisht_function.hpp:42
The tanh function, defined by.
Definition: tanh_function.hpp:29
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: base_layer.hpp:105
The ELiSH function, defined by.
Definition: elish_function.hpp:48
Implementation of the base layer.
Definition: base_layer.hpp:71
The SILU function, defined by.
Definition: silu_function.hpp:43
OutputDataType const & Delta() const
Get the delta.
Definition: base_layer.hpp:120
The Mish function, defined by.
Definition: mish_function.hpp:40
The TanhExp function, defined by.
Definition: tanh_exponential_function.hpp:42
The logistic function, defined by.
Definition: logistic_function.hpp:29
The gaussian function, defined by.
Definition: gaussian_function.hpp:28
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: base_layer.hpp:90
The Elliot function, defined by.
Definition: elliot_function.hpp:40
void serialize(Archive &, const uint32_t)
Serialize the layer.
Definition: base_layer.hpp:128
The swish function, defined by.
Definition: swish_function.hpp:30
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:77
The softplus function, defined by.
Definition: softplus_function.hpp:43
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: base_layer.hpp:115
The hard sigmoid function, defined by.
Definition: hard_sigmoid_function.hpp:34
The GELU function, defined by.
Definition: gelu_function.hpp:31
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: base_layer.hpp:117
The rectifier function, defined by.
Definition: rectifier_function.hpp:45
OutputDataType & Delta()
Modify the delta.
Definition: base_layer.hpp:122