RLPark 1.0.0
Reinforcement Learning Framework in Java

OffPACPuddleWorld.java

Go to the documentation of this file.
00001 package rlpark.example.demos.learning;
00002 
00003 import java.util.Random;
00004 
00005 import rlpark.plugin.rltoys.agents.functions.FunctionProjected2D;
00006 import rlpark.plugin.rltoys.agents.functions.ValueFunction2D;
00007 import rlpark.plugin.rltoys.agents.offpolicy.OffPolicyAgentDirect;
00008 import rlpark.plugin.rltoys.agents.offpolicy.OffPolicyAgentEvaluable;
00009 import rlpark.plugin.rltoys.agents.offpolicy.OffPolicyAgentFA;
00010 import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.ActorLambdaOffPolicy;
00011 import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.ActorOffPolicy;
00012 import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.CriticAdapterFA;
00013 import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.OffPAC;
00014 import rlpark.plugin.rltoys.algorithms.functions.ContinuousFunction;
00015 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
00016 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.helpers.RandomPolicy;
00017 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures.BoltzmannDistribution;
00018 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
00019 import rlpark.plugin.rltoys.algorithms.functions.states.Projector;
00020 import rlpark.plugin.rltoys.algorithms.predictions.td.GTDLambda;
00021 import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD;
00022 import rlpark.plugin.rltoys.algorithms.representations.discretizer.TabularActionDiscretizer;
00023 import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.AbstractPartitionFactory;
00024 import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.BoundedSmallPartitionFactory;
00025 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.StateActionCoders;
00026 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCoders;
00027 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCodersHashing;
00028 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.Hashing;
00029 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.MurmurHashing;
00030 import rlpark.plugin.rltoys.algorithms.traces.ATraces;
00031 import rlpark.plugin.rltoys.envio.policy.Policy;
00032 import rlpark.plugin.rltoys.experiments.helpers.Runner;
00033 import rlpark.plugin.rltoys.experiments.helpers.Runner.RunnerEvent;
00034 import rlpark.plugin.rltoys.math.ranges.Range;
00035 import rlpark.plugin.rltoys.problems.ProblemBounded;
00036 import rlpark.plugin.rltoys.problems.puddleworld.ConstantFunction;
00037 import rlpark.plugin.rltoys.problems.puddleworld.LocalFeatureSumFunction;
00038 import rlpark.plugin.rltoys.problems.puddleworld.PuddleWorld;
00039 import rlpark.plugin.rltoys.problems.puddleworld.SmoothPuddle;
00040 import rlpark.plugin.rltoys.problems.puddleworld.TargetReachedL1NormTermination;
00041 import zephyr.plugin.core.api.Zephyr;
00042 import zephyr.plugin.core.api.monitoring.abstracts.Monitored;
00043 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00044 import zephyr.plugin.core.api.signals.Listener;
00045 import zephyr.plugin.core.api.synchronization.Clock;
00046 
00047 @SuppressWarnings("restriction")
00048 @Monitor
00049 public class OffPACPuddleWorld implements Runnable {
00050   private final Random random = new Random(0);
00051   private final PuddleWorld behaviourEnvironment = createEnvironment(random);
00052   private final PuddleWorld evaluationEnvironment = createEnvironment(random);
00053   private final Runner learningRunner;
00054   private final Runner evaluationRunner;
00055   // Visualization with Zephyr
00056   final FunctionProjected2D valueFunction;
00057   final Clock clock = new Clock("Off-PAC Demo");
00058   final Clock episodeClock = new Clock("Episodes");
00059 
00060   public OffPACPuddleWorld() {
00061     Policy behaviour = new RandomPolicy(random, behaviourEnvironment.actions());
00062     OffPolicyAgentEvaluable agent = createOffPACAgent(random, behaviourEnvironment, behaviour, .99);
00063     learningRunner = new Runner(behaviourEnvironment, agent, -1, 5000);
00064     evaluationRunner = new Runner(evaluationEnvironment, agent.createEvaluatedAgent(), -1, 5000);
00065     CriticAdapterFA criticAdapter = (CriticAdapterFA) ((OffPolicyAgentFA) agent).learner().predictor();
00066     valueFunction = new ValueFunction2D(criticAdapter.projector(), behaviourEnvironment, criticAdapter.predictor());
00067     connectEpisodesEventsForZephyr();
00068     Zephyr.advertise(clock, this);
00069   }
00070 
00071   private void connectEpisodesEventsForZephyr() {
00072     final double[] episodeInfos = new double[2];
00073     evaluationRunner.onEpisodeEnd.connect(new Listener<Runner.RunnerEvent>() {
00074       @Override
00075       public void listen(RunnerEvent eventInfo) {
00076         episodeInfos[0] = eventInfo.step.time;
00077         episodeInfos[1] = eventInfo.episodeReward;
00078         episodeClock.tick();
00079         System.out.println(String.format("Episodes %d: %d, %f", eventInfo.nbEpisodeDone, eventInfo.step.time,
00080                                          eventInfo.episodeReward));
00081       }
00082     });
00083     Zephyr.advertise(episodeClock, new Monitored() {
00084 
00085       @Override
00086       public double monitoredValue() {
00087         return episodeInfos[0];
00088       }
00089     }, "length");
00090     Zephyr.advertise(episodeClock, new Monitored() {
00091 
00092       @Override
00093       public double monitoredValue() {
00094         return episodeInfos[1];
00095       }
00096     }, "reward");
00097   }
00098 
00099   static private Hashing createHashing(Random random) {
00100     return new MurmurHashing(random, 1000000);
00101   }
00102 
00103   static private void setTileCoders(TileCoders projector) {
00104     projector.addFullTilings(10, 10);
00105     projector.includeActiveFeature();
00106   }
00107 
00108   static private AbstractPartitionFactory createPartitionFactory(Random random, Range[] observationRanges) {
00109     AbstractPartitionFactory partitionFactory = new BoundedSmallPartitionFactory(observationRanges);
00110     partitionFactory.setRandom(random, .2);
00111     return partitionFactory;
00112   }
00113 
00114   static public Projector createProjector(Random random, PuddleWorld problem) {
00115     final Range[] observationRanges = ((ProblemBounded) problem).getObservationRanges();
00116     final AbstractPartitionFactory discretizerFactory = createPartitionFactory(random, observationRanges);
00117     Hashing hashing = createHashing(random);
00118     TileCodersHashing projector = new TileCodersHashing(hashing, discretizerFactory, observationRanges.length);
00119     setTileCoders(projector);
00120     return projector;
00121   }
00122 
00123   static public StateToStateAction createToStateAction(Random random, PuddleWorld problem) {
00124     final Range[] observationRanges = problem.getObservationRanges();
00125     final AbstractPartitionFactory discretizerFactory = createPartitionFactory(random, observationRanges);
00126     TabularActionDiscretizer actionDiscretizer = new TabularActionDiscretizer(problem.actions());
00127     Hashing hashing = createHashing(random);
00128     StateActionCoders stateActionCoders = new StateActionCoders(actionDiscretizer, hashing, discretizerFactory,
00129                                                                 observationRanges.length);
00130     setTileCoders(stateActionCoders.tileCoders());
00131     return stateActionCoders;
00132   }
00133 
00134   private OffPolicyTD createCritic(Projector criticProjector, double gamma) {
00135     double alpha_v = .1 / criticProjector.vectorNorm();
00136     GTDLambda gtd = new GTDLambda(.4, gamma, alpha_v, 0, criticProjector.vectorSize(), new ATraces());
00137     return new CriticAdapterFA(criticProjector, gtd);
00138   }
00139 
00140   private OffPolicyAgentEvaluable createOffPACAgent(Random random, PuddleWorld problem, Policy behaviour, double gamma) {
00141     Projector criticProjector = createProjector(random, problem);
00142     OffPolicyTD critic = createCritic(criticProjector, gamma);
00143     StateToStateAction toStateAction = createToStateAction(random, problem);
00144     PolicyDistribution target = new BoltzmannDistribution(random, problem.actions(), toStateAction);
00145     double alpha_u = .001 / criticProjector.vectorNorm();
00146     ActorOffPolicy actor = new ActorLambdaOffPolicy(.4, gamma, target, alpha_u, toStateAction.vectorSize(),
00147                                                     new ATraces());
00148     return new OffPolicyAgentDirect(behaviour, new OffPAC(behaviour, critic, actor));
00149   }
00150 
00151   static private PuddleWorld createEnvironment(Random random) {
00152     PuddleWorld problem = new PuddleWorld(random, 2, new Range(0, 1), new Range(-.05, .05), .1);
00153     final int[] patternIndexes = new int[] { 0, 1 };
00154     final double smallStddev = 0.03;
00155     ContinuousFunction[] features = new ContinuousFunction[] { new ConstantFunction(1),
00156         new SmoothPuddle(patternIndexes, new double[] { .3, .6 }, new double[] { .1, smallStddev }),
00157         new SmoothPuddle(patternIndexes, new double[] { .4, .5 }, new double[] { smallStddev, .1 }),
00158         new SmoothPuddle(patternIndexes, new double[] { .8, .9 }, new double[] { smallStddev, .1 }) };
00159     final double puddleMalus = -2;
00160     double[] weights = new double[] { -1, puddleMalus, puddleMalus, puddleMalus };
00161     problem.setRewardFunction(new LocalFeatureSumFunction(weights, features, 0));
00162     problem.setTermination(new TargetReachedL1NormTermination(new double[] { 1, 1 }, .1));
00163     problem.setStart(new double[] { .2, .4 });
00164     return problem;
00165   }
00166 
00167   @Override
00168   public void run() {
00169     while (clock.tick()) {
00170       learningRunner.step();
00171       evaluationRunner.step();
00172     }
00173   }
00174 
00175   public static void main(String[] args) {
00176     new OffPACPuddleWorld().run();
00177   }
00178 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark