RLPark 1.0.0
Reinforcement Learning Framework in Java

Runner.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.experiments.helpers;
00002 
00003 import java.io.Serializable;
00004 
00005 import rlpark.plugin.rltoys.envio.actions.Action;
00006 import rlpark.plugin.rltoys.envio.rl.RLAgent;
00007 import rlpark.plugin.rltoys.envio.rl.TRStep;
00008 import rlpark.plugin.rltoys.problems.RLProblem;
00009 import zephyr.plugin.core.api.monitoring.abstracts.DataMonitor;
00010 import zephyr.plugin.core.api.monitoring.abstracts.MonitorContainer;
00011 import zephyr.plugin.core.api.monitoring.abstracts.Monitored;
00012 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00013 import zephyr.plugin.core.api.signals.Signal;
00014 
00015 public class Runner implements Serializable, MonitorContainer {
00016   private static final long serialVersionUID = 465593140388569561L;
00017 
00018   @SuppressWarnings("serial")
00019   static public class RunnerEvent implements Serializable {
00020     public int nbTotalTimeSteps = 0;
00021     public int nbEpisodeDone = 0;
00022     public TRStep step = null;
00023     public double episodeReward = Double.NaN;
00024 
00025     @Override
00026     public String toString() {
00027       return String.format("Ep(%d): %s on %d", nbEpisodeDone, step, nbTotalTimeSteps);
00028     }
00029   }
00030 
00031   public final Signal<RunnerEvent> onEpisodeEnd = new Signal<RunnerEvent>();
00032   public final Signal<RunnerEvent> onTimeStep = new Signal<RunnerEvent>();
00033   protected final RunnerEvent runnerEvent = new RunnerEvent();
00034   @Monitor
00035   private final RLAgent agent;
00036   @Monitor
00037   private final RLProblem problem;
00038   private Action agentAction = null;
00039   private final int maxEpisodeTimeSteps;
00040   private final int nbEpisode;
00041 
00042   public Runner(RLProblem problem, RLAgent agent) {
00043     this(problem, agent, -1, -1);
00044   }
00045 
00046   public Runner(RLProblem environment, RLAgent agent, int nbEpisode, int maxEpisodeTimeSteps) {
00047     this.problem = environment;
00048     this.agent = agent;
00049     this.nbEpisode = nbEpisode;
00050     this.maxEpisodeTimeSteps = maxEpisodeTimeSteps;
00051   }
00052 
00053   public void run() {
00054     for (int i = 0; i < nbEpisode; i++)
00055       runEpisode();
00056   }
00057 
00058   public void runEpisode() {
00059     assert runnerEvent.step == null || runnerEvent.step.isEpisodeEnding();
00060     int currentEpisode = runnerEvent.nbEpisodeDone;
00061     do {
00062       step();
00063     } while (currentEpisode == runnerEvent.nbEpisodeDone);
00064     assert runnerEvent.step.isEpisodeEnding();
00065   }
00066 
00067   public void step() {
00068     assert nbEpisode < 0 || runnerEvent.nbEpisodeDone < nbEpisode;
00069     if (runnerEvent.step == null || runnerEvent.step.isEpisodeEnding()) {
00070       runnerEvent.step = problem.initialize();
00071       runnerEvent.episodeReward = 0;
00072       agentAction = null;
00073       assert runnerEvent.step.isEpisodeStarting();
00074     } else {
00075       runnerEvent.step = problem.step(agentAction);
00076       if (runnerEvent.step.time == maxEpisodeTimeSteps)
00077         runnerEvent.step = problem.forceEndEpisode();
00078     }
00079     agentAction = agent.getAtp1(runnerEvent.step);
00080     runnerEvent.episodeReward += runnerEvent.step.r_tp1;
00081     runnerEvent.nbTotalTimeSteps++;
00082     onTimeStep.fire(runnerEvent);
00083     if (runnerEvent.step.isEpisodeEnding()) {
00084       runnerEvent.nbEpisodeDone += 1;
00085       onEpisodeEnd.fire(runnerEvent);
00086     }
00087   }
00088 
00089   public RunnerEvent runnerEvent() {
00090     return runnerEvent;
00091   }
00092 
00093   public RLAgent agent() {
00094     return agent;
00095   }
00096 
00097   @Override
00098   public void addToMonitor(DataMonitor monitor) {
00099     monitor.add("Reward", new Monitored() {
00100       @Override
00101       public double monitoredValue() {
00102         if (runnerEvent.step == null)
00103           return 0;
00104         return runnerEvent.step.r_tp1;
00105       }
00106     });
00107   }
00108 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark