RLPark 1.0.0
Reinforcement Learning Framework in Java

ActorCriticPendulum.java

Go to the documentation of this file.
00001 package rlpark.example.demos.learning;
00002 
00003 import java.util.Random;
00004 
00005 import rlpark.plugin.rltoys.agents.functions.ValueFunction2D;
00006 import rlpark.plugin.rltoys.agents.functions.FunctionProjected2D;
00007 import rlpark.plugin.rltoys.agents.rl.LearnerAgentFA;
00008 import rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy.Actor;
00009 import rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy.AverageRewardActorCritic;
00010 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
00011 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.helpers.ScaledPolicyDistribution;
00012 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures.NormalDistributionScaled;
00013 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD;
00014 import rlpark.plugin.rltoys.algorithms.predictions.td.TDLambda;
00015 import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.AbstractPartitionFactory;
00016 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCodersNoHashing;
00017 import rlpark.plugin.rltoys.experiments.helpers.Runner;
00018 import rlpark.plugin.rltoys.experiments.helpers.Runner.RunnerEvent;
00019 import rlpark.plugin.rltoys.math.ranges.Range;
00020 import rlpark.plugin.rltoys.problems.pendulum.SwingPendulum;
00021 import zephyr.plugin.core.api.Zephyr;
00022 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00023 import zephyr.plugin.core.api.signals.Listener;
00024 import zephyr.plugin.core.api.synchronization.Clock;
00025 
00026 
00027 @Monitor
00028 public class ActorCriticPendulum implements Runnable {
00029   final FunctionProjected2D valueFunction;
00030   double reward;
00031   private final SwingPendulum problem;
00032   private final Clock clock = new Clock("ActorCriticPendulum");
00033   private final LearnerAgentFA agent;
00034   private final Runner runner;
00035 
00036   public ActorCriticPendulum() {
00037     Random random = new Random(0);
00038     problem = new SwingPendulum(null, false);
00039     TileCodersNoHashing tileCoders = new TileCodersNoHashing(problem.getObservationRanges());
00040     ((AbstractPartitionFactory) tileCoders.discretizerFactory()).setRandom(random, .2);
00041     tileCoders.addFullTilings(10, 10);
00042     double gamma = 1.0;
00043     double lambda = .5;
00044     double vectorNorm = tileCoders.vectorNorm();
00045     int vectorSize = tileCoders.vectorSize();
00046     OnPolicyTD critic = new TDLambda(lambda, gamma, .1 / vectorNorm, vectorSize);
00047     PolicyDistribution policyDistribution = new NormalDistributionScaled(random, 0.0, 1.0);
00048     policyDistribution = new ScaledPolicyDistribution(policyDistribution, new Range(-2, 2), problem.actionRanges()[0]);
00049     Actor actor = new Actor(policyDistribution, 0.001 / vectorNorm, vectorSize);
00050     AverageRewardActorCritic actorCritic = new AverageRewardActorCritic(.0001, critic, actor);
00051     agent = new LearnerAgentFA(actorCritic, tileCoders);
00052     valueFunction = new ValueFunction2D(tileCoders, problem, critic);
00053     runner = new Runner(problem, agent, -1, 1000);
00054     runner.onEpisodeEnd.connect(new Listener<Runner.RunnerEvent>() {
00055       @Override
00056       public void listen(RunnerEvent eventInfo) {
00057         System.out.println(String.format("Episode %d: %f", eventInfo.nbEpisodeDone, eventInfo.episodeReward));
00058       }
00059     });
00060     Zephyr.advertise(clock, this);
00061   }
00062 
00063   @Override
00064   public void run() {
00065     while (clock.tick())
00066       runner.step();
00067   }
00068 
00069   public static void main(String[] args) {
00070     new ActorCriticPendulum().run();
00071   }
00072 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark