13 #ifndef MLPACK_METHODS_ANN_LAYER_VIRTUALBATCHNORM_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_VIRTUALBATCHNORM_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType>
32 template <
typename InputDataType,
typename OutputDataType>
35 const arma::Mat<eT>& referenceBatch,
42 weights.set_size(size + size, 1);
44 referenceBatchMean = arma::mean(referenceBatch, 1);
45 referenceBatchMeanSquared = arma::mean(arma::square(referenceBatch), 1);
46 newCoefficient = 1.0 / (referenceBatch.n_cols + 1);
47 oldCoefficient = 1 - newCoefficient;
50 template<
typename InputDataType,
typename OutputDataType>
53 gamma = arma::mat(weights.memptr(), size, 1,
false,
false);
54 beta = arma::mat(weights.memptr() + gamma.n_elem, size, 1,
false,
false);
65 template<
typename InputDataType,
typename OutputDataType>
68 const arma::Mat<eT>& input, arma::Mat<eT>& output)
70 Log::Assert(input.n_rows % size == 0,
"Input features must be divisible \ 73 inputParameter = input;
74 arma::mat inputMean = arma::mean(input, 1);
75 arma::mat inputMeanSquared = arma::mean(arma::square(input), 1);
77 mean = oldCoefficient * referenceBatchMean + newCoefficient * inputMean;
78 arma::mat meanSquared = oldCoefficient * referenceBatchMeanSquared +
79 newCoefficient * inputMeanSquared;
80 variance = meanSquared - arma::square(mean);
82 output = input.each_col() - mean;
83 inputSubMean = output;
84 output.each_col() /= arma::sqrt(variance + eps);
89 output.each_col() %= gamma;
90 output.each_col() += beta;
93 template<
typename InputDataType,
typename OutputDataType>
96 const arma::Mat<eT>& ,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
98 const arma::mat stdInv = 1.0 / arma::sqrt(variance + eps);
101 const arma::mat norm = gy.each_col() % gamma;
104 const arma::mat var = arma::sum(norm % inputSubMean, 1) %
105 arma::pow(stdInv, 3.0) * -0.5;
109 g = (norm.each_col() % stdInv) + ((inputParameter.each_col() %
110 var) * 2 * newCoefficient / inputParameter.n_cols);
114 g.each_col() += (arma::sum(norm.each_col() % -stdInv, 1) + (var %
115 mean * -2)) * newCoefficient / inputParameter.n_cols;
118 template<
typename InputDataType,
typename OutputDataType>
119 template<
typename eT>
121 const arma::Mat<eT>& ,
122 const arma::Mat<eT>& error,
123 arma::Mat<eT>& gradient)
125 gradient.set_size(size + size, 1);
128 gradient.submat(0, 0, gamma.n_elem - 1, 0) = arma::sum(normalized % error, 1);
131 gradient.submat(gamma.n_elem, 0, gradient.n_elem - 1, 0) =
135 template<
typename InputDataType,
typename OutputDataType>
136 template<
typename Archive>
138 Archive& ar,
const uint32_t )
140 ar(CEREAL_NVP(size));
142 if (cereal::is_loading<Archive>())
144 weights.set_size(size + size, 1);
149 ar(CEREAL_NVP(gamma));
150 ar(CEREAL_NVP(beta));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
VirtualBatchNorm()
Create the VirtualBatchNorm object.
Definition: virtual_batch_norm_impl.hpp:23
OutputDataType const & Gradient() const
Get the gradient.
Definition: virtual_batch_norm.hpp:121
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
Definition: virtual_batch_norm_impl.hpp:95
void Reset()
Reset the layer parameters.
Definition: virtual_batch_norm_impl.hpp:51
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of the Virtual Batch Normalization layer.
Definition: virtual_batch_norm_impl.hpp:67
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: virtual_batch_norm_impl.hpp:137
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38