11 #ifndef MLPACK_METHODS_ANN_RBM_RBM_HPP 12 #define MLPACK_METHODS_ANN_RBM_RBM_HPP 34 typename InitializationRuleType,
35 typename DataType = arma::mat,
36 typename PolicyType = BinaryRBM
42 typedef typename DataType::elem_type ElemType;
60 RBM(arma::Mat<ElemType> predictors,
61 InitializationRuleType initializeRule,
62 const size_t visibleSize,
63 const size_t hiddenSize,
64 const size_t batchSize = 1,
65 const size_t numSteps = 1,
66 const size_t negSteps = 1,
67 const size_t poolSize = 2,
68 const ElemType slabPenalty = 8,
69 const ElemType radius = 1,
70 const bool persistence =
false);
73 template<
typename Policy = PolicyType,
typename InputType = DataType>
74 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
78 template<
typename Policy = PolicyType,
typename InputType = DataType>
79 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
97 template<
typename OptimizerType,
typename... CallbackType>
98 double Train(OptimizerType& optimizer, CallbackType&&... callbacks);
108 double Evaluate(
const arma::Mat<ElemType>& parameters,
110 const size_t batchSize);
119 template<
typename Policy = PolicyType,
typename InputType = DataType>
120 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
double>::type
133 template<
typename Policy = PolicyType,
typename InputType = DataType>
134 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
144 template<
typename Policy = PolicyType,
typename InputType = DataType>
145 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
146 Phase(
const InputType& input, DataType& gradient);
154 template<
typename Policy = PolicyType,
typename InputType = DataType>
155 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
156 Phase(
const InputType& input, DataType& gradient);
165 template<
typename Policy = PolicyType,
typename InputType = DataType>
166 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
167 SampleHidden(
const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
179 template<
typename Policy = PolicyType,
typename InputType = DataType>
180 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
181 SampleHidden(
const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
190 template<
typename Policy = PolicyType,
typename InputType = DataType>
191 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
192 SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
204 template<
typename Policy = PolicyType,
typename InputType = DataType>
205 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
206 SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
214 template<
typename Policy = PolicyType,
typename InputType = DataType>
215 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
226 template<
typename Policy = PolicyType,
typename InputType = DataType>
227 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
236 template<
typename Policy = PolicyType,
typename InputType = DataType>
237 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
238 HiddenMean(
const InputType& input, DataType& output);
250 template<
typename Policy = PolicyType,
typename InputType = DataType>
251 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
252 HiddenMean(
const InputType& input, DataType& output);
262 template<
typename Policy = PolicyType,
typename InputType = DataType>
263 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
264 SpikeMean(
const InputType& visible, DataType& spikeMean);
271 template<
typename Policy = PolicyType,
typename InputType = DataType>
272 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
273 SampleSpike(InputType& spikeMean, DataType& spike);
284 template<
typename Policy = PolicyType,
typename InputType = DataType>
285 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
286 SlabMean(
const DataType& visible, DataType& spike, DataType& slabMean);
298 template<
typename Policy = PolicyType,
typename InputType = DataType>
299 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
300 SampleSlab(InputType& slabMean, DataType& slab);
309 void Gibbs(
const arma::Mat<ElemType>& input,
310 arma::Mat<ElemType>& output,
311 const size_t steps = SIZE_MAX);
321 void Gradient(
const arma::Mat<ElemType>& parameters,
323 arma::Mat<ElemType>& gradient,
324 const size_t batchSize);
339 const arma::Mat<ElemType>&
Parameters()
const {
return parameter; }
344 arma::Cube<ElemType>
const&
Weight()
const {
return weight; }
346 arma::Cube<ElemType>&
Weight() {
return weight; }
376 size_t const&
PoolSize()
const {
return poolSize; }
379 template<
typename Archive>
380 void serialize(Archive& ar,
const uint32_t version);
384 arma::Mat<ElemType> parameter;
386 arma::Mat<ElemType> predictors;
388 InitializationRuleType initializeRule;
390 arma::Mat<ElemType> state;
408 arma::Cube<ElemType> weight;
410 DataType visibleBias;
414 DataType preActivation;
418 DataType visiblePenalty;
420 DataType visibleMean;
424 DataType spikeSamples;
428 ElemType slabPenalty;
432 arma::Mat<ElemType> hiddenReconstruction;
434 arma::Mat<ElemType> visibleReconstruction;
436 arma::Mat<ElemType> negativeSamples;
438 arma::Mat<ElemType> negativeGradient;
440 arma::Mat<ElemType> tempNegativeGradient;
442 arma::Mat<ElemType> positiveGradient;
444 arma::Mat<ElemType> gibbsTemporary;
DataType & SpikeBias()
Modify the regularizer associated with spike variables.
Definition: rbm.hpp:361
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the hidden layer given the visible layer using Bernoulli function.
Definition: rbm_impl.hpp:166
size_t const & HiddenSize() const
Get the hidden size.
Definition: rbm.hpp:374
DataType & VisibleBias()
Modify the visible bias of the network.
Definition: rbm.hpp:351
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the BinaryRBM.
Definition: rbm_impl.hpp:113
arma::Mat< ElemType > & Parameters()
Modify the parameters of the network.
Definition: rbm.hpp:341
DataType & HiddenBias()
Modify the hidden bias of the network.
Definition: rbm.hpp:356
RBM(arma::Mat< ElemType > predictors, InitializationRuleType initializeRule, const size_t visibleSize, const size_t hiddenSize, const size_t batchSize=1, const size_t numSteps=1, const size_t negSteps=1, const size_t poolSize=2, const ElemType slabPenalty=8, const ElemType radius=1, const bool persistence=false)
Initialize all the parameters of the network using initializeRule.
Definition: rbm_impl.hpp:27
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Gibbs(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output, const size_t steps=SIZE_MAX)
This function does the k-step Gibbs Sampling.
Definition: rbm_impl.hpp:234
DataType const & VisibleBias() const
Return the visible bias of the network.
Definition: rbm.hpp:349
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SlabMean(const DataType &visible, DataType &spike, DataType &slabMean)
The function calculates the mean of Normal distribution of P(s|v, h), where the mean is given by: ...
Definition: spike_slab_rbm_impl.hpp:272
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean for the hidden layer.
Definition: rbm_impl.hpp:220
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: rbm.hpp:333
ElemType const & SlabPenalty() const
Get the regularizer associated with slab variables.
Definition: rbm.hpp:364
size_t const & VisibleSize() const
Get the visible size.
Definition: rbm.hpp:372
The implementation of the RBM module.
Definition: rbm.hpp:38
void Gradient(const arma::Mat< ElemType > ¶meters, const size_t i, arma::Mat< ElemType > &gradient, const size_t batchSize)
Calculates the gradients for the RBM network.
Definition: rbm_impl.hpp:268
double Train(OptimizerType &optimizer, CallbackType &&... callbacks)
Train the RBM on the given input data.
Definition: rbm_impl.hpp:95
DataType & VisiblePenalty()
Modify the regularizer associated with visible variables.
Definition: rbm.hpp:369
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input.
Definition: rbm_impl.hpp:129
void serialize(Archive &ar, const uint32_t version)
Serialize the model.
Definition: rbm_impl.hpp:309
const arma::Mat< ElemType > & Parameters() const
Return the parameters of the network.
Definition: rbm.hpp:339
size_t NumSteps() const
Return the number of steps of Gibbs Sampling.
Definition: rbm.hpp:336
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSlab(InputType &slabMean, DataType &slab)
The function samples from the Normal distribution of P(s|v, h), where the mean is given by: and vari...
Definition: spike_slab_rbm_impl.hpp:291
arma::Cube< ElemType > & Weight()
Modify the weights of the network.
Definition: rbm.hpp:346
double Evaluate(const arma::Mat< ElemType > ¶meters, const size_t i, const size_t batchSize)
Evaluate the RBM network with the given parameters.
Definition: rbm_impl.hpp:148
DataType const & VisiblePenalty() const
Get the regularizer associated with visible variables.
Definition: rbm.hpp:367
void Shuffle()
Shuffle the order of function visitation.
Definition: rbm_impl.hpp:297
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SpikeMean(const InputType &visible, DataType &spikeMean)
The function calculates the mean of the distribution P(h|v), where mean is given by: ...
Definition: spike_slab_rbm_impl.hpp:236
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSpike(InputType &spikeMean, DataType &spike)
The function samples the spike function using Bernoulli distribution.
Definition: spike_slab_rbm_impl.hpp:255
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean for the visible layer.
Definition: rbm_impl.hpp:204
size_t const & PoolSize() const
Get the pool size.
Definition: rbm.hpp:376
DataType const & SpikeBias() const
Get the regularizer associated with spike variables.
Definition: rbm.hpp:359
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the visible layer given the hidden layer using Bernoulli function.
Definition: rbm_impl.hpp:185
arma::Cube< ElemType > const & Weight() const
Get the weights of the network.
Definition: rbm.hpp:344
DataType const & HiddenBias() const
Return the hidden bias of the network.
Definition: rbm.hpp:354