mlpack
Public Member Functions | List of all members
mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType > Class Template Reference

The implementation of the standard GAN module. More...

#include <gan.hpp>

Public Member Functions

 GAN (Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
 Constructor for GAN class. More...
 
 GAN (const GAN &)
 Copy constructor.
 
 GAN (GAN &&)
 Move constructor.
 
void ResetData (arma::mat trainData)
 Initialize the generator, discriminator and weights of the model for training. More...
 
void Reset ()
 
template<typename OptimizerType , typename... CallbackTypes>
double Train (arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
 Train function. More...
 
template<typename Policy = PolicyType>
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate (const arma::mat &parameters, const size_t i, const size_t batchSize)
 Evaluate function for the Standard GAN and DCGAN. More...
 
template<typename Policy = PolicyType>
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type Evaluate (const arma::mat &parameters, const size_t i, const size_t batchSize)
 Evaluate function for the WGAN. More...
 
template<typename Policy = PolicyType>
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type Evaluate (const arma::mat &parameters, const size_t i, const size_t batchSize)
 Evaluate function for the WGAN-GP. More...
 
template<typename GradType , typename Policy = PolicyType>
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient (const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
 EvaluateWithGradient function for the Standard GAN and DCGAN. More...
 
template<typename GradType , typename Policy = PolicyType>
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type EvaluateWithGradient (const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
 EvaluateWithGradient function for the WGAN. More...
 
template<typename GradType , typename Policy = PolicyType>
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type EvaluateWithGradient (const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
 EvaluateWithGradient function for the WGAN-GP. More...
 
template<typename Policy = PolicyType>
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient (const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
 Gradient function for Standard GAN and DCGAN. More...
 
template<typename Policy = PolicyType>
std::enable_if< std::is_same< Policy, WGAN >::value, void >::type Gradient (const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
 Gradient function for WGAN. More...
 
template<typename Policy = PolicyType>
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type Gradient (const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
 Gradient function for WGAN-GP. More...
 
void Shuffle ()
 Shuffle the order of function visitation. More...
 
void Forward (const arma::mat &input)
 This function does a forward pass through the GAN network. More...
 
void Predict (arma::mat input, arma::mat &output)
 This function predicts the output of the network on the given input. More...
 
const arma::mat & Parameters () const
 Return the parameters of the network.
 
arma::mat & Parameters ()
 Modify the parameters of the network.
 
const Model & Generator () const
 Return the generator of the GAN.
 
Model & Generator ()
 Modify the generator of the GAN.
 
const Model & Discriminator () const
 Return the discriminator of the GAN.
 
Model & Discriminator ()
 Modify the discriminator of the GAN.
 
size_t NumFunctions () const
 Return the number of separable functions (the number of predictor points).
 
const arma::mat & Responses () const
 Get the matrix of responses to the input data points.
 
arma::mat & Responses ()
 Modify the matrix of responses to the input data points.
 
const arma::mat & Predictors () const
 Get the matrix of data points (predictors).
 
arma::mat & Predictors ()
 Modify the matrix of data points (predictors).
 
template<typename Archive >
void serialize (Archive &ar, const uint32_t)
 Serialize the model.
 

Detailed Description

template<typename Model, typename InitializationRuleType, typename Noise, typename PolicyType = StandardGAN>
class mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >

The implementation of the standard GAN module.

Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning, implemented by a system of two neural networks contesting with each other in a zero-sum game framework. This technique can generate photographs that look at least superficially authentic to human observers, having many realistic characteristics. GANs have been used in Text-to-Image Synthesis, Medical Drug Discovery, High Resolution Imagery Generation, Neural Machine Translation and so on.

For more information, see the following paper:

@article{Goodfellow14,
author = {Ian J. Goodfellow, Jean Pouget-Abadi, Mehdi Mirza, Bing Xu,
David Warde-Farley, Sherjil Ozair, Aaron Courville and
Yoshua Bengio},
title = {Generative Adversarial Nets},
year = {2014},
url = {http://arxiv.org/abs/1406.2661},
eprint = {1406.2661},
}
Template Parameters
ModelThe class type of Generator and Discriminator.
InitializationRuleTypeType of Initializer.
NoiseThe noise function to use.
PolicyTypeThe GAN variant to be used (GAN, DCGAN, WGAN or WGANGP).

Constructor & Destructor Documentation

◆ GAN()

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType >
mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::GAN ( Model  generator,
Model  discriminator,
InitializationRuleType &  initializeRule,
Noise &  noiseFunction,
const size_t  noiseDim,
const size_t  batchSize,
const size_t  generatorUpdateStep,
const size_t  preTrainSize,
const double  multiplier,
const double  clippingParameter = 0.01,
const double  lambda = 10.0 
)

Constructor for GAN class.

Parameters
generatorGenerator network.
discriminatorDiscriminator network.
initializeRuleInitialization rule to use for initializing parameters.
noiseFunctionFunction to be used for generating noise.
noiseDimDimension of noise vector to be created.
batchSizeBatch size to be used for training.
generatorUpdateStepNumber of steps to train Discriminator before updating Generator.
preTrainSizeNumber of pre-training steps of Discriminator.
multiplierRatio of learning rate of Discriminator to the Generator.
clippingParameterWeight range for enforcing Lipschitz constraint.
lambdaParameter for setting the gradient penalty.

Member Function Documentation

◆ Evaluate() [1/3]

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType >
template<typename Policy >
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Evaluate ( const arma::mat &  parameters,
const size_t  i,
const size_t  batchSize 
)

Evaluate function for the Standard GAN and DCGAN.

This function gives the performance of the Standard GAN or DCGAN on the current input.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
batchSizeVariable to store the present number of inputs.

◆ Evaluate() [2/3]

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType = StandardGAN>
template<typename Policy = PolicyType>
std::enable_if<std::is_same<Policy, WGAN>::value, double>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Evaluate ( const arma::mat &  parameters,
const size_t  i,
const size_t  batchSize 
)

Evaluate function for the WGAN.

This function gives the performance of the WGAN on the current input.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
batchSizeVariable to store the present number of inputs.

◆ Evaluate() [3/3]

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType = StandardGAN>
template<typename Policy = PolicyType>
std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Evaluate ( const arma::mat &  parameters,
const size_t  i,
const size_t  batchSize 
)

Evaluate function for the WGAN-GP.

This function gives the performance of the WGAN-GP on the current input.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
batchSizeVariable to store the present number of inputs.

◆ EvaluateWithGradient() [1/3]

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType >
template<typename GradType , typename Policy >
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::EvaluateWithGradient ( const arma::mat &  parameters,
const size_t  i,
GradType &  gradient,
const size_t  batchSize 
)

EvaluateWithGradient function for the Standard GAN and DCGAN.

This function gives the performance of the Standard GAN or DCGAN on the current input, while updating Gradients.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ EvaluateWithGradient() [2/3]

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType = StandardGAN>
template<typename GradType , typename Policy = PolicyType>
std::enable_if<std::is_same<Policy, WGAN>::value, double>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::EvaluateWithGradient ( const arma::mat &  parameters,
const size_t  i,
GradType &  gradient,
const size_t  batchSize 
)

EvaluateWithGradient function for the WGAN.

This function gives the performance of the WGAN on the current input, while updating Gradients.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ EvaluateWithGradient() [3/3]

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType = StandardGAN>
template<typename GradType , typename Policy = PolicyType>
std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::EvaluateWithGradient ( const arma::mat &  parameters,
const size_t  i,
GradType &  gradient,
const size_t  batchSize 
)

EvaluateWithGradient function for the WGAN-GP.

This function gives the performance of the WGAN-GP on the current input, while updating Gradients.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Forward()

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType >
void mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Forward ( const arma::mat &  input)

This function does a forward pass through the GAN network.

Parameters
inputSampled noise.

◆ Gradient() [1/3]

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType >
template<typename Policy >
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Gradient ( const arma::mat &  parameters,
const size_t  i,
arma::mat &  gradient,
const size_t  batchSize 
)

Gradient function for Standard GAN and DCGAN.

This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.

Parameters
parameterspresent parameters of the network.
iIndex of the predictors.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Gradient() [2/3]

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType = StandardGAN>
template<typename Policy = PolicyType>
std::enable_if<std::is_same<Policy, WGAN>::value, void>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Gradient ( const arma::mat &  parameters,
const size_t  i,
arma::mat &  gradient,
const size_t  batchSize 
)

Gradient function for WGAN.

This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.

Parameters
parameterspresent parameters of the network.
iIndex of the predictors.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Gradient() [3/3]

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType = StandardGAN>
template<typename Policy = PolicyType>
std::enable_if<std::is_same<Policy, WGANGP>::value, void>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Gradient ( const arma::mat &  parameters,
const size_t  i,
arma::mat &  gradient,
const size_t  batchSize 
)

Gradient function for WGAN-GP.

This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.

Parameters
parameterspresent parameters of the network.
iIndex of the predictors.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Predict()

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType >
void mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Predict ( arma::mat  input,
arma::mat &  output 
)

This function predicts the output of the network on the given input.

Parameters
inputThe input of the Generator network.
outputResult of the Discriminator network.

◆ ResetData()

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType >
void mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::ResetData ( arma::mat  trainData)

Initialize the generator, discriminator and weights of the model for training.

This function won't actually trigger training process.

Parameters
trainDataThe data points of real distribution.

These predictors are shared by the discriminator network. The additional batch size predictors are taken from the generator network while training. For more details please look in EvaluateWithGradient() function.

◆ Shuffle()

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType >
void mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Shuffle ( )

Shuffle the order of function visitation.

This may be called by the optimizer.

◆ Train()

template<typename Model , typename InitializationRuleType , typename Noise , typename PolicyType >
template<typename OptimizerType , typename... CallbackTypes>
double mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Train ( arma::mat  trainData,
OptimizerType &  Optimizer,
CallbackTypes &&...  callbacks 
)

Train function.

Template Parameters
OptimizerTypeType of optimizer to use to train the model.
CallbackTypesTypes of Callback functions.
Parameters
trainDataThe data points of real distribution.
OptimizerInstantiated optimizer used to train the model.
callbacksCallback function for ensmallen optimizer OptimizerType. See https://www.ensmallen.org/docs.html#callback-documentation.
Returns
The final objective of the trained model (NaN or Inf on error).

The documentation for this class was generated from the following files: