mlpack
spike_slab_rbm_impl.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_RBM_SPIKE_SLAB_RBM_IMPL_HPP
12 #define MLPACK_METHODS_ANN_RBM_SPIKE_SLAB_RBM_IMPL_HPP
13 
14 #include "rbm.hpp"
15 
18 
19 
20 namespace mlpack {
21 namespace ann {
22 
23 template<
24  typename InitializationRuleType,
25  typename DataType,
26  typename PolicyType
27 >
28 template<typename Policy, typename InputType>
29 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
30 RBM<InitializationRuleType, DataType, PolicyType>::Reset()
31 {
32  size_t shape = (visibleSize * hiddenSize * poolSize) + visibleSize +
33  hiddenSize;
34  parameter.set_size(shape, 1);
35  positiveGradient.set_size(shape, 1);
36  negativeGradient.set_size(shape, 1);
37  tempNegativeGradient.set_size(shape, 1);
38  negativeSamples.set_size(visibleSize, batchSize);
39  visibleMean.set_size(visibleSize, 1);
40  spikeMean.set_size(hiddenSize, 1);
41  spikeSamples.set_size(hiddenSize, 1);
42  slabMean.set_size(poolSize, hiddenSize);
43 
44  // Weight shape D * K * N
45  weight = arma::Cube<ElemType>(parameter.memptr(), visibleSize, poolSize,
46  hiddenSize, false, false);
47  // Spike bias shape N * 1
48  spikeBias = DataType(parameter.memptr() + weight.n_elem, hiddenSize, 1,
49  false, false);
50  // Visible penalty 1 * 1 => D * D(when used)
51  visiblePenalty = DataType(parameter.memptr() + weight.n_elem +
52  spikeBias.n_elem, 1, 1, false, false);
53 
54  parameter.zeros();
55  positiveGradient.zeros();
56  negativeGradient.zeros();
57  tempNegativeGradient.zeros();
58  initializeRule.Initialize(parameter, parameter.n_elem, 1);
59 
60  reset = true;
61 }
62 
63 template<
64  typename InitializationRuleType,
65  typename DataType,
66  typename PolicyType
67 >
68 template<typename Policy, typename InputType>
69 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, double>::type
71  const arma::Mat<ElemType>& input)
72 {
73  ElemType freeEnergy = 0.5 * visiblePenalty(0) * arma::dot(input, input);
74 
75  freeEnergy -= 0.5 * hiddenSize * poolSize *
76  std::log((2.0 * M_PI) / slabPenalty);
77 
78  for (size_t i = 0; i < hiddenSize; ++i)
79  {
80  ElemType sum = arma::accu(arma::square(input.t() * weight.slice(i))) /
81  (2.0 * slabPenalty);
82  freeEnergy -= SoftplusFunction::Fn(spikeBias(i) - sum);
83  }
84 
85  return freeEnergy;
86 }
87 
88 template<
89  typename InitializationRuleType,
90  typename DataType,
91  typename PolicyType
92 >
93 template<typename Policy, typename InputType>
94 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
96  const InputType& input,
97  DataType& gradient)
98 {
99  arma::Cube<ElemType> weightGrad = arma::Cube<ElemType>
100  (gradient.memptr(), visibleSize, poolSize, hiddenSize, false, false);
101 
102  DataType spikeBiasGrad = DataType(gradient.memptr() + weightGrad.n_elem,
103  hiddenSize, 1, false, false);
104 
105  SpikeMean(input, spikeMean);
106  SampleSpike(spikeMean, spikeSamples);
107  SlabMean(input, spikeSamples, slabMean);
108 
109  for (size_t i = 0 ; i < hiddenSize; ++i)
110  {
111  weightGrad.slice(i) = input * arma::repmat(slabMean.col(i).t(),
112  input.n_cols, 1) * spikeMean(i);
113  }
114 
115  spikeBiasGrad = spikeMean;
116  // Setting visiblePenaltyGrad.
117  gradient.row(weightGrad.n_elem + spikeBiasGrad.n_elem) = -0.5 * arma::dot(
118  input, input) / std::pow(input.n_cols, 2);
119 }
120 
121 template<
122  typename InitializationRuleType,
123  typename DataType,
124  typename PolicyType
125 >
126 template<typename Policy, typename InputType>
127 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
129  const arma::Mat<ElemType>& input,
130  arma::Mat<ElemType>& output)
131 {
132  output.set_size(hiddenSize + poolSize * hiddenSize, 1);
133 
134  DataType spike(output.memptr(), hiddenSize, 1, false, false);
135  DataType slab(output.memptr() + hiddenSize, poolSize, hiddenSize, false,
136  false);
137 
138  SpikeMean(input, spike);
139  SampleSpike(spike, spike);
140  SlabMean(input, spike, slab);
141  SampleSlab(slab, slab);
142 }
143 
144 template<
145  typename InitializationRuleType,
146  typename DataType,
147  typename PolicyType
148 >
149 template<typename Policy, typename InputType>
150 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
152  arma::Mat<ElemType>& input,
153  arma::Mat<ElemType>& output)
154 {
155  const size_t numMaxTrials = 10;
156  size_t k = 0;
157 
158  VisibleMean(input, visibleMean);
159  output.set_size(visibleSize, 1);
160 
161  for (k = 0; k < numMaxTrials; ++k)
162  {
163  for (size_t i = 0; i < visibleSize; ++i)
164  {
165  output(i) = math::RandNormal(visibleMean(i), 1.0 / visiblePenalty(0));
166  }
167  if (arma::norm(output, 2) < radius)
168  {
169  break;
170  }
171  }
172 
173  if (k == numMaxTrials)
174  {
175  Log::Warn << "Outputs are still not in visible unit "
176  << arma::norm(output, 2)
177  << " terminating optimization."
178  << std::endl;
179  }
180 }
181 
182 template<
183  typename InitializationRuleType,
184  typename DataType,
185  typename PolicyType
186 >
187 template<typename Policy, typename InputType>
188 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
190  InputType& input,
191  DataType& output)
192 {
193  output.zeros(visibleSize, 1);
194 
195  DataType spike(input.memptr(), hiddenSize, 1, false, false);
196  DataType slab(input.memptr() + hiddenSize, poolSize, hiddenSize, false,
197  false);
198 
199  for (size_t i = 0; i < hiddenSize; ++i)
200  {
201  output += weight.slice(i) * slab.col(i) * spike(i);
202  }
203 
204  output = ((1.0 / visiblePenalty(0)) * output);
205 }
206 
207 template<
208  typename InitializationRuleType,
209  typename DataType,
210  typename PolicyType
211 >
212 template<typename Policy, typename InputType>
213 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
215  const InputType& input,
216  DataType& output)
217 {
218  output.set_size(hiddenSize + poolSize * hiddenSize, 1);
219 
220  DataType spike(output.memptr(), hiddenSize, 1, false, false);
221  DataType slab(output.memptr() + hiddenSize, poolSize, hiddenSize, false,
222  false);
223 
224  SpikeMean(input, spike);
225  SampleSpike(spike, spikeSamples);
226  SlabMean(input, spikeSamples, slab);
227 }
228 
229 template<
230  typename InitializationRuleType,
231  typename DataType,
232  typename PolicyType
233 >
234 template<typename Policy, typename InputType>
235 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
237  const InputType& visible,
238  DataType& spikeMean)
239 {
240  for (size_t i = 0; i < hiddenSize; ++i)
241  {
242  spikeMean(i) = LogisticFunction::Fn(0.5 * (1.0 / slabPenalty) * arma::accu(
243  visible.t() * (weight.slice(i) * weight.slice(i).t()) * visible)
244  / std::pow(visible.n_cols, 2) + spikeBias(i));
245  }
246 }
247 
248 template<
249  typename InitializationRuleType,
250  typename DataType,
251  typename PolicyType
252 >
253 template<typename Policy, typename InputType>
254 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
256  InputType& spikeMean,
257  DataType& spike)
258 {
259  for (size_t i = 0; i < hiddenSize; ++i)
260  {
261  spike(i) = math::RandBernoulli(spikeMean(i));
262  }
263 }
264 
265 template<
266  typename InitializationRuleType,
267  typename DataType,
268  typename PolicyType
269 >
270 template<typename Policy, typename InputType>
271 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
273  const DataType& visible,
274  DataType& spike,
275  DataType& slabMean)
276 {
277  for (size_t i = 0; i < hiddenSize; ++i)
278  {
279  slabMean.col(i) = arma::mean((1.0 / slabPenalty) * spike(i) *
280  weight.slice(i).t() * visible, 1);
281  }
282 }
283 
284 template<
285  typename InitializationRuleType,
286  typename DataType,
287  typename PolicyType
288 >
289 template<typename Policy, typename InputType>
290 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
292  InputType& slabMean,
293  DataType& slab)
294 {
295  for (size_t i = 0; i < hiddenSize; ++i)
296  {
297  for (size_t j = 0; j < poolSize; ++j)
298  {
299  slab(j, i) = math::RandNormal(slabMean(j, i), 1.0 / slabPenalty);
300  }
301  }
302 }
303 
304 } // namespace ann
305 } // namespace mlpack
306 
307 #endif
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the hidden layer given the visible layer using Bernoulli function.
Definition: rbm_impl.hpp:166
double RandBernoulli(const double input)
Generates a 0/1 specified by the input.
Definition: random.hpp:99
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the BinaryRBM.
Definition: rbm_impl.hpp:113
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SlabMean(const DataType &visible, DataType &spike, DataType &slabMean)
The function calculates the mean of Normal distribution of P(s|v, h), where the mean is given by: ...
Definition: spike_slab_rbm_impl.hpp:272
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean for the hidden layer.
Definition: rbm_impl.hpp:220
double RandNormal()
Generates a normally distributed random number with mean 0 and variance 1.
Definition: random.hpp:127
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input.
Definition: rbm_impl.hpp:129
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
static double Fn(const double x)
Computes the softplus function.
Definition: softplus_function.hpp:52
static double Fn(const eT x)
Computes the logistic function.
Definition: logistic_function.hpp:39
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSlab(InputType &slabMean, DataType &slab)
The function samples from the Normal distribution of P(s|v, h), where the mean is given by: and vari...
Definition: spike_slab_rbm_impl.hpp:291
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SpikeMean(const InputType &visible, DataType &spikeMean)
The function calculates the mean of the distribution P(h|v), where mean is given by: ...
Definition: spike_slab_rbm_impl.hpp:236
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSpike(InputType &spikeMean, DataType &spike)
The function samples the spike function using Bernoulli distribution.
Definition: spike_slab_rbm_impl.hpp:255
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean for the visible layer.
Definition: rbm_impl.hpp:204
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the visible layer given the hidden layer using Bernoulli function.
Definition: rbm_impl.hpp:185