RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.testing.predictions; 00002 00003 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD; 00004 import rlpark.plugin.rltoys.experiments.testing.results.TestingResult; 00005 import rlpark.plugin.rltoys.math.vector.RealVector; 00006 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00007 import rlpark.plugin.rltoys.math.vector.implementations.Vectors; 00008 import rlpark.plugin.rltoys.problems.stategraph.FSGAgentState; 00009 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph; 00010 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph.StepData; 00011 00012 public class FiniteStateGraphOnPolicy { 00013 00014 static public interface OnPolicyTDFactory { 00015 OnPolicyTD create(double lambda, double gamma, double vectorNorm, int vectorSize); 00016 } 00017 00018 static public double distanceToSolution(double[] solution, PVector theta) { 00019 double max = 0; 00020 for (int i = 0; i < Math.max(solution.length, theta.size); i++) 00021 max = Math.max(max, Math.abs(solution[i] - theta.data[i])); 00022 return max; 00023 } 00024 00025 public static TestingResult<OnPolicyTD> testTD(double lambda, FiniteStateGraph problem, 00026 FiniteStateGraphOnPolicy.OnPolicyTDFactory tdFactory, int nbEpisodeMax, double precision) { 00027 FSGAgentState agentState = new FSGAgentState(problem); 00028 OnPolicyTD td = tdFactory.create(lambda, problem.gamma(), agentState.vectorNorm(), agentState.vectorSize()); 00029 int nbEpisode = 0; 00030 double[] solution = problem.expectedDiscountedSolution(); 00031 RealVector x_t = null; 00032 if (FiniteStateGraphOnPolicy.distanceToSolution(solution, td.weights()) <= precision) 00033 return new TestingResult<OnPolicyTD>(false, "Precision is incorrect!", td); 00034 while (distanceToSolution(solution, td.weights()) > precision) { 00035 StepData stepData = agentState.step(); 00036 RealVector x_tp1 = agentState.currentFeatureState(); 00037 td.update(x_t, x_tp1, stepData.r_tp1); 00038 if (stepData.s_tp1 == null) { 00039 nbEpisode += 1; 00040 if (nbEpisode >= nbEpisodeMax) { 00041 String message = String.format("Not learning fast enough. Lambda=%f Gamma=%f. Distance to solution=%f", 00042 lambda, problem.gamma(), 00043 FiniteStateGraphOnPolicy.distanceToSolution(solution, td.weights())); 00044 return new TestingResult<OnPolicyTD>(false, message, td); 00045 } 00046 } 00047 x_t = stepData.s_tp1 != null ? x_tp1.copy() : null; 00048 } 00049 if (!Vectors.checkValues(td.weights())) 00050 return new TestingResult<OnPolicyTD>(false, "Weights are incorrect!", td); 00051 return new TestingResult<OnPolicyTD>(true, null, td); 00052 } 00053 }