11 #ifndef MLPACK_METHODS_ANN_LAYER_RBF_IMPL_HPP 12 #define MLPACK_METHODS_ANN_LAYER_RBF_IMPL_HPP 20 template<
typename InputDataType,
typename OutputDataType,
31 template<
typename InputDataType,
typename OutputDataType,
46 for (
size_t i = 0; i < centres.n_cols; i++)
49 arma::mat temp = centres.each_col() - centres.col(i);
50 max_dis = arma::accu(arma::max(arma::pow(arma::sum(
51 arma::pow((temp), 2), 0), 0.5).t()));
55 this->betas = std::pow(2 * outSize, 0.5) / sigmas;
59 template<
typename InputDataType,
typename OutputDataType,
63 const arma::Mat<eT>& input,
64 arma::Mat<eT>& output)
66 distances = arma::mat(outSize, input.n_cols);
68 for (
size_t i = 0; i < input.n_cols; i++)
70 arma::mat temp = centres.each_col() - input.col(i);
71 distances.col(i) = arma::pow(arma::sum(
72 arma::pow((temp), 2), 0), 0.5).t();
74 Activation::Fn(distances * std::pow(betas, 0.5),
79 template<
typename InputDataType,
typename OutputDataType,
83 const arma::Mat<eT>& ,
84 const arma::Mat<eT>& ,
90 template<
typename InputDataType,
typename OutputDataType,
92 template<
typename Archive>
97 ar(CEREAL_NVP(distances));
98 ar(CEREAL_NVP(centres));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: radial_basis_function_impl.hpp:93
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &, arma::Mat< eT > &)
Ordinary feed backward pass of the radial basis function.
Definition: radial_basis_function_impl.hpp:82
RBF()
Create the RBF object.
Definition: radial_basis_function_impl.hpp:22
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of the radial basis function.
Definition: radial_basis_function_impl.hpp:62