12 #ifndef MLPACK_METHODS_ANN_LAYER_MINIBATCH_DISCRIMINATION_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LAYER_MINIBATCH_DISCRIMINATION_IMPL_HPP 21 template<
typename InputDataType,
typename OutputDataType>
22 MiniBatchDiscrimination<InputDataType, OutputDataType
32 template <
typename InputDataType,
typename OutputDataType>
37 const size_t features) :
43 weights.set_size(
A *
B *
C, 1);
46 template<
typename InputDataType,
typename OutputDataType>
49 weight = arma::mat(weights.memptr(),
B *
C,
A,
false,
false);
52 template<
typename InputDataType,
typename OutputDataType>
55 const arma::Mat<eT>& input, arma::Mat<eT>& output)
57 batchSize = input.n_cols;
58 tempM = weight * input;
59 M = arma::cube(tempM.memptr(),
B,
C, batchSize,
false,
false);
60 distances.set_size(
B, batchSize, batchSize);
61 output.set_size(
B, batchSize);
63 for (
size_t i = 0; i < M.n_slices; ++i)
66 for (
size_t j = 0; j < M.n_slices; ++j)
70 output.col(i) += distances.slice(j).col(i);
78 distances.slice(i).col(j) =
79 arma::exp(-arma::sum(abs(M.slice(i) - M.slice(j)), 1));
80 output.col(i) += distances.slice(i).col(j);
85 output = join_cols(input, output);
88 template<
typename InputDataType,
typename OutputDataType>
91 const arma::Mat<eT>& ,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
94 arma::Mat<eT> gM = gy.tail_rows(
B);
95 deltaM.zeros(
B,
C, batchSize);
97 for (
size_t i = 0; i < M.n_slices; ++i)
99 for (
size_t j = 0; j < M.n_slices; ++j)
105 arma::mat t = arma::sign(M.slice(i) - M.slice(j));
107 distances.slice(std::min(i, j)).col(std::max(i, j)) % gM.col(i);
108 deltaM.slice(i) -= t;
109 deltaM.slice(j) += t;
113 deltaTemp = arma::mat(deltaM.memptr(),
B *
C, batchSize,
false,
false);
114 g += weight.t() * deltaTemp;
117 template<
typename InputDataType,
typename OutputDataType>
118 template<
typename eT>
120 const arma::Mat<eT>& input,
121 const arma::Mat<eT>& ,
122 arma::Mat<eT>& gradient)
124 gradient = arma::vectorise(deltaTemp * input.t());
127 template<
typename InputDataType,
typename OutputDataType>
128 template<
typename Archive>
130 Archive& ar,
const uint32_t )
138 if (cereal::is_loading<Archive>())
140 weights.set_size(
A *
B *
C, 1);
Definition: sfinae_test.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: sfinae_test.cpp:40
Definition: sfinae_test.cpp:18
MiniBatchDiscrimination()
Create the MiniBatchDiscrimination object.
Definition: minibatch_discrimination_impl.hpp:23
void Reset()
Reset the layer parameter.
Definition: minibatch_discrimination_impl.hpp:47
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: minibatch_discrimination_impl.hpp:129
Implementation of the MiniBatchDiscrimination layer.
Definition: layer_types.hpp:122
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed-backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: minibatch_discrimination_impl.hpp:90
OutputDataType const & Gradient() const
Get the gradient.
Definition: minibatch_discrimination.hpp:133
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: minibatch_discrimination_impl.hpp:54