mlpack
wgangp_impl.hpp
Go to the documentation of this file.
1 
10 #ifndef MLPACK_METHODS_ANN_GAN_WGANGP_IMPL_HPP
11 #define MLPACK_METHODS_ANN_GAN_WGANGP_IMPL_HPP
12 
13 #include "gan.hpp"
14 
15 #include <mlpack/core.hpp>
16 
21 
22 namespace mlpack {
23 namespace ann {
24 template<
25  typename Model,
26  typename InitializationRuleType,
27  typename Noise,
28  typename PolicyType
29 >
30 template<typename Policy>
31 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
32  double>::type
34  const arma::mat& /* parameters */,
35  const size_t i,
36  const size_t /* batchSize */)
37 {
38  if ((parameter.is_empty()))
39  {
40  Reset();
41  }
42 
43  if (!deterministic)
44  {
45  deterministic = true;
46  ResetDeterministic();
47  }
48 
49  currentInput = arma::mat(predictors.memptr() + (i * predictors.n_rows),
50  predictors.n_rows, batchSize, false, false);
51  currentTarget = arma::mat(responses.memptr() + i, 1, batchSize, false,
52  false);
53 
54  discriminator.Forward(std::move(currentInput));
55  double res = discriminator.outputLayer.Forward(
56  std::move(boost::apply_visitor(
57  outputParameterVisitor,
58  discriminator.network.back())), std::move(currentTarget));
59 
60  noise.imbue( [&]() { return noiseFunction();} );
61  generator.Forward(std::move(noise));
62 
63  arma::mat generatedData = boost::apply_visitor(outputParameterVisitor,
64  generator.network.back());
65  predictors.cols(numFunctions, numFunctions + batchSize - 1) =
66  generatedData;
67  discriminator.Forward(std::move(predictors.cols(numFunctions,
68  numFunctions + batchSize - 1)));
69  responses.cols(numFunctions, numFunctions + batchSize - 1) =
70  -arma::ones(1, batchSize);
71 
72  currentTarget = arma::mat(responses.memptr() + numFunctions,
73  1, batchSize, false, false);
74  res += discriminator.outputLayer.Forward(
75  std::move(boost::apply_visitor(
76  outputParameterVisitor,
77  discriminator.network.back())), std::move(currentTarget));
78 
79  // Gradient Penalty is calculated here.
80  double epsilon = math::Random();
81  predictors.cols(numFunctions, numFunctions + batchSize - 1) =
82  (epsilon * currentInput) + ((1.0 - epsilon) * generatedData);
83  responses.cols(numFunctions, numFunctions + batchSize - 1) =
84  -arma::ones(1, batchSize);
85  discriminator.Gradient(discriminator.parameter, numFunctions,
86  normGradientDiscriminator, batchSize);
87  res += lambda * std::pow(arma::norm(normGradientDiscriminator, 2) - 1, 2);
88 
89  return res;
90 }
91 
92 template<
93  typename Model,
94  typename InitializationRuleType,
95  typename Noise,
96  typename PolicyType
97 >
98 template<typename GradType, typename Policy>
99 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
100  double>::type
102 EvaluateWithGradient(const arma::mat& /* parameters */,
103  const size_t i,
104  GradType& gradient,
105  const size_t /* batchSize */)
106 {
107  if (parameter.is_empty())
108  {
109  Reset();
110  }
111 
112  if (gradient.is_empty())
113  {
114  if (parameter.is_empty())
115  Reset();
116  gradient = arma::zeros<arma::mat>(parameter.n_elem, 1);
117  }
118  else
119  gradient.zeros();
120 
121  if (this->deterministic)
122  {
123  this->deterministic = false;
124  ResetDeterministic();
125  }
126 
127  if (noiseGradientDiscriminator.is_empty())
128  {
129  noiseGradientDiscriminator = arma::zeros<arma::mat>(
130  gradientDiscriminator.n_elem, 1);
131  }
132  else
133  {
134  noiseGradientDiscriminator.zeros();
135  }
136 
137  gradientGenerator = arma::mat(gradient.memptr(),
138  generator.Parameters().n_elem, 1, false, false);
139 
140  gradientDiscriminator = arma::mat(gradient.memptr() +
141  gradientGenerator.n_elem,
142  discriminator.Parameters().n_elem, 1, false, false);
143 
144  currentInput = arma::mat(predictors.memptr() + (i * predictors.n_rows),
145  predictors.n_rows, batchSize, false, false);
146 
147  // Get the gradients of the Discriminator.
148  double res = discriminator.EvaluateWithGradient(discriminator.parameter,
149  i, gradientDiscriminator, batchSize);
150 
151  noise.imbue( [&]() { return noiseFunction();} );
152  generator.Forward(std::move(noise));
153  arma::mat generatedData = boost::apply_visitor(outputParameterVisitor,
154  generator.network.back());
155 
156  // Gradient Penalty is calculated here.
157  double epsilon = math::Random();
158  predictors.cols(numFunctions, numFunctions + batchSize - 1) =
159  (epsilon * currentInput) + ((1.0 - epsilon) * generatedData);
160  responses.cols(numFunctions, numFunctions + batchSize - 1) =
161  -arma::ones(1, batchSize);
162  discriminator.Gradient(discriminator.parameter, numFunctions,
163  normGradientDiscriminator, batchSize);
164  res += lambda * std::pow(arma::norm(normGradientDiscriminator, 2) - 1, 2);
165 
166  predictors.cols(numFunctions, numFunctions + batchSize - 1) =
167  generatedData;
168  res += discriminator.EvaluateWithGradient(discriminator.parameter,
169  numFunctions, noiseGradientDiscriminator, batchSize);
170  gradientDiscriminator += noiseGradientDiscriminator;
171 
172  if (currentBatch % generatorUpdateStep == 0 && preTrainSize == 0)
173  {
174  // Minimize -D(G(noise)).
175  // Pass the error from Discriminator to Generator.
176  responses.cols(numFunctions, numFunctions + batchSize - 1) =
177  arma::ones(1, batchSize);
178 
179  discriminator.outputLayer.Backward(
180  boost::apply_visitor(outputParameterVisitor,
181  discriminator.network.back()), discriminator.responses.cols(
182  numFunctions, numFunctions + batchSize - 1), discriminator.error);
183  discriminator.Backward();
184 
185  generator.error = boost::apply_visitor(deltaVisitor,
186  discriminator.network[1]);
187 
188  generator.Predictors() = noise;
189  generator.Backward();
190  generator.ResetGradients(gradientGenerator);
191  generator.Gradient(generator.Predictors().cols(0, batchSize - 1));
192 
193  gradientGenerator *= multiplier;
194  }
195 
196  currentBatch++;
197 
198  if (preTrainSize > 0)
199  {
200  preTrainSize--;
201  }
202 
203  return res;
204 }
205 
206 template<
207  typename Model,
208  typename InitializationRuleType,
209  typename Noise,
210  typename PolicyType
211 >
212 template<typename Policy>
213 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
214  void>::type
216 Gradient(const arma::mat& parameters,
217  const size_t i,
218  arma::mat& gradient,
219  const size_t batchSize)
220 {
221  this->EvaluateWithGradient(parameters, i, gradient, batchSize);
222 }
223 
224 } // namespace ann
225 } // namespace mlpack
226 # endif
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
Definition: gan_impl.hpp:398
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:239
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
double Random()
Generates a uniform random number between 0 and 1.
Definition: random.hpp:83
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:295