mlpack
batch_norm_impl.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_IMPL_HPP
16 #define MLPACK_METHODS_ANN_LAYER_BATCHNORM_IMPL_HPP
17 
18 // In case it is not included.
19 #include "batch_norm.hpp"
20 
21 namespace mlpack {
22 namespace ann {
24 template<typename InputDataType, typename OutputDataType>
26  size(0),
27  eps(1e-8),
28  average(true),
29  momentum(0.0),
30  loading(false),
31  deterministic(false),
32  count(0),
33  averageFactor(0.0)
34 {
35  // Nothing to do here.
36 }
37 
38 template <typename InputDataType, typename OutputDataType>
40  const size_t size,
41  const double eps,
42  const bool average,
43  const double momentum) :
44  size(size),
45  eps(eps),
46  average(average),
47  momentum(momentum),
48  loading(false),
49  deterministic(false),
50  count(0),
51  averageFactor(0.0)
52 {
53  weights.set_size(WeightSize(), 1);
54  runningMean.zeros(size, 1);
55  runningVariance.ones(size, 1);
56 }
57 
58 template<typename InputDataType, typename OutputDataType>
60 {
61  // Gamma acts as the scaling parameters for the normalized output.
62  gamma = arma::mat(weights.memptr(), size, 1, false, false);
63  // Beta acts as the shifting parameters for the normalized output.
64  beta = arma::mat(weights.memptr() + gamma.n_elem, size, 1, false, false);
65 
66  if (!loading)
67  {
68  gamma.fill(1.0);
69  beta.fill(0.0);
70  }
71 
72  deterministic = false;
73  loading = false;
74 }
75 
76 template<typename InputDataType, typename OutputDataType>
77 template<typename eT>
79  const arma::Mat<eT>& input,
80  arma::Mat<eT>& output)
81 {
82  Log::Assert(input.n_rows % size == 0, "Input features must be divisible \
83  by feature maps.");
84 
85  const size_t batchSize = input.n_cols;
86  const size_t inputSize = input.n_rows / size;
87 
88  // Set size of output equal to the size of input.
89  output.set_size(arma::size(input));
90 
91  // We will calculate minibatch norm on each channel / feature map.
92  if (!deterministic)
93  {
94  // Check only during training, batch-size can be one during inference.
95  if (batchSize == 1 && inputSize == 1)
96  {
97  Log::Warn << "Variance for single element isn't defined and" <<
98  " will be set to 0.0 for training. Use a batch-size" <<
99  " greater than 1 to fix the warning." << std::endl;
100  }
101 
102  // Input corresponds to output from convolution layer.
103  // Use a cube for simplicity.
104  arma::cube inputTemp(const_cast<arma::Mat<eT>&>(input).memptr(),
105  inputSize, size, batchSize, false, false);
106 
107  // Initialize output to same size and values for convenience.
108  arma::cube outputTemp(const_cast<arma::Mat<eT>&>(output).memptr(),
109  inputSize, size, batchSize, false, false);
110  outputTemp = inputTemp;
111 
112  // Calculate mean and variance over all channels.
113  mean = arma::mean(arma::mean(inputTemp, 2), 0);
114  variance = arma::mean(arma::mean(arma::pow(
115  inputTemp.each_slice() - arma::repmat(mean,
116  inputSize, 1), 2), 2), 0);
117 
118  outputTemp.each_slice() -= arma::repmat(mean, inputSize, 1);
119 
120  // Used in backward propagation.
121  inputMean.set_size(arma::size(inputTemp));
122  inputMean = outputTemp;
123 
124  // Normalize output.
125  outputTemp.each_slice() /= arma::sqrt(arma::repmat(variance,
126  inputSize, 1) + eps);
127 
128  // Re-used in backward propagation.
129  normalized.set_size(arma::size(inputTemp));
130  normalized = outputTemp;
131 
132  outputTemp.each_slice() %= arma::repmat(gamma.t(),
133  inputSize, 1);
134  outputTemp.each_slice() += arma::repmat(beta.t(),
135  inputSize, 1);
136 
137  count += 1;
138  averageFactor = average ? 1.0 / count : momentum;
139 
140  double nElements = 0.0;
141  if (input.n_elem - size != 0)
142  nElements = 1.0 / (input.n_elem - size + eps);
143 
144  // Update running mean and running variance.
145  runningMean = (1 - averageFactor) * runningMean + averageFactor *
146  mean.t();
147  runningVariance = (1 - averageFactor) * runningVariance +
148  input.n_elem * nElements *
149  averageFactor * variance.t();
150  }
151  else
152  {
153  // Normalize the input and scale and shift the output.
154  output = input;
155  arma::cube outputTemp(const_cast<arma::Mat<eT>&>(output).memptr(),
156  input.n_rows / size, size, batchSize, false, false);
157 
158  outputTemp.each_slice() -= arma::repmat(runningMean.t(),
159  input.n_rows / size, 1);
160  outputTemp.each_slice() /= arma::sqrt(arma::repmat(runningVariance.t(),
161  input.n_rows / size, 1) + eps);
162  outputTemp.each_slice() %= arma::repmat(gamma.t(),
163  input.n_rows / size, 1);
164  outputTemp.each_slice() += arma::repmat(beta.t(),
165  input.n_rows / size, 1);
166  }
167 }
168 
169 template<typename InputDataType, typename OutputDataType>
170 template<typename eT>
172  const arma::Mat<eT>& input,
173  const arma::Mat<eT>& gy,
174  arma::Mat<eT>& g)
175 {
176  const arma::mat stdInv = 1.0 / arma::sqrt(variance + eps);
177 
178  g.set_size(arma::size(input));
179  arma::cube gyTemp(const_cast<arma::Mat<eT>&>(gy).memptr(),
180  input.n_rows / size, size, input.n_cols, false, false);
181  arma::cube gTemp(const_cast<arma::Mat<eT>&>(g).memptr(),
182  input.n_rows / size, size, input.n_cols, false, false);
183 
184  // Step 1: dl / dxhat.
185  arma::cube norm = gyTemp.each_slice() % arma::repmat(gamma.t(),
186  input.n_rows / size, 1);
187 
188  // Step 2: sum dl / dxhat * (x - mu) * -0.5 * stdInv^3.
189  arma::mat temp = arma::sum(norm % inputMean, 2);
190  arma::mat vars = temp % arma::repmat(arma::pow(stdInv, 3),
191  input.n_rows / size, 1) * -0.5;
192 
193  // Step 3: dl / dxhat * 1 / stdInv + variance * 2 * (x - mu) / m +
194  // dl / dmu * 1 / m.
195  gTemp = (norm.each_slice() % arma::repmat(stdInv,
196  input.n_rows / size, 1) +
197  (inputMean.each_slice() % vars * 2)) / input.n_cols;
198 
199  // Step 4: sum (dl / dxhat * -1 / stdInv) + variance *
200  // (sum -2 * (x - mu)) / m.
201  arma::mat normTemp = arma::sum(norm.each_slice() %
202  arma::repmat(-stdInv, input.n_rows / size, 1) , 2) /
203  input.n_cols;
204  gTemp.each_slice() += normTemp;
205 }
206 
207 template<typename InputDataType, typename OutputDataType>
208 template<typename eT>
210  const arma::Mat<eT>& /* input */,
211  const arma::Mat<eT>& error,
212  arma::Mat<eT>& gradient)
213 {
214  gradient.set_size(size + size, 1);
215  arma::cube errorTemp(const_cast<arma::Mat<eT>&>(error).memptr(),
216  error.n_rows / size, size, error.n_cols, false, false);
217 
218  // Step 5: dl / dy * xhat.
219  arma::mat temp = arma::sum(arma::sum(normalized % errorTemp, 0), 2);
220  gradient.submat(0, 0, gamma.n_elem - 1, 0) = temp.t();
221 
222  // Step 6: dl / dy.
223  temp = arma::sum(arma::sum(errorTemp, 0), 2);
224  gradient.submat(gamma.n_elem, 0, gradient.n_elem - 1, 0) = temp.t();
225 }
226 
227 template<typename InputDataType, typename OutputDataType>
228 template<typename Archive>
230  Archive& ar, const uint32_t /* version */)
231 {
232  ar(CEREAL_NVP(size));
233 
234  if (cereal::is_loading<Archive>())
235  {
236  weights.set_size(size + size, 1);
237  loading = true;
238  }
239 
240  ar(CEREAL_NVP(eps));
241  ar(CEREAL_NVP(gamma));
242  ar(CEREAL_NVP(beta));
243  ar(CEREAL_NVP(count));
244  ar(CEREAL_NVP(averageFactor));
245  ar(CEREAL_NVP(momentum));
246  ar(CEREAL_NVP(average));
247  ar(CEREAL_NVP(runningMean));
248  ar(CEREAL_NVP(runningVariance));
249 }
250 
251 } // namespace ann
252 } // namespace mlpack
253 
254 #endif
BatchNorm()
Create the BatchNorm object.
Definition: batch_norm_impl.hpp:25
void Reset()
Reset the layer parameters.
Definition: batch_norm_impl.hpp:59
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t WeightSize() const
Get size of weights.
Definition: batch_norm.hpp:164
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of the Batch Normalization layer.
Definition: batch_norm_impl.hpp:78
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
Definition: batch_norm_impl.hpp:171
OutputDataType const & Gradient() const
Get the gradient.
Definition: batch_norm.hpp:132
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: batch_norm_impl.hpp:229
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38