mlpack
minibatch_discrimination_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_MINIBATCH_DISCRIMINATION_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_MINIBATCH_DISCRIMINATION_IMPL_HPP
14 
15 // In case it hasn't yet been included.
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType>
22 MiniBatchDiscrimination<InputDataType, OutputDataType
24  A(0),
25  B(0),
26  C(0),
27  batchSize(0)
28 {
29  // Nothing to do here.
30 }
31 
32 template <typename InputDataType, typename OutputDataType>
33 MiniBatchDiscrimination<InputDataType, OutputDataType
35  const size_t inSize,
36  const size_t outSize,
37  const size_t features) :
38  A(inSize),
39  B(outSize - inSize),
40  C(features),
41  batchSize(0)
42 {
43  weights.set_size(A * B * C, 1);
44 }
45 
46 template<typename InputDataType, typename OutputDataType>
48 {
49  weight = arma::mat(weights.memptr(), B * C, A, false, false);
50 }
51 
52 template<typename InputDataType, typename OutputDataType>
53 template<typename eT>
55  const arma::Mat<eT>& input, arma::Mat<eT>& output)
56 {
57  batchSize = input.n_cols;
58  tempM = weight * input;
59  M = arma::cube(tempM.memptr(), B, C, batchSize, false, false);
60  distances.set_size(B, batchSize, batchSize);
61  output.set_size(B, batchSize);
62 
63  for (size_t i = 0; i < M.n_slices; ++i)
64  {
65  output.col(i).ones();
66  for (size_t j = 0; j < M.n_slices; ++j)
67  {
68  if (j < i)
69  {
70  output.col(i) += distances.slice(j).col(i);
71  }
72  else if (i == j)
73  {
74  continue;
75  }
76  else
77  {
78  distances.slice(i).col(j) =
79  arma::exp(-arma::sum(abs(M.slice(i) - M.slice(j)), 1));
80  output.col(i) += distances.slice(i).col(j);
81  }
82  }
83  }
84 
85  output = join_cols(input, output); // (A + B) x batchSize
86 }
87 
88 template<typename InputDataType, typename OutputDataType>
89 template<typename eT>
91  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
92 {
93  g = gy.head_rows(A);
94  arma::Mat<eT> gM = gy.tail_rows(B);
95  deltaM.zeros(B, C, batchSize);
96 
97  for (size_t i = 0; i < M.n_slices; ++i)
98  {
99  for (size_t j = 0; j < M.n_slices; ++j)
100  {
101  if (i == j)
102  {
103  continue;
104  }
105  arma::mat t = arma::sign(M.slice(i) - M.slice(j));
106  t.each_col() %=
107  distances.slice(std::min(i, j)).col(std::max(i, j)) % gM.col(i);
108  deltaM.slice(i) -= t;
109  deltaM.slice(j) += t;
110  }
111  }
112 
113  deltaTemp = arma::mat(deltaM.memptr(), B * C, batchSize, false, false);
114  g += weight.t() * deltaTemp;
115 }
116 
117 template<typename InputDataType, typename OutputDataType>
118 template<typename eT>
120  const arma::Mat<eT>& input,
121  const arma::Mat<eT>& /* error */,
122  arma::Mat<eT>& gradient)
123 {
124  gradient = arma::vectorise(deltaTemp * input.t());
125 }
126 
127 template<typename InputDataType, typename OutputDataType>
128 template<typename Archive>
130  Archive& ar, const uint32_t /* version */)
131 {
132  ar(CEREAL_NVP(A));
133  ar(CEREAL_NVP(B));
134  ar(CEREAL_NVP(C));
135 
136  // This is inefficient, but we have to allocate this memory so that
137  // WeightSetVisitor gets the right size.
138  if (cereal::is_loading<Archive>())
139  {
140  weights.set_size(A * B * C, 1);
141  }
142 }
143 
144 } // namespace ann
145 } // namespace mlpack
146 
147 #endif
Definition: sfinae_test.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: sfinae_test.cpp:40
Definition: sfinae_test.cpp:18
MiniBatchDiscrimination()
Create the MiniBatchDiscrimination object.
Definition: minibatch_discrimination_impl.hpp:23
void Reset()
Reset the layer parameter.
Definition: minibatch_discrimination_impl.hpp:47
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: minibatch_discrimination_impl.hpp:129
Implementation of the MiniBatchDiscrimination layer.
Definition: layer_types.hpp:122
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed-backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: minibatch_discrimination_impl.hpp:90
OutputDataType const & Gradient() const
Get the gradient.
Definition: minibatch_discrimination.hpp:133
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: minibatch_discrimination_impl.hpp:54