RLPark 1.0.0
Reinforcement Learning Framework in Java

FiniteStateGraphOnPolicy.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.experiments.testing.predictions;
00002 
00003 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD;
00004 import rlpark.plugin.rltoys.experiments.testing.results.TestingResult;
00005 import rlpark.plugin.rltoys.math.vector.RealVector;
00006 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00007 import rlpark.plugin.rltoys.math.vector.implementations.Vectors;
00008 import rlpark.plugin.rltoys.problems.stategraph.FSGAgentState;
00009 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph;
00010 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph.StepData;
00011 
00012 public class FiniteStateGraphOnPolicy {
00013 
00014   static public interface OnPolicyTDFactory {
00015     OnPolicyTD create(double lambda, double gamma, double vectorNorm, int vectorSize);
00016   }
00017 
00018   static public double distanceToSolution(double[] solution, PVector theta) {
00019     double max = 0;
00020     for (int i = 0; i < Math.max(solution.length, theta.size); i++)
00021       max = Math.max(max, Math.abs(solution[i] - theta.data[i]));
00022     return max;
00023   }
00024 
00025   public static TestingResult<OnPolicyTD> testTD(double lambda, FiniteStateGraph problem,
00026       FiniteStateGraphOnPolicy.OnPolicyTDFactory tdFactory, int nbEpisodeMax, double precision) {
00027     FSGAgentState agentState = new FSGAgentState(problem);
00028     OnPolicyTD td = tdFactory.create(lambda, problem.gamma(), agentState.vectorNorm(), agentState.vectorSize());
00029     int nbEpisode = 0;
00030     double[] solution = problem.expectedDiscountedSolution();
00031     RealVector x_t = null;
00032     if (FiniteStateGraphOnPolicy.distanceToSolution(solution, td.weights()) <= precision)
00033       return new TestingResult<OnPolicyTD>(false, "Precision is incorrect!", td);
00034     while (distanceToSolution(solution, td.weights()) > precision) {
00035       StepData stepData = agentState.step();
00036       RealVector x_tp1 = agentState.currentFeatureState();
00037       td.update(x_t, x_tp1, stepData.r_tp1);
00038       if (stepData.s_tp1 == null) {
00039         nbEpisode += 1;
00040         if (nbEpisode >= nbEpisodeMax) {
00041           String message = String.format("Not learning fast enough. Lambda=%f Gamma=%f. Distance to solution=%f",
00042                                          lambda, problem.gamma(),
00043                                          FiniteStateGraphOnPolicy.distanceToSolution(solution, td.weights()));
00044           return new TestingResult<OnPolicyTD>(false, message, td);
00045         }
00046       }
00047       x_t = stepData.s_tp1 != null ? x_tp1.copy() : null;
00048     }
00049     if (!Vectors.checkValues(td.weights()))
00050       return new TestingResult<OnPolicyTD>(false, "Weights are incorrect!", td);
00051     return new TestingResult<OnPolicyTD>(true, null, td);
00052   }
00053 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark