mlpack
wgan_impl.hpp
Go to the documentation of this file.
1 
10 #ifndef MLPACK_METHODS_ANN_GAN_WGAN_IMPL_HPP
11 #define MLPACK_METHODS_ANN_GAN_WGAN_IMPL_HPP
12 
13 #include "gan.hpp"
14 
15 #include <mlpack/core.hpp>
16 
21 
22 namespace mlpack {
23 namespace ann {
24 template<
25  typename Model,
26  typename InitializationRuleType,
27  typename Noise,
28  typename PolicyType
29 >
30 template<typename Policy>
31 typename std::enable_if<std::is_same<Policy, WGAN>::value, double>::type
33  const arma::mat& /* parameters */,
34  const size_t i,
35  const size_t /* batchSize */)
36 {
37  if (parameter.is_empty())
38  {
39  Reset();
40  }
41 
42  if (!deterministic)
43  {
44  deterministic = true;
45  ResetDeterministic();
46  }
47 
48  currentInput = arma::mat(predictors.memptr() + (i * predictors.n_rows),
49  predictors.n_rows, batchSize, false, false);
50  currentTarget = arma::mat(responses.memptr() + i, 1, batchSize, false,
51  false);
52 
53  discriminator.Forward(currentInput);
54  double res = discriminator.outputLayer.Forward(
55  boost::apply_visitor(
56  outputParameterVisitor,
57  discriminator.network.back()), currentTarget);
58 
59  noise.imbue( [&]() { return noiseFunction();} );
60  generator.Forward(noise);
61 
62  predictors.cols(numFunctions, numFunctions + batchSize - 1) =
63  boost::apply_visitor(outputParameterVisitor, generator.network.back());
64  discriminator.Forward(predictors.cols(numFunctions,
65  numFunctions + batchSize - 1));
66  responses.cols(numFunctions, numFunctions + batchSize - 1) =
67  -arma::ones(1, batchSize);
68 
69  currentTarget = arma::mat(responses.memptr() + numFunctions,
70  1, batchSize, false, false);
71  res += discriminator.outputLayer.Forward(
72  boost::apply_visitor(
73  outputParameterVisitor,
74  discriminator.network.back()), currentTarget);
75 
76  return res;
77 }
78 
79 template<
80  typename Model,
81  typename InitializationRuleType,
82  typename Noise,
83  typename PolicyType
84 >
85 template<typename GradType, typename Policy>
86 typename std::enable_if<std::is_same<Policy, WGAN>::value, double>::type
88 EvaluateWithGradient(const arma::mat& /* parameters */,
89  const size_t i,
90  GradType& gradient,
91  const size_t /* batchSize */)
92 {
93  if (parameter.is_empty())
94  {
95  Reset();
96  }
97 
98  if (gradient.is_empty())
99  {
100  if (parameter.is_empty())
101  Reset();
102  gradient = arma::zeros<arma::mat>(parameter.n_elem, 1);
103  }
104  else
105  gradient.zeros();
106 
107  if (this->deterministic)
108  {
109  this->deterministic = false;
110  ResetDeterministic();
111  }
112 
113  if (noiseGradientDiscriminator.is_empty())
114  {
115  noiseGradientDiscriminator = arma::zeros<arma::mat>(
116  gradientDiscriminator.n_elem, 1);
117  }
118  else
119  {
120  noiseGradientDiscriminator.zeros();
121  }
122 
123  gradientGenerator = arma::mat(gradient.memptr(),
124  generator.Parameters().n_elem, 1, false, false);
125 
126  gradientDiscriminator = arma::mat(gradient.memptr() +
127  gradientGenerator.n_elem,
128  discriminator.Parameters().n_elem, 1, false, false);
129 
130  // Get the gradients of the Discriminator.
131  double res = discriminator.EvaluateWithGradient(discriminator.parameter,
132  i, gradientDiscriminator, batchSize);
133 
134  noise.imbue( [&]() { return noiseFunction();} );
135  generator.Forward(noise);
136  predictors.cols(numFunctions, numFunctions + batchSize - 1) =
137  boost::apply_visitor(outputParameterVisitor, generator.network.back());
138  responses.cols(numFunctions, numFunctions + batchSize - 1) =
139  -arma::ones(1, batchSize);
140 
141  // Get the gradients of the Generator.
142  res += discriminator.EvaluateWithGradient(discriminator.parameter,
143  numFunctions, noiseGradientDiscriminator, batchSize);
144  gradientDiscriminator += noiseGradientDiscriminator;
145  gradientDiscriminator = arma::clamp(gradientDiscriminator,
146  -clippingParameter, clippingParameter);
147 
148  if (currentBatch % generatorUpdateStep == 0 && preTrainSize == 0)
149  {
150  // Minimize -D(G(noise)).
151  // Pass the error from Discriminator to Generator.
152  responses.cols(numFunctions, numFunctions + batchSize - 1) =
153  arma::ones(1, batchSize);
154 
155  discriminator.outputLayer.Backward(
156  boost::apply_visitor(outputParameterVisitor,
157  discriminator.network.back()), discriminator.responses.cols(
158  numFunctions, numFunctions + batchSize - 1), discriminator.error);
159  discriminator.Backward();
160 
161  generator.error = boost::apply_visitor(deltaVisitor,
162  discriminator.network[1]);
163 
164  generator.Predictors() = noise;
165  generator.Backward();
166  generator.ResetGradients(gradientGenerator);
167  generator.Gradient(generator.Predictors().cols(0, batchSize - 1));
168 
169  gradientGenerator *= multiplier;
170  }
171 
172  currentBatch++;
173 
174  if (preTrainSize > 0)
175  {
176  preTrainSize--;
177  }
178 
179  return res;
180 }
181 
182 template<
183  typename Model,
184  typename InitializationRuleType,
185  typename Noise,
186  typename PolicyType
187 >
188 template<typename Policy>
189 typename std::enable_if<std::is_same<Policy, WGAN>::value, void>::type
191 Gradient(const arma::mat& parameters,
192  const size_t i,
193  arma::mat& gradient,
194  const size_t batchSize)
195 {
196  this->EvaluateWithGradient(parameters, i, gradient, batchSize);
197 }
198 
199 } // namespace ann
200 } // namespace mlpack
201 # endif
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
Definition: gan_impl.hpp:398
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:239
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:295