mlpack
radial_basis_function_impl.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_LAYER_RBF_IMPL_HPP
12 #define MLPACK_METHODS_ANN_LAYER_RBF_IMPL_HPP
13 
14 // In case it hasn't yet been included.
16 
17 namespace mlpack {
18 namespace ann {
19 
20 template<typename InputDataType, typename OutputDataType,
21  typename Activation>
23  inSize(0),
24  outSize(0),
25  sigmas(0),
26  betas(0)
27 {
28  // Nothing to do here.
29 }
30 
31 template<typename InputDataType, typename OutputDataType,
32  typename Activation>
34  const size_t inSize,
35  const size_t outSize,
36  arma::mat& centres,
37  double betas) :
38  inSize(inSize),
39  outSize(outSize),
40  betas(betas),
41  centres(centres)
42 {
43  sigmas = 0;
44  if (betas == 0)
45  {
46  for (size_t i = 0; i < centres.n_cols; i++)
47  {
48  double max_dis = 0;
49  arma::mat temp = centres.each_col() - centres.col(i);
50  max_dis = arma::accu(arma::max(arma::pow(arma::sum(
51  arma::pow((temp), 2), 0), 0.5).t()));
52  if (max_dis > sigmas)
53  sigmas = max_dis;
54  }
55  this->betas = std::pow(2 * outSize, 0.5) / sigmas;
56  }
57 }
58 
59 template<typename InputDataType, typename OutputDataType,
60  typename Activation>
61 template<typename eT>
63  const arma::Mat<eT>& input,
64  arma::Mat<eT>& output)
65 {
66  distances = arma::mat(outSize, input.n_cols);
67 
68  for (size_t i = 0; i < input.n_cols; i++)
69  {
70  arma::mat temp = centres.each_col() - input.col(i);
71  distances.col(i) = arma::pow(arma::sum(
72  arma::pow((temp), 2), 0), 0.5).t();
73  }
74  Activation::Fn(distances * std::pow(betas, 0.5),
75  output);
76 }
77 
78 
79 template<typename InputDataType, typename OutputDataType,
80  typename Activation>
81 template<typename eT>
83  const arma::Mat<eT>& /* input */,
84  const arma::Mat<eT>& /* gy */,
85  arma::Mat<eT>& /* g */)
86 {
87  // Nothing to do here.
88 }
89 
90 template<typename InputDataType, typename OutputDataType,
91  typename Activation>
92 template<typename Archive>
94  Archive& ar,
95  const uint32_t /* version */)
96 {
97  ar(CEREAL_NVP(distances));
98  ar(CEREAL_NVP(centres));
99 }
100 
101 } // namespace ann
102 } // namespace mlpack
103 
104 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: radial_basis_function_impl.hpp:93
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &, arma::Mat< eT > &)
Ordinary feed backward pass of the radial basis function.
Definition: radial_basis_function_impl.hpp:82
RBF()
Create the RBF object.
Definition: radial_basis_function_impl.hpp:22
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of the radial basis function.
Definition: radial_basis_function_impl.hpp:62