RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.testing.predictions; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD; 00006 import rlpark.plugin.rltoys.envio.policy.ConstantPolicy; 00007 import rlpark.plugin.rltoys.experiments.testing.results.TestingResult; 00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00009 import rlpark.plugin.rltoys.math.vector.implementations.Vectors; 00010 import rlpark.plugin.rltoys.problems.stategraph.FSGAgentState; 00011 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph.StepData; 00012 import rlpark.plugin.rltoys.problems.stategraph.RandomWalk; 00013 00014 public class RandomWalkOffPolicy { 00015 public interface OffPolicyTDFactory { 00016 OffPolicyTD newTD(double lambda, double gamma, double vectorNorm, int vectorSize); 00017 } 00018 00019 static public TestingResult<OffPolicyTD> testOffPolicyGTD(int nbEpisodeMax, double precision, double lambda, 00020 double gamma, double targetLeftProbability, double behaviourLeftProbability, OffPolicyTDFactory tdFactory) { 00021 Random random = new Random(0); 00022 ConstantPolicy behaviourPolicy = RandomWalk.newPolicy(random, behaviourLeftProbability); 00023 ConstantPolicy targetPolicy = RandomWalk.newPolicy(random, targetLeftProbability); 00024 RandomWalk problem = new RandomWalk(behaviourPolicy); 00025 FSGAgentState agentState = new FSGAgentState(problem); 00026 OffPolicyTD gtd = tdFactory.newTD(lambda, gamma, agentState.vectorNorm(), agentState.vectorSize()); 00027 int nbEpisode = 0; 00028 double[] solution = agentState.computeSolution(targetPolicy, gamma, lambda); 00029 PVector phi_t = null; 00030 if (FiniteStateGraphOnPolicy.distanceToSolution(solution, gtd.weights()) <= precision) 00031 return new TestingResult<OffPolicyTD>(false, "Precision is incorrect!", gtd); 00032 while (FiniteStateGraphOnPolicy.distanceToSolution(solution, gtd.weights()) > precision) { 00033 StepData stepData = agentState.step(); 00034 PVector phi_tp1 = agentState.currentFeatureState(); 00035 double pi_t = stepData.a_t != null ? targetPolicy.pi(stepData.a_t) : 0; 00036 double b_t = stepData.a_t != null ? behaviourPolicy.pi(stepData.a_t) : 1; 00037 gtd.update(pi_t, b_t, phi_t, phi_tp1, stepData.r_tp1); 00038 if (stepData.s_tp1 == null) { 00039 nbEpisode += 1; 00040 if (nbEpisode > nbEpisodeMax) 00041 return new TestingResult<OffPolicyTD>(false, "Not learning fast enough. Distance to solution: " 00042 + FiniteStateGraphOnPolicy.distanceToSolution(solution, gtd.weights()), gtd); 00043 if (!Vectors.checkValues(gtd.weights())) 00044 return new TestingResult<OffPolicyTD>(false, "Weights are wrong", gtd); 00045 } 00046 phi_t = stepData.s_tp1 != null ? phi_tp1.copy() : null; 00047 } 00048 return new TestingResult<OffPolicyTD>(true, null, gtd); 00049 } 00050 }