mlpack
vr_class_reward_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_VR_CLASS_REWARD_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_VR_CLASS_REWARD_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "vr_class_reward.hpp"
18 
19 #include "../visitor/reward_set_visitor.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
24 template<typename InputDataType, typename OutputDataType>
26  const double scale,
27  const bool sizeAverage) :
28  scale(scale),
29  sizeAverage(sizeAverage),
30  reward(0)
31 {
32  // Nothing to do here.
33 }
34 
35 template<typename InputDataType, typename OutputDataType>
36 template<typename InputType, typename TargetType>
38  const InputType& input, const TargetType& target)
39 {
40  double output = 0;
41  for (size_t i = 0; i < input.n_cols - 1; ++i)
42  {
43  size_t currentTarget = target(i) - 1;
44  Log::Assert(currentTarget < input.n_rows,
45  "Target class out of range.");
46 
47  output -= input(currentTarget, i);
48  }
49 
50  reward = 0;
51  arma::uword index = 0;
52 
53  for (size_t i = 0; i < input.n_cols - 1; ++i)
54  {
55  input.unsafe_col(i).max(index);
56  reward = ((index + 1) == target(i)) * scale;
57  }
58 
59  if (sizeAverage)
60  {
61  return output - reward / (input.n_cols - 1);
62  }
63 
64  return output - reward;
65 }
66 
67 template<typename InputDataType, typename OutputDataType>
68 template<typename InputType, typename TargetType, typename OutputType>
70  const InputType& input,
71  const TargetType& target,
72  OutputType& output)
73 {
74  output = arma::zeros<OutputType>(input.n_rows, input.n_cols);
75  for (size_t i = 0; i < (input.n_cols - 1); ++i)
76  {
77  size_t currentTarget = target(i) - 1;
78  Log::Assert(currentTarget < input.n_rows,
79  "Target class out of range.");
80 
81  output(currentTarget, i) = -1;
82  }
83 
84  double vrReward = reward - input(0, 1);
85  if (sizeAverage)
86  {
87  vrReward /= input.n_cols - 1;
88  }
89 
90  const double norm = sizeAverage ? 2.0 / (input.n_cols - 1) : 2.0;
91 
92  output(0, 1) = norm * (input(0, 1) - reward);
93  boost::apply_visitor(RewardSetVisitor(vrReward), network.back());
94 }
95 
96 template<typename InputDataType, typename OutputDataType>
97 template<typename Archive>
99  Archive& /* ar */, const uint32_t /* version */)
100 {
101  // Nothing to do here.
102 }
103 
104 } // namespace ann
105 } // namespace mlpack
106 
107 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
RewardSetVisitor set the reward parameter given the reward value.
Definition: reward_set_visitor.hpp:26
VRClassReward(const double scale=1, const bool sizeAverage=true)
Create the VRClassReward object.
Definition: vr_class_reward_impl.hpp:25
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
Definition: vr_class_reward_impl.hpp:69
double Forward(const InputType &input, const TargetType &target)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: vr_class_reward_impl.hpp:37
void serialize(Archive &, const uint32_t)
Serialize the layer.
Definition: vr_class_reward_impl.hpp:98
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38