RLPark 1.0.0
Reinforcement Learning Framework in Java

OnPolicyTests.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.experiments.testing.predictions;
00002 
00003 import java.util.Random;
00004 
00005 import org.junit.Assert;
00006 import org.junit.Test;
00007 
00008 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD;
00009 import rlpark.plugin.rltoys.experiments.testing.predictions.FiniteStateGraphOnPolicy.OnPolicyTDFactory;
00010 import rlpark.plugin.rltoys.experiments.testing.results.TestingResult;
00011 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph;
00012 import rlpark.plugin.rltoys.problems.stategraph.LineProblem;
00013 import rlpark.plugin.rltoys.problems.stategraph.RandomWalk;
00014 
00015 public abstract class OnPolicyTests {
00016   private final LineProblem lineProblem = new LineProblem();
00017   private final RandomWalk randomWalkProblem = new RandomWalk(new Random(0));
00018 
00019   protected OnPolicyTDFactory[] onPolicyFactories() {
00020     return new OnPolicyTDFactory[] { new OnPolicyTDFactory() {
00021       @Override
00022       public OnPolicyTD create(double lambda, double gamma, double vectorNorm, int vectorSize) {
00023         return newOnPolicyTD(lambda, gamma, vectorNorm, vectorSize);
00024       }
00025     } };
00026   }
00027 
00028   @Test
00029   public void testOnLineProblem() {
00030     for (OnPolicyTDFactory factory : onPolicyFactories())
00031       testTD(0, lineProblem, factory);
00032   }
00033 
00034   @Test
00035   public void testOnRandomWalkProblem() {
00036     for (OnPolicyTDFactory factory : onPolicyFactories())
00037       testTD(0, randomWalkProblem, factory);
00038   }
00039 
00040   @Test
00041   public void testOnLineProblemWithLambda() {
00042     for (OnPolicyTDFactory factory : onPolicyFactories())
00043       for (double lambda : lambdaValues())
00044         testTD(lambda, lineProblem, factory);
00045   }
00046 
00047   @Test
00048   public void testOnRandomWalkProblemWithLambda() {
00049     for (OnPolicyTDFactory factory : onPolicyFactories())
00050       for (double lambda : lambdaValues())
00051         testTD(lambda, randomWalkProblem, factory);
00052   }
00053 
00054   private void testTD(double lambda, FiniteStateGraph problem, OnPolicyTDFactory factory) {
00055     TestingResult<OnPolicyTD> result = FiniteStateGraphOnPolicy.testTD(lambda, problem, factory, nbEpisodeMax(),
00056                                                                        precision());
00057     Assert.assertTrue(result.message, result.passed);
00058   }
00059 
00060   protected int nbEpisodeMax() {
00061     return 100000;
00062   }
00063 
00064   abstract protected OnPolicyTD newOnPolicyTD(double lambda, double gamma, double vectorNorm, int vectorSize);
00065 
00066   abstract protected double[] lambdaValues();
00067 
00068   abstract protected double precision();
00069 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark