RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.example.demos.learning; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.agents.functions.FunctionProjected2D; 00006 import rlpark.plugin.rltoys.agents.functions.ValueFunction2D; 00007 import rlpark.plugin.rltoys.algorithms.control.acting.EpsilonGreedy; 00008 import rlpark.plugin.rltoys.algorithms.control.sarsa.Sarsa; 00009 import rlpark.plugin.rltoys.algorithms.control.sarsa.SarsaControl; 00010 import rlpark.plugin.rltoys.algorithms.functions.stateactions.TabularAction; 00011 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCodersNoHashing; 00012 import rlpark.plugin.rltoys.algorithms.traces.RTraces; 00013 import rlpark.plugin.rltoys.envio.actions.Action; 00014 import rlpark.plugin.rltoys.envio.policy.Policy; 00015 import rlpark.plugin.rltoys.envio.rl.TRStep; 00016 import rlpark.plugin.rltoys.math.vector.BinaryVector; 00017 import rlpark.plugin.rltoys.math.vector.RealVector; 00018 import rlpark.plugin.rltoys.math.vector.implementations.Vectors; 00019 import rlpark.plugin.rltoys.problems.mountaincar.MountainCar; 00020 import zephyr.plugin.core.api.Zephyr; 00021 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00022 import zephyr.plugin.core.api.synchronization.Clock; 00023 00024 @Monitor 00025 public class SarsaMountainCar implements Runnable { 00026 final FunctionProjected2D valueFunctionDisplay; 00027 private final MountainCar problem; 00028 private final SarsaControl control; 00029 private final TileCodersNoHashing projector; 00030 private final Clock clock = new Clock("SarsaMountainCar"); 00031 00032 public SarsaMountainCar() { 00033 problem = new MountainCar(null); 00034 projector = new TileCodersNoHashing(problem.getObservationRanges()); 00035 projector.addFullTilings(10, 10); 00036 projector.includeActiveFeature(); 00037 TabularAction toStateAction = new TabularAction(problem.actions(), projector.vectorNorm(), projector.vectorSize()); 00038 toStateAction.includeActiveFeature(); 00039 double alpha = .15 / projector.vectorNorm(); 00040 double gamma = 0.99; 00041 double lambda = .3; 00042 Sarsa sarsa = new Sarsa(alpha, gamma, lambda, toStateAction.vectorSize(), new RTraces()); 00043 double epsilon = 0.01; 00044 Policy acting = new EpsilonGreedy(new Random(0), problem.actions(), toStateAction, sarsa, epsilon); 00045 control = new SarsaControl(acting, toStateAction, sarsa); 00046 valueFunctionDisplay = new ValueFunction2D(projector, problem, sarsa); 00047 Zephyr.advertise(clock, this); 00048 } 00049 00050 @Override 00051 public void run() { 00052 TRStep step = problem.initialize(); 00053 int nbEpisode = 0; 00054 RealVector x_t = null; 00055 while (clock.tick()) { 00056 BinaryVector x_tp1 = projector.project(step.o_tp1); 00057 Action action = control.step(x_t, step.a_t, x_tp1, step.r_tp1); 00058 x_t = Vectors.bufferedCopy(x_tp1, x_t); 00059 if (step.isEpisodeEnding()) { 00060 System.out.println(String.format("Episode %d: %d steps", nbEpisode, step.time)); 00061 step = problem.initialize(); 00062 x_t = null; 00063 nbEpisode++; 00064 } else 00065 step = problem.step(action); 00066 } 00067 } 00068 00069 public static void main(String[] args) { 00070 new SarsaMountainCar().run(); 00071 } 00072 }