11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_IMPL_HPP 12 #define MLPACK_METHODS_ANN_GAN_GAN_IMPL_HPP 27 typename InitializationRuleType,
34 InitializationRuleType& initializeRule,
36 const size_t noiseDim,
37 const size_t batchSize,
38 const size_t generatorUpdateStep,
39 const size_t preTrainSize,
40 const double multiplier,
41 const double clippingParameter,
43 generator(
std::move(generator)),
44 discriminator(
std::move(discriminator)),
45 initializeRule(initializeRule),
46 noiseFunction(noiseFunction),
51 generatorUpdateStep(generatorUpdateStep),
52 preTrainSize(preTrainSize),
53 multiplier(multiplier),
54 clippingParameter(clippingParameter),
62 this->discriminator.network.insert(
63 this->discriminator.network.begin(),
69 typename InitializationRuleType,
75 predictors(network.predictors),
76 responses(network.responses),
77 generator(network.generator),
78 discriminator(network.discriminator),
79 initializeRule(network.initializeRule),
80 noiseFunction(network.noiseFunction),
81 noiseDim(network.noiseDim),
82 batchSize(network.batchSize),
83 generatorUpdateStep(network.generatorUpdateStep),
84 preTrainSize(network.preTrainSize),
85 multiplier(network.multiplier),
86 clippingParameter(network.clippingParameter),
87 lambda(network.lambda),
89 currentBatch(network.currentBatch),
90 parameter(network.parameter),
91 numFunctions(network.numFunctions),
93 deterministic(network.deterministic),
94 genWeights(network.genWeights),
95 discWeights(network.discWeights)
102 typename InitializationRuleType,
108 predictors(
std::move(network.predictors)),
109 responses(
std::move(network.responses)),
110 generator(
std::move(network.generator)),
111 discriminator(
std::move(network.discriminator)),
112 initializeRule(
std::move(network.initializeRule)),
113 noiseFunction(
std::move(network.noiseFunction)),
114 noiseDim(network.noiseDim),
115 batchSize(network.batchSize),
116 generatorUpdateStep(network.generatorUpdateStep),
117 preTrainSize(network.preTrainSize),
118 multiplier(network.multiplier),
119 clippingParameter(network.clippingParameter),
120 lambda(network.lambda),
121 reset(network.reset),
122 currentBatch(network.currentBatch),
123 parameter(
std::move(network.parameter)),
124 numFunctions(network.numFunctions),
125 noise(
std::move(network.noise)),
126 deterministic(network.deterministic),
127 genWeights(network.genWeights),
128 discWeights(network.discWeights)
135 typename InitializationRuleType,
144 numFunctions = trainData.n_cols;
145 noise.set_size(noiseDim, batchSize);
147 deterministic =
true;
148 ResetDeterministic();
155 this->predictors.set_size(trainData.n_rows, numFunctions + batchSize);
156 this->predictors.cols(0, numFunctions - 1) = std::move(trainData);
157 this->discriminator.predictors = arma::mat(this->predictors.memptr(),
158 this->predictors.n_rows, this->predictors.n_cols,
false,
false);
160 responses.ones(1, numFunctions + batchSize);
161 responses.cols(numFunctions, numFunctions + batchSize - 1) =
162 arma::zeros(1, batchSize);
163 this->discriminator.responses = arma::mat(this->responses.memptr(),
164 this->responses.n_rows, this->responses.n_cols,
false,
false);
166 this->generator.predictors.set_size(noiseDim, batchSize);
167 this->generator.responses.set_size(predictors.n_rows, batchSize);
177 typename InitializationRuleType,
188 for (
size_t i = 0; i < generator.network.size(); ++i)
190 genWeights += boost::apply_visitor(weightSizeVisitor, generator.network[i]);
193 for (
size_t i = 0; i < discriminator.network.size(); ++i)
195 discWeights += boost::apply_visitor(weightSizeVisitor,
196 discriminator.network[i]);
199 parameter.set_size(genWeights + discWeights, 1);
200 generator.Parameters() = arma::mat(parameter.memptr(), genWeights, 1,
false,
202 discriminator.Parameters() = arma::mat(parameter.memptr() + genWeights,
203 discWeights, 1,
false,
false);
206 networkInit.
Initialize(generator.network, parameter);
208 networkInit.
Initialize(discriminator.network, parameter, genWeights);
215 typename InitializationRuleType,
219 template<
typename OptimizerType,
typename... CallbackTypes>
222 OptimizerType& Optimizer,
223 CallbackTypes&&... callbacks)
227 return Optimizer.Optimize(*
this, parameter, callbacks...);
232 typename InitializationRuleType,
236 template<
typename Policy>
237 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
238 std::is_same<Policy, DCGAN>::value,
double>::type
244 if (parameter.is_empty())
251 deterministic =
true;
252 ResetDeterministic();
255 currentInput = arma::mat(predictors.memptr() + (i * predictors.n_rows),
256 predictors.n_rows, batchSize,
false,
false);
257 currentTarget = arma::mat(responses.memptr() + i, 1, batchSize,
false,
260 discriminator.Forward(currentInput);
261 double res = discriminator.outputLayer.Forward(
262 boost::apply_visitor(
263 outputParameterVisitor,
264 discriminator.network.back()), currentTarget);
266 noise.imbue( [&]() {
return noiseFunction();} );
267 generator.Forward(noise);
269 predictors.cols(numFunctions, numFunctions + batchSize - 1) =
270 boost::apply_visitor(outputParameterVisitor, generator.network.back());
271 discriminator.Forward(predictors.cols(numFunctions,
272 numFunctions + batchSize - 1));
273 responses.cols(numFunctions, numFunctions + batchSize - 1) =
274 arma::zeros(1, batchSize);
276 currentTarget = arma::mat(responses.memptr() + numFunctions,
277 1, batchSize,
false,
false);
278 res += discriminator.outputLayer.Forward(
279 boost::apply_visitor(outputParameterVisitor,
280 discriminator.network.back()), currentTarget);
287 typename InitializationRuleType,
291 template<
typename GradType,
typename Policy>
292 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
293 std::is_same<Policy, DCGAN>::value,
double>::type
300 if (parameter.is_empty())
305 if (gradient.is_empty())
307 if (parameter.is_empty())
309 gradient = arma::zeros<arma::mat>(parameter.n_elem, 1);
314 if (this->deterministic)
316 this->deterministic =
false;
317 ResetDeterministic();
320 if (noiseGradientDiscriminator.is_empty())
322 noiseGradientDiscriminator = arma::zeros<arma::mat>(
323 gradientDiscriminator.n_elem, 1);
327 noiseGradientDiscriminator.zeros();
330 gradientGenerator = arma::mat(gradient.memptr(),
331 generator.Parameters().n_elem, 1,
false,
false);
333 gradientDiscriminator = arma::mat(gradient.memptr() +
334 gradientGenerator.n_elem,
335 discriminator.Parameters().n_elem, 1,
false,
false);
338 double res = discriminator.EvaluateWithGradient(discriminator.parameter,
339 i, gradientDiscriminator, batchSize);
341 noise.imbue( [&]() {
return noiseFunction();} );
342 generator.Forward(noise);
343 predictors.cols(numFunctions, numFunctions + batchSize - 1) =
344 boost::apply_visitor(outputParameterVisitor, generator.network.back());
345 responses.cols(numFunctions, numFunctions + batchSize - 1) =
346 arma::zeros(1, batchSize);
349 res += discriminator.EvaluateWithGradient(discriminator.parameter,
350 numFunctions, noiseGradientDiscriminator, batchSize);
351 gradientDiscriminator += noiseGradientDiscriminator;
353 if (currentBatch % generatorUpdateStep == 0 && preTrainSize == 0)
357 responses.cols(numFunctions, numFunctions + batchSize - 1) =
358 arma::ones(1, batchSize);
360 discriminator.outputLayer.Backward(
361 boost::apply_visitor(outputParameterVisitor,
362 discriminator.network.back()), discriminator.responses.cols(
363 numFunctions, numFunctions + batchSize - 1), discriminator.error);
364 discriminator.Backward();
366 generator.error = boost::apply_visitor(deltaVisitor,
367 discriminator.network[1]);
369 generator.Predictors() = noise;
370 generator.Backward();
371 generator.ResetGradients(gradientGenerator);
372 generator.Gradient(generator.Predictors().cols(0, batchSize - 1));
374 gradientGenerator *= multiplier;
380 if (preTrainSize > 0)
390 typename InitializationRuleType,
394 template<
typename Policy>
395 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
396 std::is_same<Policy, DCGAN>::value,
void>::type
401 const size_t batchSize)
408 typename InitializationRuleType,
414 const arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
415 numFunctions - 1, numFunctions));
416 predictors.cols(0, numFunctions - 1) = predictors.cols(ordering);
421 typename InitializationRuleType,
426 const arma::mat& input)
428 if (parameter.is_empty())
433 generator.Forward(input);
434 arma::mat ganOutput = boost::apply_visitor(outputParameterVisitor,
435 generator.network.back());
437 discriminator.Forward(ganOutput);
442 typename InitializationRuleType,
449 if (parameter.is_empty())
456 deterministic =
true;
457 ResetDeterministic();
462 output = boost::apply_visitor(outputParameterVisitor,
463 discriminator.network.back());
468 typename InitializationRuleType,
475 this->discriminator.deterministic = deterministic;
476 this->generator.deterministic = deterministic;
477 this->discriminator.ResetDeterministic();
478 this->generator.ResetDeterministic();
483 typename InitializationRuleType,
487 template<
typename Archive>
491 ar(CEREAL_NVP(parameter));
492 ar(CEREAL_NVP(generator));
493 ar(CEREAL_NVP(discriminator));
494 ar(CEREAL_NVP(reset));
495 ar(CEREAL_NVP(genWeights));
496 ar(CEREAL_NVP(discWeights));
498 if (cereal::is_loading<Archive>())
501 generator.Parameters() = arma::mat(parameter.memptr(), genWeights, 1,
false,
503 discriminator.Parameters() = arma::mat(parameter.memptr() + genWeights,
504 discWeights, 1,
false,
false);
507 for (
size_t i = 0; i < generator.network.size(); ++i)
510 generator.parameter, offset), generator.network[i]);
512 boost::apply_visitor(resetVisitor, generator.network[i]);
516 for (
size_t i = 0; i < discriminator.network.size(); ++i)
519 discriminator.parameter, offset), discriminator.network[i]);
521 boost::apply_visitor(resetVisitor, discriminator.network[i]);
524 deterministic =
true;
525 ResetDeterministic();
void Shuffle()
Shuffle the order of function visitation.
Definition: gan_impl.hpp:412
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
Definition: gan_impl.hpp:398
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Initialize(const std::vector< LayerTypes< CustomLayers... > > &network, arma::Mat< eT > ¶meter, size_t parameterOffset=0)
Initialize the specified network and store the results in the given parameter.
Definition: network_init.hpp:57
Definition: pointer_wrapper.hpp:23
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function.
Definition: gan_impl.hpp:220
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
Definition: gan_impl.hpp:31
Implementation of the base layer.
Definition: base_layer.hpp:71
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input.
Definition: gan_impl.hpp:447
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network.
Definition: gan_impl.hpp:425
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:239
WeightSetVisitor update the module parameters given the parameters set.
Definition: weight_set_visitor.hpp:26
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: gan_impl.hpp:489
The implementation of the standard GAN module.
Definition: gan.hpp:63
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training.
Definition: gan_impl.hpp:139
This class is used to initialize the network with the given initialization rule.
Definition: network_init.hpp:33
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:295