RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.example.demos.learning; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.agents.functions.ValueFunction2D; 00006 import rlpark.plugin.rltoys.agents.functions.FunctionProjected2D; 00007 import rlpark.plugin.rltoys.agents.rl.LearnerAgentFA; 00008 import rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy.Actor; 00009 import rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy.AverageRewardActorCritic; 00010 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00011 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.helpers.ScaledPolicyDistribution; 00012 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures.NormalDistributionScaled; 00013 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD; 00014 import rlpark.plugin.rltoys.algorithms.predictions.td.TDLambda; 00015 import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.AbstractPartitionFactory; 00016 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCodersNoHashing; 00017 import rlpark.plugin.rltoys.experiments.helpers.Runner; 00018 import rlpark.plugin.rltoys.experiments.helpers.Runner.RunnerEvent; 00019 import rlpark.plugin.rltoys.math.ranges.Range; 00020 import rlpark.plugin.rltoys.problems.pendulum.SwingPendulum; 00021 import zephyr.plugin.core.api.Zephyr; 00022 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00023 import zephyr.plugin.core.api.signals.Listener; 00024 import zephyr.plugin.core.api.synchronization.Clock; 00025 00026 00027 @Monitor 00028 public class ActorCriticPendulum implements Runnable { 00029 final FunctionProjected2D valueFunction; 00030 double reward; 00031 private final SwingPendulum problem; 00032 private final Clock clock = new Clock("ActorCriticPendulum"); 00033 private final LearnerAgentFA agent; 00034 private final Runner runner; 00035 00036 public ActorCriticPendulum() { 00037 Random random = new Random(0); 00038 problem = new SwingPendulum(null, false); 00039 TileCodersNoHashing tileCoders = new TileCodersNoHashing(problem.getObservationRanges()); 00040 ((AbstractPartitionFactory) tileCoders.discretizerFactory()).setRandom(random, .2); 00041 tileCoders.addFullTilings(10, 10); 00042 double gamma = 1.0; 00043 double lambda = .5; 00044 double vectorNorm = tileCoders.vectorNorm(); 00045 int vectorSize = tileCoders.vectorSize(); 00046 OnPolicyTD critic = new TDLambda(lambda, gamma, .1 / vectorNorm, vectorSize); 00047 PolicyDistribution policyDistribution = new NormalDistributionScaled(random, 0.0, 1.0); 00048 policyDistribution = new ScaledPolicyDistribution(policyDistribution, new Range(-2, 2), problem.actionRanges()[0]); 00049 Actor actor = new Actor(policyDistribution, 0.001 / vectorNorm, vectorSize); 00050 AverageRewardActorCritic actorCritic = new AverageRewardActorCritic(.0001, critic, actor); 00051 agent = new LearnerAgentFA(actorCritic, tileCoders); 00052 valueFunction = new ValueFunction2D(tileCoders, problem, critic); 00053 runner = new Runner(problem, agent, -1, 1000); 00054 runner.onEpisodeEnd.connect(new Listener<Runner.RunnerEvent>() { 00055 @Override 00056 public void listen(RunnerEvent eventInfo) { 00057 System.out.println(String.format("Episode %d: %f", eventInfo.nbEpisodeDone, eventInfo.episodeReward)); 00058 } 00059 }); 00060 Zephyr.advertise(clock, this); 00061 } 00062 00063 @Override 00064 public void run() { 00065 while (clock.tick()) 00066 runner.step(); 00067 } 00068 00069 public static void main(String[] args) { 00070 new ActorCriticPendulum().run(); 00071 } 00072 }