RLPark 1.0.0
Reinforcement Learning Framework in Java

RandomWalkOffPolicy.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.experiments.testing.predictions;
00002 
00003 import java.util.Random;
00004 
00005 import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD;
00006 import rlpark.plugin.rltoys.envio.policy.ConstantPolicy;
00007 import rlpark.plugin.rltoys.experiments.testing.results.TestingResult;
00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00009 import rlpark.plugin.rltoys.math.vector.implementations.Vectors;
00010 import rlpark.plugin.rltoys.problems.stategraph.FSGAgentState;
00011 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph.StepData;
00012 import rlpark.plugin.rltoys.problems.stategraph.RandomWalk;
00013 
00014 public class RandomWalkOffPolicy {
00015   public interface OffPolicyTDFactory {
00016     OffPolicyTD newTD(double lambda, double gamma, double vectorNorm, int vectorSize);
00017   }
00018 
00019   static public TestingResult<OffPolicyTD> testOffPolicyGTD(int nbEpisodeMax, double precision, double lambda,
00020       double gamma, double targetLeftProbability, double behaviourLeftProbability, OffPolicyTDFactory tdFactory) {
00021     Random random = new Random(0);
00022     ConstantPolicy behaviourPolicy = RandomWalk.newPolicy(random, behaviourLeftProbability);
00023     ConstantPolicy targetPolicy = RandomWalk.newPolicy(random, targetLeftProbability);
00024     RandomWalk problem = new RandomWalk(behaviourPolicy);
00025     FSGAgentState agentState = new FSGAgentState(problem);
00026     OffPolicyTD gtd = tdFactory.newTD(lambda, gamma, agentState.vectorNorm(), agentState.vectorSize());
00027     int nbEpisode = 0;
00028     double[] solution = agentState.computeSolution(targetPolicy, gamma, lambda);
00029     PVector phi_t = null;
00030     if (FiniteStateGraphOnPolicy.distanceToSolution(solution, gtd.weights()) <= precision)
00031       return new TestingResult<OffPolicyTD>(false, "Precision is incorrect!", gtd);
00032     while (FiniteStateGraphOnPolicy.distanceToSolution(solution, gtd.weights()) > precision) {
00033       StepData stepData = agentState.step();
00034       PVector phi_tp1 = agentState.currentFeatureState();
00035       double pi_t = stepData.a_t != null ? targetPolicy.pi(stepData.a_t) : 0;
00036       double b_t = stepData.a_t != null ? behaviourPolicy.pi(stepData.a_t) : 1;
00037       gtd.update(pi_t, b_t, phi_t, phi_tp1, stepData.r_tp1);
00038       if (stepData.s_tp1 == null) {
00039         nbEpisode += 1;
00040         if (nbEpisode > nbEpisodeMax)
00041           return new TestingResult<OffPolicyTD>(false, "Not learning fast enough. Distance to solution: "
00042               + FiniteStateGraphOnPolicy.distanceToSolution(solution, gtd.weights()), gtd);
00043         if (!Vectors.checkValues(gtd.weights()))
00044           return new TestingResult<OffPolicyTD>(false, "Weights are wrong", gtd);
00045       }
00046       phi_t = stepData.s_tp1 != null ? phi_tp1.copy() : null;
00047     }
00048     return new TestingResult<OffPolicyTD>(true, null, gtd);
00049   }
00050 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark