RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.testing.predictions; 00002 00003 import org.junit.Assert; 00004 import org.junit.Test; 00005 00006 import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD; 00007 import rlpark.plugin.rltoys.experiments.testing.predictions.FiniteStateGraphOnPolicy.OnPolicyTDFactory; 00008 import rlpark.plugin.rltoys.experiments.testing.predictions.RandomWalkOffPolicy.OffPolicyTDFactory; 00009 import rlpark.plugin.rltoys.experiments.testing.results.TestingResult; 00010 00011 public abstract class OffPolicyTests extends OnPolicyTests { 00012 private static final double Gamma = 0.9; 00013 00014 @Test 00015 public void testOffPolicy() { 00016 for (OffPolicyTDFactory factory : offPolicyTDFactory()) { 00017 testOffPolicy(0.0, 0.2, 0.5, factory); 00018 testOffPolicy(0.0, 0.5, 0.2, factory); 00019 } 00020 } 00021 00022 @Test 00023 public void testOffPolicyWithLambda() { 00024 for (OffPolicyTDFactory factory : offPolicyTDFactory()) { 00025 for (double lambda : lambdaValues()) { 00026 testOffPolicy(lambda, 0.2, 0.5, factory); 00027 testOffPolicy(lambda, 0.5, 0.2, factory); 00028 } 00029 } 00030 } 00031 00032 protected OffPolicyTDFactory[] offPolicyTDFactory() { 00033 OnPolicyTDFactory[] onPolicyTDFactories = onPolicyFactories(); 00034 OffPolicyTDFactory[] offPolicyTDFactories = new OffPolicyTDFactory[onPolicyTDFactories.length]; 00035 for (int i = 0; i < offPolicyTDFactories.length; i++) { 00036 final OnPolicyTDFactory onPolicyFactory = onPolicyTDFactories[i]; 00037 offPolicyTDFactories[i] = new OffPolicyTDFactory() { 00038 @Override 00039 public OffPolicyTD newTD(double lambda, double gamma, double vectorNorm, int vectorSize) { 00040 return (OffPolicyTD) onPolicyFactory.create(lambda, gamma, vectorNorm, vectorSize); 00041 } 00042 }; 00043 } 00044 return offPolicyTDFactories; 00045 } 00046 00047 private void testOffPolicy(double lambda, double targetLeftProbability, double behaviourLeftProbability, 00048 OffPolicyTDFactory factory) { 00049 TestingResult<OffPolicyTD> result = RandomWalkOffPolicy.testOffPolicyGTD(nbEpisodeMax(), precision(), lambda, 00050 Gamma, targetLeftProbability, 00051 behaviourLeftProbability, factory); 00052 Assert.assertTrue(result.message, result.passed); 00053 } 00054 }