RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.testing.predictions; 00002 00003 import java.util.Random; 00004 00005 import org.junit.Assert; 00006 import org.junit.Test; 00007 00008 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD; 00009 import rlpark.plugin.rltoys.experiments.testing.predictions.FiniteStateGraphOnPolicy.OnPolicyTDFactory; 00010 import rlpark.plugin.rltoys.experiments.testing.results.TestingResult; 00011 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph; 00012 import rlpark.plugin.rltoys.problems.stategraph.LineProblem; 00013 import rlpark.plugin.rltoys.problems.stategraph.RandomWalk; 00014 00015 public abstract class OnPolicyTests { 00016 private final LineProblem lineProblem = new LineProblem(); 00017 private final RandomWalk randomWalkProblem = new RandomWalk(new Random(0)); 00018 00019 protected OnPolicyTDFactory[] onPolicyFactories() { 00020 return new OnPolicyTDFactory[] { new OnPolicyTDFactory() { 00021 @Override 00022 public OnPolicyTD create(double lambda, double gamma, double vectorNorm, int vectorSize) { 00023 return newOnPolicyTD(lambda, gamma, vectorNorm, vectorSize); 00024 } 00025 } }; 00026 } 00027 00028 @Test 00029 public void testOnLineProblem() { 00030 for (OnPolicyTDFactory factory : onPolicyFactories()) 00031 testTD(0, lineProblem, factory); 00032 } 00033 00034 @Test 00035 public void testOnRandomWalkProblem() { 00036 for (OnPolicyTDFactory factory : onPolicyFactories()) 00037 testTD(0, randomWalkProblem, factory); 00038 } 00039 00040 @Test 00041 public void testOnLineProblemWithLambda() { 00042 for (OnPolicyTDFactory factory : onPolicyFactories()) 00043 for (double lambda : lambdaValues()) 00044 testTD(lambda, lineProblem, factory); 00045 } 00046 00047 @Test 00048 public void testOnRandomWalkProblemWithLambda() { 00049 for (OnPolicyTDFactory factory : onPolicyFactories()) 00050 for (double lambda : lambdaValues()) 00051 testTD(lambda, randomWalkProblem, factory); 00052 } 00053 00054 private void testTD(double lambda, FiniteStateGraph problem, OnPolicyTDFactory factory) { 00055 TestingResult<OnPolicyTD> result = FiniteStateGraphOnPolicy.testTD(lambda, problem, factory, nbEpisodeMax(), 00056 precision()); 00057 Assert.assertTrue(result.message, result.passed); 00058 } 00059 00060 protected int nbEpisodeMax() { 00061 return 100000; 00062 } 00063 00064 abstract protected OnPolicyTD newOnPolicyTD(double lambda, double gamma, double vectorNorm, int vectorSize); 00065 00066 abstract protected double[] lambdaValues(); 00067 00068 abstract protected double precision(); 00069 }