11 #ifndef MLPACK_METHODS_ANN_RBM_SPIKE_SLAB_RBM_IMPL_HPP 12 #define MLPACK_METHODS_ANN_RBM_SPIKE_SLAB_RBM_IMPL_HPP 24 typename InitializationRuleType,
28 template<
typename Policy,
typename InputType>
29 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
30 RBM<InitializationRuleType, DataType, PolicyType>::Reset()
32 size_t shape = (visibleSize * hiddenSize * poolSize) + visibleSize +
34 parameter.set_size(shape, 1);
35 positiveGradient.set_size(shape, 1);
36 negativeGradient.set_size(shape, 1);
37 tempNegativeGradient.set_size(shape, 1);
38 negativeSamples.set_size(visibleSize, batchSize);
39 visibleMean.set_size(visibleSize, 1);
40 spikeMean.set_size(hiddenSize, 1);
41 spikeSamples.set_size(hiddenSize, 1);
42 slabMean.set_size(poolSize, hiddenSize);
45 weight = arma::Cube<ElemType>(parameter.memptr(), visibleSize, poolSize,
46 hiddenSize,
false,
false);
48 spikeBias = DataType(parameter.memptr() + weight.n_elem, hiddenSize, 1,
51 visiblePenalty = DataType(parameter.memptr() + weight.n_elem +
52 spikeBias.n_elem, 1, 1,
false,
false);
55 positiveGradient.zeros();
56 negativeGradient.zeros();
57 tempNegativeGradient.zeros();
58 initializeRule.Initialize(parameter, parameter.n_elem, 1);
64 typename InitializationRuleType,
68 template<
typename Policy,
typename InputType>
69 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
double>::type
71 const arma::Mat<ElemType>& input)
73 ElemType freeEnergy = 0.5 * visiblePenalty(0) * arma::dot(input, input);
75 freeEnergy -= 0.5 * hiddenSize * poolSize *
76 std::log((2.0 * M_PI) / slabPenalty);
78 for (
size_t i = 0; i < hiddenSize; ++i)
80 ElemType sum = arma::accu(arma::square(input.t() * weight.slice(i))) /
89 typename InitializationRuleType,
93 template<
typename Policy,
typename InputType>
94 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
96 const InputType& input,
99 arma::Cube<ElemType> weightGrad = arma::Cube<ElemType>
100 (gradient.memptr(), visibleSize, poolSize, hiddenSize,
false,
false);
102 DataType spikeBiasGrad = DataType(gradient.memptr() + weightGrad.n_elem,
103 hiddenSize, 1,
false,
false);
105 SpikeMean(input, spikeMean);
106 SampleSpike(spikeMean, spikeSamples);
107 SlabMean(input, spikeSamples, slabMean);
109 for (
size_t i = 0 ; i < hiddenSize; ++i)
111 weightGrad.slice(i) = input * arma::repmat(slabMean.col(i).t(),
112 input.n_cols, 1) * spikeMean(i);
115 spikeBiasGrad = spikeMean;
117 gradient.row(weightGrad.n_elem + spikeBiasGrad.n_elem) = -0.5 * arma::dot(
118 input, input) / std::pow(input.n_cols, 2);
122 typename InitializationRuleType,
126 template<
typename Policy,
typename InputType>
127 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
129 const arma::Mat<ElemType>& input,
130 arma::Mat<ElemType>& output)
132 output.set_size(hiddenSize + poolSize * hiddenSize, 1);
134 DataType spike(output.memptr(), hiddenSize, 1,
false,
false);
135 DataType slab(output.memptr() + hiddenSize, poolSize, hiddenSize,
false,
138 SpikeMean(input, spike);
139 SampleSpike(spike, spike);
140 SlabMean(input, spike, slab);
141 SampleSlab(slab, slab);
145 typename InitializationRuleType,
149 template<
typename Policy,
typename InputType>
150 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
152 arma::Mat<ElemType>& input,
153 arma::Mat<ElemType>& output)
155 const size_t numMaxTrials = 10;
158 VisibleMean(input, visibleMean);
159 output.set_size(visibleSize, 1);
161 for (k = 0; k < numMaxTrials; ++k)
163 for (
size_t i = 0; i < visibleSize; ++i)
167 if (arma::norm(output, 2) < radius)
173 if (k == numMaxTrials)
175 Log::Warn <<
"Outputs are still not in visible unit " 176 << arma::norm(output, 2)
177 <<
" terminating optimization." 183 typename InitializationRuleType,
187 template<
typename Policy,
typename InputType>
188 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
193 output.zeros(visibleSize, 1);
195 DataType spike(input.memptr(), hiddenSize, 1,
false,
false);
196 DataType slab(input.memptr() + hiddenSize, poolSize, hiddenSize,
false,
199 for (
size_t i = 0; i < hiddenSize; ++i)
201 output += weight.slice(i) * slab.col(i) * spike(i);
204 output = ((1.0 / visiblePenalty(0)) * output);
208 typename InitializationRuleType,
212 template<
typename Policy,
typename InputType>
213 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
215 const InputType& input,
218 output.set_size(hiddenSize + poolSize * hiddenSize, 1);
220 DataType spike(output.memptr(), hiddenSize, 1,
false,
false);
221 DataType slab(output.memptr() + hiddenSize, poolSize, hiddenSize,
false,
224 SpikeMean(input, spike);
225 SampleSpike(spike, spikeSamples);
226 SlabMean(input, spikeSamples, slab);
230 typename InitializationRuleType,
234 template<
typename Policy,
typename InputType>
235 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
237 const InputType& visible,
240 for (
size_t i = 0; i < hiddenSize; ++i)
243 visible.t() * (weight.slice(i) * weight.slice(i).t()) * visible)
244 / std::pow(visible.n_cols, 2) + spikeBias(i));
249 typename InitializationRuleType,
253 template<
typename Policy,
typename InputType>
254 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
256 InputType& spikeMean,
259 for (
size_t i = 0; i < hiddenSize; ++i)
266 typename InitializationRuleType,
270 template<
typename Policy,
typename InputType>
271 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
273 const DataType& visible,
277 for (
size_t i = 0; i < hiddenSize; ++i)
279 slabMean.col(i) = arma::mean((1.0 / slabPenalty) * spike(i) *
280 weight.slice(i).t() * visible, 1);
285 typename InitializationRuleType,
289 template<
typename Policy,
typename InputType>
290 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
295 for (
size_t i = 0; i < hiddenSize; ++i)
297 for (
size_t j = 0; j < poolSize; ++j)
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the hidden layer given the visible layer using Bernoulli function.
Definition: rbm_impl.hpp:166
double RandBernoulli(const double input)
Generates a 0/1 specified by the input.
Definition: random.hpp:99
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the BinaryRBM.
Definition: rbm_impl.hpp:113
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SlabMean(const DataType &visible, DataType &spike, DataType &slabMean)
The function calculates the mean of Normal distribution of P(s|v, h), where the mean is given by: ...
Definition: spike_slab_rbm_impl.hpp:272
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean for the hidden layer.
Definition: rbm_impl.hpp:220
double RandNormal()
Generates a normally distributed random number with mean 0 and variance 1.
Definition: random.hpp:127
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input.
Definition: rbm_impl.hpp:129
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
static double Fn(const double x)
Computes the softplus function.
Definition: softplus_function.hpp:52
static double Fn(const eT x)
Computes the logistic function.
Definition: logistic_function.hpp:39
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSlab(InputType &slabMean, DataType &slab)
The function samples from the Normal distribution of P(s|v, h), where the mean is given by: and vari...
Definition: spike_slab_rbm_impl.hpp:291
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SpikeMean(const InputType &visible, DataType &spikeMean)
The function calculates the mean of the distribution P(h|v), where mean is given by: ...
Definition: spike_slab_rbm_impl.hpp:236
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSpike(InputType &spikeMean, DataType &spike)
The function samples the spike function using Bernoulli distribution.
Definition: spike_slab_rbm_impl.hpp:255
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean for the visible layer.
Definition: rbm_impl.hpp:204
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the visible layer given the hidden layer using Bernoulli function.
Definition: rbm_impl.hpp:185