10 #ifndef MLPACK_METHODS_ANN_GAN_WGAN_IMPL_HPP 11 #define MLPACK_METHODS_ANN_GAN_WGAN_IMPL_HPP 26 typename InitializationRuleType,
30 template<
typename Policy>
31 typename std::enable_if<std::is_same<Policy, WGAN>::value,
double>::type
37 if (parameter.is_empty())
48 currentInput = arma::mat(predictors.memptr() + (i * predictors.n_rows),
49 predictors.n_rows, batchSize,
false,
false);
50 currentTarget = arma::mat(responses.memptr() + i, 1, batchSize,
false,
53 discriminator.Forward(currentInput);
54 double res = discriminator.outputLayer.Forward(
56 outputParameterVisitor,
57 discriminator.network.back()), currentTarget);
59 noise.imbue( [&]() {
return noiseFunction();} );
60 generator.Forward(noise);
62 predictors.cols(numFunctions, numFunctions + batchSize - 1) =
63 boost::apply_visitor(outputParameterVisitor, generator.network.back());
64 discriminator.Forward(predictors.cols(numFunctions,
65 numFunctions + batchSize - 1));
66 responses.cols(numFunctions, numFunctions + batchSize - 1) =
67 -arma::ones(1, batchSize);
69 currentTarget = arma::mat(responses.memptr() + numFunctions,
70 1, batchSize,
false,
false);
71 res += discriminator.outputLayer.Forward(
73 outputParameterVisitor,
74 discriminator.network.back()), currentTarget);
81 typename InitializationRuleType,
85 template<
typename GradType,
typename Policy>
86 typename std::enable_if<std::is_same<Policy, WGAN>::value,
double>::type
93 if (parameter.is_empty())
98 if (gradient.is_empty())
100 if (parameter.is_empty())
102 gradient = arma::zeros<arma::mat>(parameter.n_elem, 1);
107 if (this->deterministic)
109 this->deterministic =
false;
110 ResetDeterministic();
113 if (noiseGradientDiscriminator.is_empty())
115 noiseGradientDiscriminator = arma::zeros<arma::mat>(
116 gradientDiscriminator.n_elem, 1);
120 noiseGradientDiscriminator.zeros();
123 gradientGenerator = arma::mat(gradient.memptr(),
124 generator.Parameters().n_elem, 1,
false,
false);
126 gradientDiscriminator = arma::mat(gradient.memptr() +
127 gradientGenerator.n_elem,
128 discriminator.Parameters().n_elem, 1,
false,
false);
131 double res = discriminator.EvaluateWithGradient(discriminator.parameter,
132 i, gradientDiscriminator, batchSize);
134 noise.imbue( [&]() {
return noiseFunction();} );
135 generator.Forward(noise);
136 predictors.cols(numFunctions, numFunctions + batchSize - 1) =
137 boost::apply_visitor(outputParameterVisitor, generator.network.back());
138 responses.cols(numFunctions, numFunctions + batchSize - 1) =
139 -arma::ones(1, batchSize);
142 res += discriminator.EvaluateWithGradient(discriminator.parameter,
143 numFunctions, noiseGradientDiscriminator, batchSize);
144 gradientDiscriminator += noiseGradientDiscriminator;
145 gradientDiscriminator = arma::clamp(gradientDiscriminator,
146 -clippingParameter, clippingParameter);
148 if (currentBatch % generatorUpdateStep == 0 && preTrainSize == 0)
152 responses.cols(numFunctions, numFunctions + batchSize - 1) =
153 arma::ones(1, batchSize);
155 discriminator.outputLayer.Backward(
156 boost::apply_visitor(outputParameterVisitor,
157 discriminator.network.back()), discriminator.responses.cols(
158 numFunctions, numFunctions + batchSize - 1), discriminator.error);
159 discriminator.Backward();
161 generator.error = boost::apply_visitor(deltaVisitor,
162 discriminator.network[1]);
164 generator.Predictors() = noise;
165 generator.Backward();
166 generator.ResetGradients(gradientGenerator);
167 generator.Gradient(generator.Predictors().cols(0, batchSize - 1));
169 gradientGenerator *= multiplier;
174 if (preTrainSize > 0)
184 typename InitializationRuleType,
188 template<
typename Policy>
189 typename std::enable_if<std::is_same<Policy, WGAN>::value,
void>::type
191 Gradient(
const arma::mat& parameters,
194 const size_t batchSize)
196 this->EvaluateWithGradient(parameters, i, gradient, batchSize);
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
Definition: gan_impl.hpp:398
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:239
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:295