mlpack
rbm_impl.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_RBM_RBM_IMPL_HPP
12 #define MLPACK_METHODS_ANN_RBM_RBM_IMPL_HPP
13 
14 // In case it hasn't been included yet.
15 #include "rbm.hpp"
16 
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<
23  typename InitializationRuleType,
24  typename DataType,
25  typename PolicyType
26 >
28  arma::Mat<ElemType> predictors,
29  InitializationRuleType initializeRule,
30  const size_t visibleSize,
31  const size_t hiddenSize,
32  const size_t batchSize,
33  const size_t numSteps,
34  const size_t negSteps,
35  const size_t poolSize,
36  const ElemType slabPenalty,
37  const ElemType radius,
38  const bool persistence) :
39  predictors(std::move(predictors)),
40  initializeRule(initializeRule),
41  visibleSize(visibleSize),
42  hiddenSize(hiddenSize),
43  batchSize(batchSize),
44  numSteps(numSteps),
45  negSteps(negSteps),
46  poolSize(poolSize),
47  steps(0),
48  slabPenalty(slabPenalty),
49  radius(2 * radius),
50  persistence(persistence),
51  reset(false)
52 {
53  numFunctions = this->predictors.n_cols;
54 }
55 
56 template<
57  typename InitializationRuleType,
58  typename DataType,
59  typename PolicyType
60 >
61 template<typename Policy, typename InputType>
62 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
64 {
65  size_t shape = (visibleSize * hiddenSize) + visibleSize + hiddenSize;
66 
67  parameter.set_size(shape, 1);
68  positiveGradient.set_size(shape, 1);
69  negativeGradient.set_size(shape, 1);
70  tempNegativeGradient.set_size(shape, 1);
71  negativeSamples.set_size(visibleSize, batchSize);
72 
73  weight = arma::Cube<ElemType>(parameter.memptr(), hiddenSize, visibleSize, 1,
74  false, false);
75  hiddenBias = DataType(parameter.memptr() + weight.n_elem,
76  hiddenSize, 1, false, false);
77  visibleBias = DataType(parameter.memptr() + weight.n_elem +
78  hiddenBias.n_elem, visibleSize, 1, false, false);
79 
80  parameter.zeros();
81  positiveGradient.zeros();
82  negativeGradient.zeros();
83  tempNegativeGradient.zeros();
84  initializeRule.Initialize(parameter, parameter.n_elem, 1);
85 
86  reset = true;
87 }
88 
89 template<
90  typename InitializationRuleType,
91  typename DataType,
92  typename PolicyType
93 >
94 template<typename OptimizerType, typename... CallbackType>
96  OptimizerType& optimizer, CallbackType&&... callbacks)
97 {
98  if (!reset)
99  {
100  Reset();
101  }
102 
103  return optimizer.Optimize(*this, parameter, callbacks...);
104 }
105 
106 template<
107  typename InitializationRuleType,
108  typename DataType,
109  typename PolicyType
110 >
111 template<typename Policy, typename InputType>
112 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, double>::type
114  const arma::Mat<ElemType>& input)
115 {
116  preActivation = (weight.slice(0) * input);
117  preActivation.each_col() += hiddenBias;
118  return -(arma::accu(arma::log(1 + arma::trunc_exp(preActivation))) +
119  arma::dot(input, arma::repmat(visibleBias, 1, input.n_cols)));
120 }
121 
122 template<
123  typename InitializationRuleType,
124  typename DataType,
125  typename PolicyType
126 >
127 template<typename Policy, typename InputType>
128 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
130  const InputType& input,
131  DataType& gradient)
132 {
133  arma::Cube<ElemType> weightGrad = arma::Cube<ElemType>(gradient.memptr(),
134  hiddenSize, visibleSize, 1, false, false);
135 
136  DataType hiddenBiasGrad = DataType(gradient.memptr() + weightGrad.n_elem,
137  hiddenSize, 1, false, false);
138 
139  HiddenMean(input, hiddenBiasGrad);
140  weightGrad.slice(0) = hiddenBiasGrad * input.t();
141 }
142 
143 template<
144  typename InitializationRuleType,
145  typename DataType,
146  typename PolicyType
147 >
149  const arma::Mat<ElemType>& /* parameters*/,
150  const size_t i,
151  const size_t batchSize)
152 {
153  Gibbs(predictors.cols(i, i + batchSize - 1),
154  negativeSamples);
155  return std::fabs(FreeEnergy(predictors.cols(i,
156  i + batchSize - 1)) - FreeEnergy(negativeSamples));
157 }
158 
159 template<
160  typename InitializationRuleType,
161  typename DataType,
162  typename PolicyType
163 >
164 template<typename Policy, typename InputType>
165 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
167  const arma::Mat<ElemType>& input,
168  arma::Mat<ElemType>& output)
169 {
170  HiddenMean(input, output);
171 
172  for (size_t i = 0; i < output.n_elem; ++i)
173  {
174  output(i) = math::RandBernoulli(output(i));
175  }
176 }
177 
178 template<
179  typename InitializationRuleType,
180  typename DataType,
181  typename PolicyType
182 >
183 template<typename Policy, typename InputType>
184 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
186  arma::Mat<ElemType>& input,
187  arma::Mat<ElemType>& output)
188 {
189  VisibleMean(input, output);
190 
191  for (size_t i = 0; i < output.n_elem; ++i)
192  {
193  output(i) = math::RandBernoulli(output(i));
194  }
195 }
196 
197 template<
198  typename InitializationRuleType,
199  typename DataType,
200  typename PolicyType
201 >
202 template<typename Policy, typename InputType>
203 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
205  InputType& input,
206  DataType& output)
207 {
208  output = weight.slice(0).t() * input;
209  output.each_col() += visibleBias;
210  LogisticFunction::Fn(output, output);
211 }
212 
213 template<
214  typename InitializationRuleType,
215  typename DataType,
216  typename PolicyType
217 >
218 template<typename Policy, typename InputType>
219 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
221  const InputType& input,
222  DataType& output)
223 {
224  output = weight.slice(0) * input;
225  output.each_col() += hiddenBias;
226  LogisticFunction::Fn(output, output);
227 }
228 
229 template<
230  typename InitializationRuleType,
231  typename DataType,
232  typename PolicyType
233 >
235  const arma::Mat<ElemType>& input,
236  arma::Mat<ElemType>& output,
237  const size_t steps)
238 {
239  this->steps = (steps == SIZE_MAX) ? this->numSteps : steps;
240 
241  if (persistence && !state.is_empty())
242  {
243  SampleHidden(state, gibbsTemporary);
244  SampleVisible(gibbsTemporary, output);
245  }
246  else
247  {
248  SampleHidden(input, gibbsTemporary);
249  SampleVisible(gibbsTemporary, output);
250  }
251 
252  for (size_t j = 1; j < this->steps; ++j)
253  {
254  SampleHidden(output, gibbsTemporary);
255  SampleVisible(gibbsTemporary, output);
256  }
257  if (persistence)
258  {
259  state = output;
260  }
261 }
262 
263 template<
264  typename InitializationRuleType,
265  typename DataType,
266  typename PolicyType
267 >
269  const arma::Mat<ElemType>& /*parameters*/,
270  const size_t i,
271  arma::Mat<ElemType>& gradient,
272  const size_t batchSize)
273 {
274  positiveGradient.zeros();
275  negativeGradient.zeros();
276 
277  Phase(predictors.cols(i, i + batchSize - 1),
278  positiveGradient);
279 
280  for (size_t i = 0; i < negSteps; ++i)
281  {
282  Gibbs(predictors.cols(i, i + batchSize - 1),
283  negativeSamples);
284  Phase(negativeSamples, tempNegativeGradient);
285 
286  negativeGradient += tempNegativeGradient;
287  }
288 
289  gradient = ((negativeGradient / negSteps) - positiveGradient);
290 }
291 
292 template<
293  typename InitializationRuleType,
294  typename DataType,
295  typename PolicyType
296 >
298 {
299  predictors = predictors.cols(arma::shuffle(arma::linspace<arma::uvec>(0,
300  predictors.n_cols - 1, predictors.n_cols)));
301 }
302 
303 template<
304  typename InitializationRuleType,
305  typename DataType,
306  typename PolicyType
307 >
308 template<typename Archive>
310  Archive& ar, const uint32_t /* version */)
311 {
312  ar(CEREAL_NVP(parameter));
313  ar(CEREAL_NVP(visibleSize));
314  ar(CEREAL_NVP(hiddenSize));
315  ar(CEREAL_NVP(state));
316  ar(CEREAL_NVP(numFunctions));
317  ar(CEREAL_NVP(numSteps));
318  ar(CEREAL_NVP(negSteps));
319  ar(CEREAL_NVP(persistence));
320  ar(CEREAL_NVP(poolSize));
321  ar(CEREAL_NVP(visibleBias));
322  ar(CEREAL_NVP(hiddenBias));
323  ar(CEREAL_NVP(weight));
324  ar(CEREAL_NVP(spikeBias));
325  ar(CEREAL_NVP(slabPenalty));
326  ar(CEREAL_NVP(radius));
327  ar(CEREAL_NVP(visiblePenalty));
328 
329  // If we are loading, we need to initialize the weights.
330  if (cereal::is_loading<Archive>())
331  {
332  size_t shape = parameter.n_elem;
333  positiveGradient.set_size(shape, 1);
334  negativeGradient.set_size(shape, 1);
335  negativeSamples.set_size(visibleSize, batchSize);
336  tempNegativeGradient.set_size(shape, 1);
337  spikeMean.set_size(hiddenSize, 1);
338  spikeSamples.set_size(hiddenSize, 1);
339  slabMean.set_size(poolSize, hiddenSize);
340  positiveGradient.zeros();
341  negativeGradient.zeros();
342  tempNegativeGradient.zeros();
343  reset = true;
344  }
345 }
346 
347 } // namespace ann
348 } // namespace mlpack
349 #endif
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the hidden layer given the visible layer using Bernoulli function.
Definition: rbm_impl.hpp:166
double RandBernoulli(const double input)
Generates a 0/1 specified by the input.
Definition: random.hpp:99
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the BinaryRBM.
Definition: rbm_impl.hpp:113
RBM(arma::Mat< ElemType > predictors, InitializationRuleType initializeRule, const size_t visibleSize, const size_t hiddenSize, const size_t batchSize=1, const size_t numSteps=1, const size_t negSteps=1, const size_t poolSize=2, const ElemType slabPenalty=8, const ElemType radius=1, const bool persistence=false)
Initialize all the parameters of the network using initializeRule.
Definition: rbm_impl.hpp:27
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Gibbs(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output, const size_t steps=SIZE_MAX)
This function does the k-step Gibbs Sampling.
Definition: rbm_impl.hpp:234
Definition: pointer_wrapper.hpp:23
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean for the hidden layer.
Definition: rbm_impl.hpp:220
The implementation of the RBM module.
Definition: rbm.hpp:38
void Gradient(const arma::Mat< ElemType > &parameters, const size_t i, arma::Mat< ElemType > &gradient, const size_t batchSize)
Calculates the gradients for the RBM network.
Definition: rbm_impl.hpp:268
double Train(OptimizerType &optimizer, CallbackType &&... callbacks)
Train the RBM on the given input data.
Definition: rbm_impl.hpp:95
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input.
Definition: rbm_impl.hpp:129
void serialize(Archive &ar, const uint32_t version)
Serialize the model.
Definition: rbm_impl.hpp:309
static double Fn(const eT x)
Computes the logistic function.
Definition: logistic_function.hpp:39
double Evaluate(const arma::Mat< ElemType > &parameters, const size_t i, const size_t batchSize)
Evaluate the RBM network with the given parameters.
Definition: rbm_impl.hpp:148
void Shuffle()
Shuffle the order of function visitation.
Definition: rbm_impl.hpp:297
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean for the visible layer.
Definition: rbm_impl.hpp:204
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the visible layer given the hidden layer using Bernoulli function.
Definition: rbm_impl.hpp:185