mlpack
Functions
q_learning_test.cpp File Reference
#include <mlpack/core.hpp>
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/ann/init_rules/gaussian_init.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/loss_functions/mean_squared_error.hpp>
#include <mlpack/methods/ann/loss_functions/empty_loss.hpp>
#include <mlpack/methods/reinforcement_learning/q_learning.hpp>
#include <mlpack/methods/reinforcement_learning/sac.hpp>
#include <mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp>
#include <mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp>
#include <mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp>
#include <mlpack/methods/reinforcement_learning/environment/env_type.hpp>
#include <mlpack/methods/reinforcement_learning/environment/pendulum.hpp>
#include <mlpack/methods/reinforcement_learning/environment/mountain_car.hpp>
#include <mlpack/methods/reinforcement_learning/environment/acrobot.hpp>
#include <mlpack/methods/reinforcement_learning/environment/cart_pole.hpp>
#include <mlpack/methods/reinforcement_learning/environment/double_pole_cart.hpp>
#include <mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp>
#include <mlpack/methods/reinforcement_learning/training_config.hpp>
#include <ensmallen.hpp>
#include <numeric>
#include "catch.hpp"
Include dependency graph for q_learning_test.cpp:

Functions

template<typename AgentType >
bool testAgent (AgentType &agent, const double rewardThreshold, const size_t noOfEpisodes, const size_t consecutiveEpisodesTest=50)
 
 TEST_CASE ("CartPoleWithDQN", "[QLearningTest]")
 Test DQN in Cart Pole task.
 
 TEST_CASE ("CartPoleWithDQNPrioritizedReplay", "[QLearningTest]")
 Test DQN in Cart Pole task with Prioritized Replay.
 
 TEST_CASE ("CartPoleWithDoubleDQN", "[QLearningTest]")
 Test Double DQN in Cart Pole task.
 
 TEST_CASE ("AcrobotWithDQN", "[QLearningTest]")
 Test DQN in Acrobot task.
 
 TEST_CASE ("MountainCarWithDQN", "[QLearningTest]")
 Test DQN in Mountain Car task.
 
 TEST_CASE ("DoublePoleCartWithDQN", "[QLearningTest]")
 Test DQN in DoublePoleCart task.
 
 TEST_CASE ("CartPoleWithDuelingDQN", "[QLearningTest]")
 Test Dueling DQN in Cart Pole task.
 
 TEST_CASE ("CartPoleWithDuelingDQNPrioritizedReplay", "[QLearningTest]")
 Test Dueling DQN in Cart Pole task with Prioritized Replay.
 
 TEST_CASE ("CartPoleWithNoisyDQN", "[QLearningTest]")
 Test Noisy DQN in Cart Pole task.
 
 TEST_CASE ("CartPoleWithDuelingDoubleNoisyDQN", "[QLearningTest]")
 Test Dueling-Double-Noisy DQN in Cart Pole task.
 
 TEST_CASE ("CartPoleWithNStepDQN", "[QLearningTest]")
 Test N-step DQN in Cart Pole task. More...
 
 TEST_CASE ("CartPoleWithNStepPrioritizedDQN", "[QLearningTest]")
 Test N-step Prioritized DQN in Cart Pole task. More...
 
 TEST_CASE ("CartPoleWithCategoricalDQN", "[QLearningTest]")
 Test Categorical DQN in Cart Pole task.
 
 TEST_CASE ("PendulumWithSAC", "[QLearningTest]")
 Test SAC on Pendulum task.
 
 TEST_CASE ("SACForMultipleActions", "[QLearningTest]")
 A test to ensure SAC works with multiple actions in action space.
 

Detailed Description

Author
Shangtong Zhang
Rohan Raj

Test for Q-Learning implementation

mlpack is free software; you may redistribute it and/or modify it under the terms of the 3-clause BSD license. You should have received a copy of the 3-clause BSD license along with mlpack. If not, see http://www.opensource.org/licenses/BSD-3-Clause for more information.

Function Documentation

◆ TEST_CASE() [1/2]

TEST_CASE ( "CartPoleWithNStepDQN"  ,
""  [QLearningTest] 
)

Test N-step DQN in Cart Pole task.

For N-step learning, we need to specify n as the last parameter in the replay method. Here we use n = 3.

◆ TEST_CASE() [2/2]

TEST_CASE ( "CartPoleWithNStepPrioritizedDQN"  ,
""  [QLearningTest] 
)

Test N-step Prioritized DQN in Cart Pole task.

For N-step learning, we need to specify n as the last parameter in the replay method. Here we use n = 3.