mlpack
gan_impl.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_IMPL_HPP
12 #define MLPACK_METHODS_ANN_GAN_GAN_IMPL_HPP
13 
14 #include "gan.hpp"
15 
16 #include <mlpack/core.hpp>
17 
22 
23 namespace mlpack {
24 namespace ann {
25 template<
26  typename Model,
27  typename InitializationRuleType,
28  typename Noise,
29  typename PolicyType
30 >
32  Model generator,
33  Model discriminator,
34  InitializationRuleType& initializeRule,
35  Noise& noiseFunction,
36  const size_t noiseDim,
37  const size_t batchSize,
38  const size_t generatorUpdateStep,
39  const size_t preTrainSize,
40  const double multiplier,
41  const double clippingParameter,
42  const double lambda):
43  generator(std::move(generator)),
44  discriminator(std::move(discriminator)),
45  initializeRule(initializeRule),
46  noiseFunction(noiseFunction),
47  noiseDim(noiseDim),
48  numFunctions(0),
49  batchSize(batchSize),
50  currentBatch(0),
51  generatorUpdateStep(generatorUpdateStep),
52  preTrainSize(preTrainSize),
53  multiplier(multiplier),
54  clippingParameter(clippingParameter),
55  lambda(lambda),
56  reset(false),
57  deterministic(false),
58  genWeights(0),
59  discWeights(0)
60 {
61  // Insert IdentityLayer for joining the Generator and Discriminator.
62  this->discriminator.network.insert(
63  this->discriminator.network.begin(),
64  new IdentityLayer<>());
65 }
66 
67 template<
68  typename Model,
69  typename InitializationRuleType,
70  typename Noise,
71  typename PolicyType
72 >
74  const GAN& network):
75  predictors(network.predictors),
76  responses(network.responses),
77  generator(network.generator),
78  discriminator(network.discriminator),
79  initializeRule(network.initializeRule),
80  noiseFunction(network.noiseFunction),
81  noiseDim(network.noiseDim),
82  batchSize(network.batchSize),
83  generatorUpdateStep(network.generatorUpdateStep),
84  preTrainSize(network.preTrainSize),
85  multiplier(network.multiplier),
86  clippingParameter(network.clippingParameter),
87  lambda(network.lambda),
88  reset(network.reset),
89  currentBatch(network.currentBatch),
90  parameter(network.parameter),
91  numFunctions(network.numFunctions),
92  noise(network.noise),
93  deterministic(network.deterministic),
94  genWeights(network.genWeights),
95  discWeights(network.discWeights)
96 {
97  /* Nothing to do here */
98 }
99 
100 template<
101  typename Model,
102  typename InitializationRuleType,
103  typename Noise,
104  typename PolicyType
105 >
107  GAN&& network):
108  predictors(std::move(network.predictors)),
109  responses(std::move(network.responses)),
110  generator(std::move(network.generator)),
111  discriminator(std::move(network.discriminator)),
112  initializeRule(std::move(network.initializeRule)),
113  noiseFunction(std::move(network.noiseFunction)),
114  noiseDim(network.noiseDim),
115  batchSize(network.batchSize),
116  generatorUpdateStep(network.generatorUpdateStep),
117  preTrainSize(network.preTrainSize),
118  multiplier(network.multiplier),
119  clippingParameter(network.clippingParameter),
120  lambda(network.lambda),
121  reset(network.reset),
122  currentBatch(network.currentBatch),
123  parameter(std::move(network.parameter)),
124  numFunctions(network.numFunctions),
125  noise(std::move(network.noise)),
126  deterministic(network.deterministic),
127  genWeights(network.genWeights),
128  discWeights(network.discWeights)
129 {
130  /* Nothing to do here */
131 }
132 
133 template<
134  typename Model,
135  typename InitializationRuleType,
136  typename Noise,
137  typename PolicyType
138 >
140  arma::mat trainData)
141 {
142  currentBatch = 0;
143 
144  numFunctions = trainData.n_cols;
145  noise.set_size(noiseDim, batchSize);
146 
147  deterministic = true;
148  ResetDeterministic();
149 
155  this->predictors.set_size(trainData.n_rows, numFunctions + batchSize);
156  this->predictors.cols(0, numFunctions - 1) = std::move(trainData);
157  this->discriminator.predictors = arma::mat(this->predictors.memptr(),
158  this->predictors.n_rows, this->predictors.n_cols, false, false);
159 
160  responses.ones(1, numFunctions + batchSize);
161  responses.cols(numFunctions, numFunctions + batchSize - 1) =
162  arma::zeros(1, batchSize);
163  this->discriminator.responses = arma::mat(this->responses.memptr(),
164  this->responses.n_rows, this->responses.n_cols, false, false);
165 
166  this->generator.predictors.set_size(noiseDim, batchSize);
167  this->generator.responses.set_size(predictors.n_rows, batchSize);
168 
169  if (!reset)
170  {
171  Reset();
172  }
173 }
174 
175 template<
176  typename Model,
177  typename InitializationRuleType,
178  typename Noise,
179  typename PolicyType
180 >
182 {
183  genWeights = 0;
184  discWeights = 0;
185 
186  NetworkInitialization<InitializationRuleType> networkInit(initializeRule);
187 
188  for (size_t i = 0; i < generator.network.size(); ++i)
189  {
190  genWeights += boost::apply_visitor(weightSizeVisitor, generator.network[i]);
191  }
192 
193  for (size_t i = 0; i < discriminator.network.size(); ++i)
194  {
195  discWeights += boost::apply_visitor(weightSizeVisitor,
196  discriminator.network[i]);
197  }
198 
199  parameter.set_size(genWeights + discWeights, 1);
200  generator.Parameters() = arma::mat(parameter.memptr(), genWeights, 1, false,
201  false);
202  discriminator.Parameters() = arma::mat(parameter.memptr() + genWeights,
203  discWeights, 1, false, false);
204 
205  // Initialize the parameters generator
206  networkInit.Initialize(generator.network, parameter);
207  // Initialize the parameters discriminator
208  networkInit.Initialize(discriminator.network, parameter, genWeights);
209 
210  reset = true;
211 }
212 
213 template<
214  typename Model,
215  typename InitializationRuleType,
216  typename Noise,
217  typename PolicyType
218 >
219 template<typename OptimizerType, typename... CallbackTypes>
221  arma::mat trainData,
222  OptimizerType& Optimizer,
223  CallbackTypes&&... callbacks)
224 {
225  ResetData(std::move(trainData));
226 
227  return Optimizer.Optimize(*this, parameter, callbacks...);
228 }
229 
230 template<
231  typename Model,
232  typename InitializationRuleType,
233  typename Noise,
234  typename PolicyType
235 >
236 template<typename Policy>
237 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
238  std::is_same<Policy, DCGAN>::value, double>::type
240  const arma::mat& /* parameters */,
241  const size_t i,
242  const size_t /* batchSize */)
243 {
244  if (parameter.is_empty())
245  {
246  Reset();
247  }
248 
249  if (!deterministic)
250  {
251  deterministic = true;
252  ResetDeterministic();
253  }
254 
255  currentInput = arma::mat(predictors.memptr() + (i * predictors.n_rows),
256  predictors.n_rows, batchSize, false, false);
257  currentTarget = arma::mat(responses.memptr() + i, 1, batchSize, false,
258  false);
259 
260  discriminator.Forward(currentInput);
261  double res = discriminator.outputLayer.Forward(
262  boost::apply_visitor(
263  outputParameterVisitor,
264  discriminator.network.back()), currentTarget);
265 
266  noise.imbue( [&]() { return noiseFunction();} );
267  generator.Forward(noise);
268 
269  predictors.cols(numFunctions, numFunctions + batchSize - 1) =
270  boost::apply_visitor(outputParameterVisitor, generator.network.back());
271  discriminator.Forward(predictors.cols(numFunctions,
272  numFunctions + batchSize - 1));
273  responses.cols(numFunctions, numFunctions + batchSize - 1) =
274  arma::zeros(1, batchSize);
275 
276  currentTarget = arma::mat(responses.memptr() + numFunctions,
277  1, batchSize, false, false);
278  res += discriminator.outputLayer.Forward(
279  boost::apply_visitor(outputParameterVisitor,
280  discriminator.network.back()), currentTarget);
281 
282  return res;
283 }
284 
285 template<
286  typename Model,
287  typename InitializationRuleType,
288  typename Noise,
289  typename PolicyType
290 >
291 template<typename GradType, typename Policy>
292 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
293  std::is_same<Policy, DCGAN>::value, double>::type
295 EvaluateWithGradient(const arma::mat& /* parameters */,
296  const size_t i,
297  GradType& gradient,
298  const size_t /* batchSize */)
299 {
300  if (parameter.is_empty())
301  {
302  Reset();
303  }
304 
305  if (gradient.is_empty())
306  {
307  if (parameter.is_empty())
308  Reset();
309  gradient = arma::zeros<arma::mat>(parameter.n_elem, 1);
310  }
311  else
312  gradient.zeros();
313 
314  if (this->deterministic)
315  {
316  this->deterministic = false;
317  ResetDeterministic();
318  }
319 
320  if (noiseGradientDiscriminator.is_empty())
321  {
322  noiseGradientDiscriminator = arma::zeros<arma::mat>(
323  gradientDiscriminator.n_elem, 1);
324  }
325  else
326  {
327  noiseGradientDiscriminator.zeros();
328  }
329 
330  gradientGenerator = arma::mat(gradient.memptr(),
331  generator.Parameters().n_elem, 1, false, false);
332 
333  gradientDiscriminator = arma::mat(gradient.memptr() +
334  gradientGenerator.n_elem,
335  discriminator.Parameters().n_elem, 1, false, false);
336 
337  // Get the gradients of the Discriminator.
338  double res = discriminator.EvaluateWithGradient(discriminator.parameter,
339  i, gradientDiscriminator, batchSize);
340 
341  noise.imbue( [&]() { return noiseFunction();} );
342  generator.Forward(noise);
343  predictors.cols(numFunctions, numFunctions + batchSize - 1) =
344  boost::apply_visitor(outputParameterVisitor, generator.network.back());
345  responses.cols(numFunctions, numFunctions + batchSize - 1) =
346  arma::zeros(1, batchSize);
347 
348  // Get the gradients of the Generator.
349  res += discriminator.EvaluateWithGradient(discriminator.parameter,
350  numFunctions, noiseGradientDiscriminator, batchSize);
351  gradientDiscriminator += noiseGradientDiscriminator;
352 
353  if (currentBatch % generatorUpdateStep == 0 && preTrainSize == 0)
354  {
355  // Minimize -log(D(G(noise))).
356  // Pass the error from Discriminator to Generator.
357  responses.cols(numFunctions, numFunctions + batchSize - 1) =
358  arma::ones(1, batchSize);
359 
360  discriminator.outputLayer.Backward(
361  boost::apply_visitor(outputParameterVisitor,
362  discriminator.network.back()), discriminator.responses.cols(
363  numFunctions, numFunctions + batchSize - 1), discriminator.error);
364  discriminator.Backward();
365 
366  generator.error = boost::apply_visitor(deltaVisitor,
367  discriminator.network[1]);
368 
369  generator.Predictors() = noise;
370  generator.Backward();
371  generator.ResetGradients(gradientGenerator);
372  generator.Gradient(generator.Predictors().cols(0, batchSize - 1));
373 
374  gradientGenerator *= multiplier;
375  }
376 
377  currentBatch++;
378 
379 
380  if (preTrainSize > 0)
381  {
382  preTrainSize--;
383  }
384 
385  return res;
386 }
387 
388 template<
389  typename Model,
390  typename InitializationRuleType,
391  typename Noise,
392  typename PolicyType
393 >
394 template<typename Policy>
395 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
396  std::is_same<Policy, DCGAN>::value, void>::type
398 Gradient(const arma::mat& parameters,
399  const size_t i,
400  arma::mat& gradient,
401  const size_t batchSize)
402 {
403  this->EvaluateWithGradient(parameters, i, gradient, batchSize);
404 }
405 
406 template<
407  typename Model,
408  typename InitializationRuleType,
409  typename Noise,
410  typename PolicyType
411 >
413 {
414  const arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
415  numFunctions - 1, numFunctions));
416  predictors.cols(0, numFunctions - 1) = predictors.cols(ordering);
417 }
418 
419 template<
420  typename Model,
421  typename InitializationRuleType,
422  typename Noise,
423  typename PolicyType
424 >
426  const arma::mat& input)
427 {
428  if (parameter.is_empty())
429  {
430  Reset();
431  }
432 
433  generator.Forward(input);
434  arma::mat ganOutput = boost::apply_visitor(outputParameterVisitor,
435  generator.network.back());
436 
437  discriminator.Forward(ganOutput);
438 }
439 
440 template<
441  typename Model,
442  typename InitializationRuleType,
443  typename Noise,
444  typename PolicyType
445 >
447 Predict(arma::mat input, arma::mat& output)
448 {
449  if (parameter.is_empty())
450  {
451  Reset();
452  }
453 
454  if (!deterministic)
455  {
456  deterministic = true;
457  ResetDeterministic();
458  }
459 
460  Forward(input);
461 
462  output = boost::apply_visitor(outputParameterVisitor,
463  discriminator.network.back());
464 }
465 
466 template<
467  typename Model,
468  typename InitializationRuleType,
469  typename Noise,
470  typename PolicyType
471 >
474 {
475  this->discriminator.deterministic = deterministic;
476  this->generator.deterministic = deterministic;
477  this->discriminator.ResetDeterministic();
478  this->generator.ResetDeterministic();
479 }
480 
481 template<
482  typename Model,
483  typename InitializationRuleType,
484  typename Noise,
485  typename PolicyType
486 >
487 template<typename Archive>
489 serialize(Archive& ar, const uint32_t /* version */)
490 {
491  ar(CEREAL_NVP(parameter));
492  ar(CEREAL_NVP(generator));
493  ar(CEREAL_NVP(discriminator));
494  ar(CEREAL_NVP(reset));
495  ar(CEREAL_NVP(genWeights));
496  ar(CEREAL_NVP(discWeights));
497 
498  if (cereal::is_loading<Archive>())
499  {
500  // Share the parameters between the network.
501  generator.Parameters() = arma::mat(parameter.memptr(), genWeights, 1, false,
502  false);
503  discriminator.Parameters() = arma::mat(parameter.memptr() + genWeights,
504  discWeights, 1, false, false);
505 
506  size_t offset = 0;
507  for (size_t i = 0; i < generator.network.size(); ++i)
508  {
509  offset += boost::apply_visitor(WeightSetVisitor(
510  generator.parameter, offset), generator.network[i]);
511 
512  boost::apply_visitor(resetVisitor, generator.network[i]);
513  }
514 
515  offset = 0;
516  for (size_t i = 0; i < discriminator.network.size(); ++i)
517  {
518  offset += boost::apply_visitor(WeightSetVisitor(
519  discriminator.parameter, offset), discriminator.network[i]);
520 
521  boost::apply_visitor(resetVisitor, discriminator.network[i]);
522  }
523 
524  deterministic = true;
525  ResetDeterministic();
526  }
527 }
528 
529 } // namespace ann
530 } // namespace mlpack
531 # endif
void Shuffle()
Shuffle the order of function visitation.
Definition: gan_impl.hpp:412
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
Definition: gan_impl.hpp:398
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Initialize(const std::vector< LayerTypes< CustomLayers... > > &network, arma::Mat< eT > &parameter, size_t parameterOffset=0)
Initialize the specified network and store the results in the given parameter.
Definition: network_init.hpp:57
Definition: pointer_wrapper.hpp:23
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function.
Definition: gan_impl.hpp:220
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
Definition: gan_impl.hpp:31
Implementation of the base layer.
Definition: base_layer.hpp:71
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input.
Definition: gan_impl.hpp:447
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network.
Definition: gan_impl.hpp:425
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:239
WeightSetVisitor update the module parameters given the parameters set.
Definition: weight_set_visitor.hpp:26
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: gan_impl.hpp:489
The implementation of the standard GAN module.
Definition: gan.hpp:63
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training.
Definition: gan_impl.hpp:139
This class is used to initialize the network with the given initialization rule.
Definition: network_init.hpp:33
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:295