mlpack
virtual_batch_norm_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_VIRTUALBATCHNORM_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_VIRTUALBATCHNORM_IMPL_HPP
15 
16 // In case it is not included.
17 #include "virtual_batch_norm.hpp"
18 
19 namespace mlpack {
20 namespace ann {
22 template<typename InputDataType, typename OutputDataType>
24  size(0),
25  eps(1e-8),
26  loading(false),
27  oldCoefficient(0),
28  newCoefficient(0)
29 {
30  // Nothing to do here.
31 }
32 template <typename InputDataType, typename OutputDataType>
33 template<typename eT>
35  const arma::Mat<eT>& referenceBatch,
36  const size_t size,
37  const double eps) :
38  size(size),
39  eps(eps),
40  loading(false)
41 {
42  weights.set_size(size + size, 1);
43 
44  referenceBatchMean = arma::mean(referenceBatch, 1);
45  referenceBatchMeanSquared = arma::mean(arma::square(referenceBatch), 1);
46  newCoefficient = 1.0 / (referenceBatch.n_cols + 1);
47  oldCoefficient = 1 - newCoefficient;
48 }
49 
50 template<typename InputDataType, typename OutputDataType>
52 {
53  gamma = arma::mat(weights.memptr(), size, 1, false, false);
54  beta = arma::mat(weights.memptr() + gamma.n_elem, size, 1, false, false);
55 
56  if (!loading)
57  {
58  gamma.fill(1.0);
59  beta.fill(0.0);
60  }
61 
62  loading = false;
63 }
64 
65 template<typename InputDataType, typename OutputDataType>
66 template<typename eT>
68  const arma::Mat<eT>& input, arma::Mat<eT>& output)
69 {
70  Log::Assert(input.n_rows % size == 0, "Input features must be divisible \
71  by feature maps.");
72 
73  inputParameter = input;
74  arma::mat inputMean = arma::mean(input, 1);
75  arma::mat inputMeanSquared = arma::mean(arma::square(input), 1);
76 
77  mean = oldCoefficient * referenceBatchMean + newCoefficient * inputMean;
78  arma::mat meanSquared = oldCoefficient * referenceBatchMeanSquared +
79  newCoefficient * inputMeanSquared;
80  variance = meanSquared - arma::square(mean);
81  // Normalize the input.
82  output = input.each_col() - mean;
83  inputSubMean = output;
84  output.each_col() /= arma::sqrt(variance + eps);
85 
86  // Reused in the backward and gradient step.
87  normalized = output;
88  // Scale and shift the output.
89  output.each_col() %= gamma;
90  output.each_col() += beta;
91 }
92 
93 template<typename InputDataType, typename OutputDataType>
94 template<typename eT>
96  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
97 {
98  const arma::mat stdInv = 1.0 / arma::sqrt(variance + eps);
99 
100  // dl / dxhat.
101  const arma::mat norm = gy.each_col() % gamma;
102 
103  // sum dl / dxhat * (x - mu) * -0.5 * stdInv^3.
104  const arma::mat var = arma::sum(norm % inputSubMean, 1) %
105  arma::pow(stdInv, 3.0) * -0.5;
106 
107  // dl / dxhat * 1 / stdInv + variance * 2 * (x - mu) / m +
108  // dl / dmu * newCoefficient / m.
109  g = (norm.each_col() % stdInv) + ((inputParameter.each_col() %
110  var) * 2 * newCoefficient / inputParameter.n_cols);
111 
112  // (sum (dl / dxhat * -1 / stdInv) + (variance * mean * -2)) *
113  // newCoefficient / m.
114  g.each_col() += (arma::sum(norm.each_col() % -stdInv, 1) + (var %
115  mean * -2)) * newCoefficient / inputParameter.n_cols;
116 }
117 
118 template<typename InputDataType, typename OutputDataType>
119 template<typename eT>
121  const arma::Mat<eT>& /* input */,
122  const arma::Mat<eT>& error,
123  arma::Mat<eT>& gradient)
124 {
125  gradient.set_size(size + size, 1);
126 
127  // Step 5: dl / dy * xhat.
128  gradient.submat(0, 0, gamma.n_elem - 1, 0) = arma::sum(normalized % error, 1);
129 
130  // Step 6: dl / dy.
131  gradient.submat(gamma.n_elem, 0, gradient.n_elem - 1, 0) =
132  arma::sum(error, 1);
133 }
134 
135 template<typename InputDataType, typename OutputDataType>
136 template<typename Archive>
138  Archive& ar, const uint32_t /* version */)
139 {
140  ar(CEREAL_NVP(size));
141 
142  if (cereal::is_loading<Archive>())
143  {
144  weights.set_size(size + size, 1);
145  loading = true;
146  }
147 
148  ar(CEREAL_NVP(eps));
149  ar(CEREAL_NVP(gamma));
150  ar(CEREAL_NVP(beta));
151 }
152 
153 } // namespace ann
154 } // namespace mlpack
155 
156 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
VirtualBatchNorm()
Create the VirtualBatchNorm object.
Definition: virtual_batch_norm_impl.hpp:23
OutputDataType const & Gradient() const
Get the gradient.
Definition: virtual_batch_norm.hpp:121
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
Definition: virtual_batch_norm_impl.hpp:95
void Reset()
Reset the layer parameters.
Definition: virtual_batch_norm_impl.hpp:51
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of the Virtual Batch Normalization layer.
Definition: virtual_batch_norm_impl.hpp:67
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: virtual_batch_norm_impl.hpp:137
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38