mlpack
|
The implementation of the standard GAN module. More...
#include <gan.hpp>
Public Member Functions | |
GAN (Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0) | |
Constructor for GAN class. More... | |
GAN (const GAN &) | |
Copy constructor. | |
GAN (GAN &&) | |
Move constructor. | |
void | ResetData (arma::mat trainData) |
Initialize the generator, discriminator and weights of the model for training. More... | |
void | Reset () |
template<typename OptimizerType , typename... CallbackTypes> | |
double | Train (arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks) |
Train function. More... | |
template<typename Policy = PolicyType> | |
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type | Evaluate (const arma::mat ¶meters, const size_t i, const size_t batchSize) |
Evaluate function for the Standard GAN and DCGAN. More... | |
template<typename Policy = PolicyType> | |
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type | Evaluate (const arma::mat ¶meters, const size_t i, const size_t batchSize) |
Evaluate function for the WGAN. More... | |
template<typename Policy = PolicyType> | |
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type | Evaluate (const arma::mat ¶meters, const size_t i, const size_t batchSize) |
Evaluate function for the WGAN-GP. More... | |
template<typename GradType , typename Policy = PolicyType> | |
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type | EvaluateWithGradient (const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize) |
EvaluateWithGradient function for the Standard GAN and DCGAN. More... | |
template<typename GradType , typename Policy = PolicyType> | |
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type | EvaluateWithGradient (const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize) |
EvaluateWithGradient function for the WGAN. More... | |
template<typename GradType , typename Policy = PolicyType> | |
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type | EvaluateWithGradient (const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize) |
EvaluateWithGradient function for the WGAN-GP. More... | |
template<typename Policy = PolicyType> | |
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type | Gradient (const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize) |
Gradient function for Standard GAN and DCGAN. More... | |
template<typename Policy = PolicyType> | |
std::enable_if< std::is_same< Policy, WGAN >::value, void >::type | Gradient (const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize) |
Gradient function for WGAN. More... | |
template<typename Policy = PolicyType> | |
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type | Gradient (const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize) |
Gradient function for WGAN-GP. More... | |
void | Shuffle () |
Shuffle the order of function visitation. More... | |
void | Forward (const arma::mat &input) |
This function does a forward pass through the GAN network. More... | |
void | Predict (arma::mat input, arma::mat &output) |
This function predicts the output of the network on the given input. More... | |
const arma::mat & | Parameters () const |
Return the parameters of the network. | |
arma::mat & | Parameters () |
Modify the parameters of the network. | |
const Model & | Generator () const |
Return the generator of the GAN. | |
Model & | Generator () |
Modify the generator of the GAN. | |
const Model & | Discriminator () const |
Return the discriminator of the GAN. | |
Model & | Discriminator () |
Modify the discriminator of the GAN. | |
size_t | NumFunctions () const |
Return the number of separable functions (the number of predictor points). | |
const arma::mat & | Responses () const |
Get the matrix of responses to the input data points. | |
arma::mat & | Responses () |
Modify the matrix of responses to the input data points. | |
const arma::mat & | Predictors () const |
Get the matrix of data points (predictors). | |
arma::mat & | Predictors () |
Modify the matrix of data points (predictors). | |
template<typename Archive > | |
void | serialize (Archive &ar, const uint32_t) |
Serialize the model. | |
The implementation of the standard GAN module.
Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning, implemented by a system of two neural networks contesting with each other in a zero-sum game framework. This technique can generate photographs that look at least superficially authentic to human observers, having many realistic characteristics. GANs have been used in Text-to-Image Synthesis, Medical Drug Discovery, High Resolution Imagery Generation, Neural Machine Translation and so on.
For more information, see the following paper:
mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::GAN | ( | Model | generator, |
Model | discriminator, | ||
InitializationRuleType & | initializeRule, | ||
Noise & | noiseFunction, | ||
const size_t | noiseDim, | ||
const size_t | batchSize, | ||
const size_t | generatorUpdateStep, | ||
const size_t | preTrainSize, | ||
const double | multiplier, | ||
const double | clippingParameter = 0.01 , |
||
const double | lambda = 10.0 |
||
) |
Constructor for GAN class.
generator | Generator network. |
discriminator | Discriminator network. |
initializeRule | Initialization rule to use for initializing parameters. |
noiseFunction | Function to be used for generating noise. |
noiseDim | Dimension of noise vector to be created. |
batchSize | Batch size to be used for training. |
generatorUpdateStep | Number of steps to train Discriminator before updating Generator. |
preTrainSize | Number of pre-training steps of Discriminator. |
multiplier | Ratio of learning rate of Discriminator to the Generator. |
clippingParameter | Weight range for enforcing Lipschitz constraint. |
lambda | Parameter for setting the gradient penalty. |
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Evaluate | ( | const arma::mat & | parameters, |
const size_t | i, | ||
const size_t | batchSize | ||
) |
std::enable_if<std::is_same<Policy, WGAN>::value, double>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Evaluate | ( | const arma::mat & | parameters, |
const size_t | i, | ||
const size_t | batchSize | ||
) |
std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Evaluate | ( | const arma::mat & | parameters, |
const size_t | i, | ||
const size_t | batchSize | ||
) |
Evaluate function for the WGAN-GP.
This function gives the performance of the WGAN-GP on the current input.
parameters | The parameters of the network. |
i | Index of the current input. |
batchSize | Variable to store the present number of inputs. |
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::EvaluateWithGradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
GradType & | gradient, | ||
const size_t | batchSize | ||
) |
EvaluateWithGradient function for the Standard GAN and DCGAN.
This function gives the performance of the Standard GAN or DCGAN on the current input, while updating Gradients.
parameters | The parameters of the network. |
i | Index of the current input. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
std::enable_if<std::is_same<Policy, WGAN>::value, double>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::EvaluateWithGradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
GradType & | gradient, | ||
const size_t | batchSize | ||
) |
EvaluateWithGradient function for the WGAN.
This function gives the performance of the WGAN on the current input, while updating Gradients.
parameters | The parameters of the network. |
i | Index of the current input. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::EvaluateWithGradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
GradType & | gradient, | ||
const size_t | batchSize | ||
) |
EvaluateWithGradient function for the WGAN-GP.
This function gives the performance of the WGAN-GP on the current input, while updating Gradients.
parameters | The parameters of the network. |
i | Index of the current input. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
void mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Forward | ( | const arma::mat & | input | ) |
This function does a forward pass through the GAN network.
input | Sampled noise. |
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Gradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
arma::mat & | gradient, | ||
const size_t | batchSize | ||
) |
Gradient function for Standard GAN and DCGAN.
This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.
parameters | present parameters of the network. |
i | Index of the predictors. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
std::enable_if<std::is_same<Policy, WGAN>::value, void>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Gradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
arma::mat & | gradient, | ||
const size_t | batchSize | ||
) |
Gradient function for WGAN.
This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.
parameters | present parameters of the network. |
i | Index of the predictors. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
std::enable_if<std::is_same<Policy, WGANGP>::value, void>::type mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Gradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
arma::mat & | gradient, | ||
const size_t | batchSize | ||
) |
Gradient function for WGAN-GP.
This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.
parameters | present parameters of the network. |
i | Index of the predictors. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
void mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Predict | ( | arma::mat | input, |
arma::mat & | output | ||
) |
This function predicts the output of the network on the given input.
input | The input of the Generator network. |
output | Result of the Discriminator network. |
void mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::ResetData | ( | arma::mat | trainData | ) |
Initialize the generator, discriminator and weights of the model for training.
This function won't actually trigger training process.
trainData | The data points of real distribution. |
These predictors are shared by the discriminator network. The additional batch size predictors are taken from the generator network while training. For more details please look in EvaluateWithGradient() function.
void mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Shuffle | ( | ) |
Shuffle the order of function visitation.
This may be called by the optimizer.
double mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >::Train | ( | arma::mat | trainData, |
OptimizerType & | Optimizer, | ||
CallbackTypes &&... | callbacks | ||
) |
Train function.
OptimizerType | Type of optimizer to use to train the model. |
CallbackTypes | Types of Callback functions. |
trainData | The data points of real distribution. |
Optimizer | Instantiated optimizer used to train the model. |
callbacks | Callback function for ensmallen optimizer OptimizerType . See https://www.ensmallen.org/docs.html#callback-documentation. |