13 #ifndef MLPACK_METHODS_ANN_LAYER_VR_CLASS_REWARD_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_VR_CLASS_REWARD_IMPL_HPP 19 #include "../visitor/reward_set_visitor.hpp" 24 template<
typename InputDataType,
typename OutputDataType>
27 const bool sizeAverage) :
29 sizeAverage(sizeAverage),
35 template<
typename InputDataType,
typename OutputDataType>
36 template<
typename InputType,
typename TargetType>
38 const InputType& input,
const TargetType& target)
41 for (
size_t i = 0; i < input.n_cols - 1; ++i)
43 size_t currentTarget = target(i) - 1;
45 "Target class out of range.");
47 output -= input(currentTarget, i);
51 arma::uword index = 0;
53 for (
size_t i = 0; i < input.n_cols - 1; ++i)
55 input.unsafe_col(i).max(index);
56 reward = ((index + 1) == target(i)) * scale;
61 return output - reward / (input.n_cols - 1);
64 return output - reward;
67 template<
typename InputDataType,
typename OutputDataType>
68 template<
typename InputType,
typename TargetType,
typename OutputType>
70 const InputType& input,
71 const TargetType& target,
74 output = arma::zeros<OutputType>(input.n_rows, input.n_cols);
75 for (
size_t i = 0; i < (input.n_cols - 1); ++i)
77 size_t currentTarget = target(i) - 1;
79 "Target class out of range.");
81 output(currentTarget, i) = -1;
84 double vrReward = reward - input(0, 1);
87 vrReward /= input.n_cols - 1;
90 const double norm = sizeAverage ? 2.0 / (input.n_cols - 1) : 2.0;
92 output(0, 1) = norm * (input(0, 1) - reward);
96 template<
typename InputDataType,
typename OutputDataType>
97 template<
typename Archive>
99 Archive& ,
const uint32_t )
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
RewardSetVisitor set the reward parameter given the reward value.
Definition: reward_set_visitor.hpp:26
VRClassReward(const double scale=1, const bool sizeAverage=true)
Create the VRClassReward object.
Definition: vr_class_reward_impl.hpp:25
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
Definition: vr_class_reward_impl.hpp:69
double Forward(const InputType &input, const TargetType &target)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: vr_class_reward_impl.hpp:37
void serialize(Archive &, const uint32_t)
Serialize the layer.
Definition: vr_class_reward_impl.hpp:98
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38