10 #ifndef MLPACK_METHODS_ANN_GAN_WGANGP_IMPL_HPP 11 #define MLPACK_METHODS_ANN_GAN_WGANGP_IMPL_HPP 26 typename InitializationRuleType,
30 template<
typename Policy>
31 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
38 if ((parameter.is_empty()))
49 currentInput = arma::mat(predictors.memptr() + (i * predictors.n_rows),
50 predictors.n_rows, batchSize,
false,
false);
51 currentTarget = arma::mat(responses.memptr() + i, 1, batchSize,
false,
54 discriminator.Forward(std::move(currentInput));
55 double res = discriminator.outputLayer.Forward(
56 std::move(boost::apply_visitor(
57 outputParameterVisitor,
58 discriminator.network.back())), std::move(currentTarget));
60 noise.imbue( [&]() {
return noiseFunction();} );
61 generator.Forward(std::move(noise));
63 arma::mat generatedData = boost::apply_visitor(outputParameterVisitor,
64 generator.network.back());
65 predictors.cols(numFunctions, numFunctions + batchSize - 1) =
67 discriminator.Forward(std::move(predictors.cols(numFunctions,
68 numFunctions + batchSize - 1)));
69 responses.cols(numFunctions, numFunctions + batchSize - 1) =
70 -arma::ones(1, batchSize);
72 currentTarget = arma::mat(responses.memptr() + numFunctions,
73 1, batchSize,
false,
false);
74 res += discriminator.outputLayer.Forward(
75 std::move(boost::apply_visitor(
76 outputParameterVisitor,
77 discriminator.network.back())), std::move(currentTarget));
81 predictors.cols(numFunctions, numFunctions + batchSize - 1) =
82 (epsilon * currentInput) + ((1.0 - epsilon) * generatedData);
83 responses.cols(numFunctions, numFunctions + batchSize - 1) =
84 -arma::ones(1, batchSize);
85 discriminator.Gradient(discriminator.parameter, numFunctions,
86 normGradientDiscriminator, batchSize);
87 res += lambda * std::pow(arma::norm(normGradientDiscriminator, 2) - 1, 2);
94 typename InitializationRuleType,
98 template<
typename GradType,
typename Policy>
99 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
107 if (parameter.is_empty())
112 if (gradient.is_empty())
114 if (parameter.is_empty())
116 gradient = arma::zeros<arma::mat>(parameter.n_elem, 1);
121 if (this->deterministic)
123 this->deterministic =
false;
124 ResetDeterministic();
127 if (noiseGradientDiscriminator.is_empty())
129 noiseGradientDiscriminator = arma::zeros<arma::mat>(
130 gradientDiscriminator.n_elem, 1);
134 noiseGradientDiscriminator.zeros();
137 gradientGenerator = arma::mat(gradient.memptr(),
138 generator.Parameters().n_elem, 1,
false,
false);
140 gradientDiscriminator = arma::mat(gradient.memptr() +
141 gradientGenerator.n_elem,
142 discriminator.Parameters().n_elem, 1,
false,
false);
144 currentInput = arma::mat(predictors.memptr() + (i * predictors.n_rows),
145 predictors.n_rows, batchSize,
false,
false);
148 double res = discriminator.EvaluateWithGradient(discriminator.parameter,
149 i, gradientDiscriminator, batchSize);
151 noise.imbue( [&]() {
return noiseFunction();} );
152 generator.Forward(std::move(noise));
153 arma::mat generatedData = boost::apply_visitor(outputParameterVisitor,
154 generator.network.back());
158 predictors.cols(numFunctions, numFunctions + batchSize - 1) =
159 (epsilon * currentInput) + ((1.0 - epsilon) * generatedData);
160 responses.cols(numFunctions, numFunctions + batchSize - 1) =
161 -arma::ones(1, batchSize);
162 discriminator.Gradient(discriminator.parameter, numFunctions,
163 normGradientDiscriminator, batchSize);
164 res += lambda * std::pow(arma::norm(normGradientDiscriminator, 2) - 1, 2);
166 predictors.cols(numFunctions, numFunctions + batchSize - 1) =
168 res += discriminator.EvaluateWithGradient(discriminator.parameter,
169 numFunctions, noiseGradientDiscriminator, batchSize);
170 gradientDiscriminator += noiseGradientDiscriminator;
172 if (currentBatch % generatorUpdateStep == 0 && preTrainSize == 0)
176 responses.cols(numFunctions, numFunctions + batchSize - 1) =
177 arma::ones(1, batchSize);
179 discriminator.outputLayer.Backward(
180 boost::apply_visitor(outputParameterVisitor,
181 discriminator.network.back()), discriminator.responses.cols(
182 numFunctions, numFunctions + batchSize - 1), discriminator.error);
183 discriminator.Backward();
185 generator.error = boost::apply_visitor(deltaVisitor,
186 discriminator.network[1]);
188 generator.Predictors() = noise;
189 generator.Backward();
190 generator.ResetGradients(gradientGenerator);
191 generator.Gradient(generator.Predictors().cols(0, batchSize - 1));
193 gradientGenerator *= multiplier;
198 if (preTrainSize > 0)
208 typename InitializationRuleType,
212 template<
typename Policy>
213 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
216 Gradient(
const arma::mat& parameters,
219 const size_t batchSize)
221 this->EvaluateWithGradient(parameters, i, gradient, batchSize);
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
Definition: gan_impl.hpp:398
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:239
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
double Random()
Generates a uniform random number between 0 and 1.
Definition: random.hpp:83
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
Definition: gan_impl.hpp:295