15 #ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_IMPL_HPP 16 #define MLPACK_METHODS_ANN_LAYER_BATCHNORM_IMPL_HPP 24 template<
typename InputDataType,
typename OutputDataType>
38 template <
typename InputDataType,
typename OutputDataType>
43 const double momentum) :
54 runningMean.zeros(size, 1);
55 runningVariance.ones(size, 1);
58 template<
typename InputDataType,
typename OutputDataType>
62 gamma = arma::mat(weights.memptr(), size, 1,
false,
false);
64 beta = arma::mat(weights.memptr() + gamma.n_elem, size, 1,
false,
false);
72 deterministic =
false;
76 template<
typename InputDataType,
typename OutputDataType>
79 const arma::Mat<eT>& input,
80 arma::Mat<eT>& output)
82 Log::Assert(input.n_rows % size == 0,
"Input features must be divisible \ 85 const size_t batchSize = input.n_cols;
86 const size_t inputSize = input.n_rows / size;
89 output.set_size(arma::size(input));
95 if (batchSize == 1 && inputSize == 1)
97 Log::Warn <<
"Variance for single element isn't defined and" <<
98 " will be set to 0.0 for training. Use a batch-size" <<
99 " greater than 1 to fix the warning." << std::endl;
104 arma::cube inputTemp(
const_cast<arma::Mat<eT>&
>(input).memptr(),
105 inputSize, size, batchSize,
false,
false);
108 arma::cube outputTemp(
const_cast<arma::Mat<eT>&
>(output).memptr(),
109 inputSize, size, batchSize,
false,
false);
110 outputTemp = inputTemp;
113 mean = arma::mean(arma::mean(inputTemp, 2), 0);
114 variance = arma::mean(arma::mean(arma::pow(
115 inputTemp.each_slice() - arma::repmat(mean,
116 inputSize, 1), 2), 2), 0);
118 outputTemp.each_slice() -= arma::repmat(mean, inputSize, 1);
121 inputMean.set_size(arma::size(inputTemp));
122 inputMean = outputTemp;
125 outputTemp.each_slice() /= arma::sqrt(arma::repmat(variance,
126 inputSize, 1) + eps);
129 normalized.set_size(arma::size(inputTemp));
130 normalized = outputTemp;
132 outputTemp.each_slice() %= arma::repmat(gamma.t(),
134 outputTemp.each_slice() += arma::repmat(beta.t(),
138 averageFactor = average ? 1.0 / count : momentum;
140 double nElements = 0.0;
141 if (input.n_elem - size != 0)
142 nElements = 1.0 / (input.n_elem - size + eps);
145 runningMean = (1 - averageFactor) * runningMean + averageFactor *
147 runningVariance = (1 - averageFactor) * runningVariance +
148 input.n_elem * nElements *
149 averageFactor * variance.t();
155 arma::cube outputTemp(
const_cast<arma::Mat<eT>&
>(output).memptr(),
156 input.n_rows / size, size, batchSize,
false,
false);
158 outputTemp.each_slice() -= arma::repmat(runningMean.t(),
159 input.n_rows / size, 1);
160 outputTemp.each_slice() /= arma::sqrt(arma::repmat(runningVariance.t(),
161 input.n_rows / size, 1) + eps);
162 outputTemp.each_slice() %= arma::repmat(gamma.t(),
163 input.n_rows / size, 1);
164 outputTemp.each_slice() += arma::repmat(beta.t(),
165 input.n_rows / size, 1);
169 template<
typename InputDataType,
typename OutputDataType>
170 template<
typename eT>
172 const arma::Mat<eT>& input,
173 const arma::Mat<eT>& gy,
176 const arma::mat stdInv = 1.0 / arma::sqrt(variance + eps);
178 g.set_size(arma::size(input));
179 arma::cube gyTemp(
const_cast<arma::Mat<eT>&
>(gy).memptr(),
180 input.n_rows / size, size, input.n_cols,
false,
false);
181 arma::cube gTemp(
const_cast<arma::Mat<eT>&
>(g).memptr(),
182 input.n_rows / size, size, input.n_cols,
false,
false);
185 arma::cube norm = gyTemp.each_slice() % arma::repmat(gamma.t(),
186 input.n_rows / size, 1);
189 arma::mat temp = arma::sum(norm % inputMean, 2);
190 arma::mat vars = temp % arma::repmat(arma::pow(stdInv, 3),
191 input.n_rows / size, 1) * -0.5;
195 gTemp = (norm.each_slice() % arma::repmat(stdInv,
196 input.n_rows / size, 1) +
197 (inputMean.each_slice() % vars * 2)) / input.n_cols;
201 arma::mat normTemp = arma::sum(norm.each_slice() %
202 arma::repmat(-stdInv, input.n_rows / size, 1) , 2) /
204 gTemp.each_slice() += normTemp;
207 template<
typename InputDataType,
typename OutputDataType>
208 template<
typename eT>
210 const arma::Mat<eT>& ,
211 const arma::Mat<eT>& error,
212 arma::Mat<eT>& gradient)
214 gradient.set_size(size + size, 1);
215 arma::cube errorTemp(
const_cast<arma::Mat<eT>&
>(error).memptr(),
216 error.n_rows / size, size, error.n_cols,
false,
false);
219 arma::mat temp = arma::sum(arma::sum(normalized % errorTemp, 0), 2);
220 gradient.submat(0, 0, gamma.n_elem - 1, 0) = temp.t();
223 temp = arma::sum(arma::sum(errorTemp, 0), 2);
224 gradient.submat(gamma.n_elem, 0, gradient.n_elem - 1, 0) = temp.t();
227 template<
typename InputDataType,
typename OutputDataType>
228 template<
typename Archive>
230 Archive& ar,
const uint32_t )
232 ar(CEREAL_NVP(size));
234 if (cereal::is_loading<Archive>())
236 weights.set_size(size + size, 1);
241 ar(CEREAL_NVP(gamma));
242 ar(CEREAL_NVP(beta));
243 ar(CEREAL_NVP(count));
244 ar(CEREAL_NVP(averageFactor));
245 ar(CEREAL_NVP(momentum));
246 ar(CEREAL_NVP(average));
247 ar(CEREAL_NVP(runningMean));
248 ar(CEREAL_NVP(runningVariance));
BatchNorm()
Create the BatchNorm object.
Definition: batch_norm_impl.hpp:25
void Reset()
Reset the layer parameters.
Definition: batch_norm_impl.hpp:59
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t WeightSize() const
Get size of weights.
Definition: batch_norm.hpp:164
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of the Batch Normalization layer.
Definition: batch_norm_impl.hpp:78
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
Definition: batch_norm_impl.hpp:171
OutputDataType const & Gradient() const
Get the gradient.
Definition: batch_norm.hpp:132
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: batch_norm_impl.hpp:229
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38