11 #ifndef MLPACK_METHODS_ANN_RBM_RBM_IMPL_HPP 12 #define MLPACK_METHODS_ANN_RBM_RBM_IMPL_HPP 23 typename InitializationRuleType,
28 arma::Mat<ElemType> predictors,
29 InitializationRuleType initializeRule,
30 const size_t visibleSize,
31 const size_t hiddenSize,
32 const size_t batchSize,
33 const size_t numSteps,
34 const size_t negSteps,
35 const size_t poolSize,
36 const ElemType slabPenalty,
37 const ElemType radius,
38 const bool persistence) :
39 predictors(
std::move(predictors)),
40 initializeRule(initializeRule),
41 visibleSize(visibleSize),
42 hiddenSize(hiddenSize),
48 slabPenalty(slabPenalty),
50 persistence(persistence),
53 numFunctions = this->predictors.n_cols;
57 typename InitializationRuleType,
61 template<
typename Policy,
typename InputType>
62 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
65 size_t shape = (visibleSize * hiddenSize) + visibleSize + hiddenSize;
67 parameter.set_size(shape, 1);
68 positiveGradient.set_size(shape, 1);
69 negativeGradient.set_size(shape, 1);
70 tempNegativeGradient.set_size(shape, 1);
71 negativeSamples.set_size(visibleSize, batchSize);
73 weight = arma::Cube<ElemType>(parameter.memptr(), hiddenSize, visibleSize, 1,
75 hiddenBias = DataType(parameter.memptr() + weight.n_elem,
76 hiddenSize, 1,
false,
false);
77 visibleBias = DataType(parameter.memptr() + weight.n_elem +
78 hiddenBias.n_elem, visibleSize, 1,
false,
false);
81 positiveGradient.zeros();
82 negativeGradient.zeros();
83 tempNegativeGradient.zeros();
84 initializeRule.Initialize(parameter, parameter.n_elem, 1);
90 typename InitializationRuleType,
94 template<
typename OptimizerType,
typename... CallbackType>
96 OptimizerType& optimizer, CallbackType&&... callbacks)
103 return optimizer.Optimize(*
this, parameter, callbacks...);
107 typename InitializationRuleType,
111 template<
typename Policy,
typename InputType>
112 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
double>::type
114 const arma::Mat<ElemType>& input)
116 preActivation = (weight.slice(0) * input);
117 preActivation.each_col() += hiddenBias;
118 return -(arma::accu(arma::log(1 + arma::trunc_exp(preActivation))) +
119 arma::dot(input, arma::repmat(visibleBias, 1, input.n_cols)));
123 typename InitializationRuleType,
127 template<
typename Policy,
typename InputType>
128 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
130 const InputType& input,
133 arma::Cube<ElemType> weightGrad = arma::Cube<ElemType>(gradient.memptr(),
134 hiddenSize, visibleSize, 1,
false,
false);
136 DataType hiddenBiasGrad = DataType(gradient.memptr() + weightGrad.n_elem,
137 hiddenSize, 1,
false,
false);
140 weightGrad.slice(0) = hiddenBiasGrad * input.t();
144 typename InitializationRuleType,
149 const arma::Mat<ElemType>& ,
151 const size_t batchSize)
153 Gibbs(predictors.cols(i, i + batchSize - 1),
155 return std::fabs(
FreeEnergy(predictors.cols(i,
156 i + batchSize - 1)) -
FreeEnergy(negativeSamples));
160 typename InitializationRuleType,
164 template<
typename Policy,
typename InputType>
165 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
167 const arma::Mat<ElemType>& input,
168 arma::Mat<ElemType>& output)
172 for (
size_t i = 0; i < output.n_elem; ++i)
179 typename InitializationRuleType,
183 template<
typename Policy,
typename InputType>
184 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
186 arma::Mat<ElemType>& input,
187 arma::Mat<ElemType>& output)
191 for (
size_t i = 0; i < output.n_elem; ++i)
198 typename InitializationRuleType,
202 template<
typename Policy,
typename InputType>
203 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
208 output = weight.slice(0).t() * input;
209 output.each_col() += visibleBias;
214 typename InitializationRuleType,
218 template<
typename Policy,
typename InputType>
219 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
221 const InputType& input,
224 output = weight.slice(0) * input;
225 output.each_col() += hiddenBias;
230 typename InitializationRuleType,
235 const arma::Mat<ElemType>& input,
236 arma::Mat<ElemType>& output,
239 this->steps = (steps == SIZE_MAX) ? this->numSteps : steps;
241 if (persistence && !state.is_empty())
252 for (
size_t j = 1; j < this->steps; ++j)
264 typename InitializationRuleType,
269 const arma::Mat<ElemType>& ,
271 arma::Mat<ElemType>& gradient,
272 const size_t batchSize)
274 positiveGradient.zeros();
275 negativeGradient.zeros();
277 Phase(predictors.cols(i, i + batchSize - 1),
280 for (
size_t i = 0; i < negSteps; ++i)
282 Gibbs(predictors.cols(i, i + batchSize - 1),
284 Phase(negativeSamples, tempNegativeGradient);
286 negativeGradient += tempNegativeGradient;
289 gradient = ((negativeGradient / negSteps) - positiveGradient);
293 typename InitializationRuleType,
299 predictors = predictors.cols(arma::shuffle(arma::linspace<arma::uvec>(0,
300 predictors.n_cols - 1, predictors.n_cols)));
304 typename InitializationRuleType,
308 template<
typename Archive>
310 Archive& ar,
const uint32_t )
312 ar(CEREAL_NVP(parameter));
313 ar(CEREAL_NVP(visibleSize));
314 ar(CEREAL_NVP(hiddenSize));
315 ar(CEREAL_NVP(state));
316 ar(CEREAL_NVP(numFunctions));
317 ar(CEREAL_NVP(numSteps));
318 ar(CEREAL_NVP(negSteps));
319 ar(CEREAL_NVP(persistence));
320 ar(CEREAL_NVP(poolSize));
321 ar(CEREAL_NVP(visibleBias));
322 ar(CEREAL_NVP(hiddenBias));
323 ar(CEREAL_NVP(weight));
324 ar(CEREAL_NVP(spikeBias));
325 ar(CEREAL_NVP(slabPenalty));
326 ar(CEREAL_NVP(radius));
327 ar(CEREAL_NVP(visiblePenalty));
330 if (cereal::is_loading<Archive>())
332 size_t shape = parameter.n_elem;
333 positiveGradient.set_size(shape, 1);
334 negativeGradient.set_size(shape, 1);
335 negativeSamples.set_size(visibleSize, batchSize);
336 tempNegativeGradient.set_size(shape, 1);
337 spikeMean.set_size(hiddenSize, 1);
338 spikeSamples.set_size(hiddenSize, 1);
339 slabMean.set_size(poolSize, hiddenSize);
340 positiveGradient.zeros();
341 negativeGradient.zeros();
342 tempNegativeGradient.zeros();
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the hidden layer given the visible layer using Bernoulli function.
Definition: rbm_impl.hpp:166
double RandBernoulli(const double input)
Generates a 0/1 specified by the input.
Definition: random.hpp:99
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the BinaryRBM.
Definition: rbm_impl.hpp:113
RBM(arma::Mat< ElemType > predictors, InitializationRuleType initializeRule, const size_t visibleSize, const size_t hiddenSize, const size_t batchSize=1, const size_t numSteps=1, const size_t negSteps=1, const size_t poolSize=2, const ElemType slabPenalty=8, const ElemType radius=1, const bool persistence=false)
Initialize all the parameters of the network using initializeRule.
Definition: rbm_impl.hpp:27
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Gibbs(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output, const size_t steps=SIZE_MAX)
This function does the k-step Gibbs Sampling.
Definition: rbm_impl.hpp:234
Definition: pointer_wrapper.hpp:23
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean for the hidden layer.
Definition: rbm_impl.hpp:220
The implementation of the RBM module.
Definition: rbm.hpp:38
void Gradient(const arma::Mat< ElemType > ¶meters, const size_t i, arma::Mat< ElemType > &gradient, const size_t batchSize)
Calculates the gradients for the RBM network.
Definition: rbm_impl.hpp:268
double Train(OptimizerType &optimizer, CallbackType &&... callbacks)
Train the RBM on the given input data.
Definition: rbm_impl.hpp:95
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input.
Definition: rbm_impl.hpp:129
void serialize(Archive &ar, const uint32_t version)
Serialize the model.
Definition: rbm_impl.hpp:309
static double Fn(const eT x)
Computes the logistic function.
Definition: logistic_function.hpp:39
double Evaluate(const arma::Mat< ElemType > ¶meters, const size_t i, const size_t batchSize)
Evaluate the RBM network with the given parameters.
Definition: rbm_impl.hpp:148
void Shuffle()
Shuffle the order of function visitation.
Definition: rbm_impl.hpp:297
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean for the visible layer.
Definition: rbm_impl.hpp:204
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the visible layer given the hidden layer using Bernoulli function.
Definition: rbm_impl.hpp:185