15 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP 16 #define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP 96 Action::actions action;
99 static const size_t size = 2;
118 const double gravity = 9.8,
119 const double massCart = 1.0,
120 const double massPole = 0.1,
121 const double length = 0.5,
122 const double forceMag = 10.0,
123 const double tau = 0.02,
124 const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
125 const double xThreshold = 2.4,
126 const double doneReward = 1.0) :
131 totalMass(massCart + massPole),
133 poleMassLength(massPole * length),
136 thetaThresholdRadians(thetaThresholdRadians),
137 xThreshold(xThreshold),
138 doneReward(doneReward),
159 double force = action.action ? forceMag : -forceMag;
160 double cosTheta = std::cos(state.
Angle());
161 double sinTheta = std::sin(state.
Angle());
164 double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
165 (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
166 double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
178 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
199 return Sample(state, action, nextState);
210 return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
221 if (maxSteps != 0 && stepsPerformed >= maxSteps)
223 Log::Info <<
"Episode terminated due to the maximum number of steps" 227 else if (std::abs(state.
Position()) > xThreshold ||
228 std::abs(state.
Angle()) > thetaThresholdRadians)
230 Log::Info <<
"Episode terminated due to agent failing.";
264 double poleMassLength;
273 double thetaThresholdRadians;
282 size_t stepsPerformed;
double Sample(const State &state, const Action &action)
Dynamics of Cart Pole.
Definition: cart_pole.hpp:196
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
Definition: cart_pole.hpp:207
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double Velocity() const
Get the velocity.
Definition: cart_pole.hpp:59
double & Velocity()
Modify the velocity.
Definition: cart_pole.hpp:61
Implementation of action of Cart Pole.
Definition: cart_pole.hpp:87
State(const arma::colvec &data)
Construct a state instance from given data.
Definition: cart_pole.hpp:47
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t & MaxSteps()
Set the maximum number of steps allowed.
Definition: cart_pole.hpp:242
double & Angle()
Modify the angle.
Definition: cart_pole.hpp:66
double & AngularVelocity()
Modify the angular velocity.
Definition: cart_pole.hpp:71
const arma::colvec & Encode() const
Encode the state to a column vector.
Definition: cart_pole.hpp:74
arma::colvec & Data()
Modify the internal representation of the state.
Definition: cart_pole.hpp:51
State()
Construct a state instance.
Definition: cart_pole.hpp:39
bool IsTerminal(const State &state) const
This function checks if the cart has reached the terminal state.
Definition: cart_pole.hpp:219
Implementation of the state of Cart Pole.
Definition: cart_pole.hpp:33
size_t MaxSteps() const
Get the maximum number of steps allowed.
Definition: cart_pole.hpp:240
CartPole(const size_t maxSteps=200, const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=1.0)
Construct a Cart Pole instance using the given constants.
Definition: cart_pole.hpp:117
static constexpr size_t dimension
Dimension of the encoded state.
Definition: cart_pole.hpp:77
double Angle() const
Get the angle.
Definition: cart_pole.hpp:64
size_t StepsPerformed() const
Get the number of steps performed.
Definition: cart_pole.hpp:237
double & Position()
Modify the position.
Definition: cart_pole.hpp:56
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Cart Pole instance.
Definition: cart_pole.hpp:151
double Position() const
Get the position.
Definition: cart_pole.hpp:54
double AngularVelocity() const
Get the angular velocity.
Definition: cart_pole.hpp:69
Implementation of Cart Pole task.
Definition: cart_pole.hpp:26