RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.testing.control; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.agents.offpolicy.OffPolicyAgentEvaluable; 00006 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.helpers.RandomPolicy; 00007 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00008 import rlpark.plugin.rltoys.algorithms.functions.states.Projector; 00009 import rlpark.plugin.rltoys.algorithms.representations.discretizer.TabularActionDiscretizer; 00010 import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.AbstractPartitionFactory; 00011 import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.BoundedSmallPartitionFactory; 00012 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.StateActionCoders; 00013 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCoders; 00014 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCodersHashing; 00015 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.Hashing; 00016 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.MurmurHashing; 00017 import rlpark.plugin.rltoys.envio.policy.Policy; 00018 import rlpark.plugin.rltoys.experiments.helpers.Runner; 00019 import rlpark.plugin.rltoys.math.ranges.Range; 00020 import rlpark.plugin.rltoys.problems.ProblemBounded; 00021 import rlpark.plugin.rltoys.problems.mountaincar.MountainCar; 00022 00023 public class MountainCarOffPolicyLearning { 00024 static public interface MountainCarEvaluationAgentFactory { 00025 OffPolicyAgentEvaluable createOffPolicyAgent(Random random, MountainCar problem, Policy behaviour, double gamma); 00026 } 00027 00028 static public long evaluate(MountainCarEvaluationAgentFactory agentFactory) { 00029 return evaluate(agentFactory, 100); 00030 } 00031 00032 static public long evaluate(MountainCarEvaluationAgentFactory agentFactory, int nbLearningEpisodes) { 00033 MountainCar problem = new MountainCar(null); 00034 Random random = new Random(0); 00035 Policy behaviour = new RandomPolicy(random, problem.actions()); 00036 OffPolicyAgentEvaluable agent = agentFactory.createOffPolicyAgent(random, problem, behaviour, .99); 00037 Runner learningRunner = new Runner(problem, agent, 100, 5000); 00038 learningRunner.run(); 00039 Runner evaluationRunner = new Runner(problem, agent.createEvaluatedAgent(), 1, 5000); 00040 evaluationRunner.run(); 00041 return evaluationRunner.runnerEvent().step.time; 00042 } 00043 00044 private static final int MemorySize = 1000000; 00045 00046 static private Hashing createHashing(Random random) { 00047 return new MurmurHashing(random, MemorySize); 00048 } 00049 00050 static private void setTileCoders(TileCoders projector) { 00051 projector.addFullTilings(10, 10); 00052 projector.includeActiveFeature(); 00053 } 00054 00055 static private AbstractPartitionFactory createPartitionFactory(Random random, Range[] observationRanges) { 00056 AbstractPartitionFactory partitionFactory = new BoundedSmallPartitionFactory(observationRanges); 00057 partitionFactory.setRandom(random, .2); 00058 return partitionFactory; 00059 } 00060 00061 00062 static public Projector createProjector(Random random, MountainCar problem) { 00063 final Range[] observationRanges = ((ProblemBounded) problem).getObservationRanges(); 00064 final AbstractPartitionFactory discretizerFactory = createPartitionFactory(random, observationRanges); 00065 Hashing hashing = createHashing(random); 00066 TileCodersHashing projector = new TileCodersHashing(hashing, discretizerFactory, observationRanges.length); 00067 setTileCoders(projector); 00068 return projector; 00069 } 00070 00071 static public StateToStateAction createToStateAction(Random random, MountainCar problem) { 00072 final Range[] observationRanges = problem.getObservationRanges(); 00073 final AbstractPartitionFactory discretizerFactory = createPartitionFactory(random, observationRanges); 00074 TabularActionDiscretizer actionDiscretizer = new TabularActionDiscretizer(problem.actions()); 00075 Hashing hashing = createHashing(random); 00076 StateActionCoders stateActionCoders = new StateActionCoders(actionDiscretizer, hashing, discretizerFactory, 00077 observationRanges.length); 00078 setTileCoders(stateActionCoders.tileCoders()); 00079 return stateActionCoders; 00080 } 00081 }