RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.helpers; 00002 00003 import java.io.Serializable; 00004 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.envio.rl.RLAgent; 00007 import rlpark.plugin.rltoys.envio.rl.TRStep; 00008 import rlpark.plugin.rltoys.problems.RLProblem; 00009 import zephyr.plugin.core.api.monitoring.abstracts.DataMonitor; 00010 import zephyr.plugin.core.api.monitoring.abstracts.MonitorContainer; 00011 import zephyr.plugin.core.api.monitoring.abstracts.Monitored; 00012 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00013 import zephyr.plugin.core.api.signals.Signal; 00014 00015 public class Runner implements Serializable, MonitorContainer { 00016 private static final long serialVersionUID = 465593140388569561L; 00017 00018 @SuppressWarnings("serial") 00019 static public class RunnerEvent implements Serializable { 00020 public int nbTotalTimeSteps = 0; 00021 public int nbEpisodeDone = 0; 00022 public TRStep step = null; 00023 public double episodeReward = Double.NaN; 00024 00025 @Override 00026 public String toString() { 00027 return String.format("Ep(%d): %s on %d", nbEpisodeDone, step, nbTotalTimeSteps); 00028 } 00029 } 00030 00031 public final Signal<RunnerEvent> onEpisodeEnd = new Signal<RunnerEvent>(); 00032 public final Signal<RunnerEvent> onTimeStep = new Signal<RunnerEvent>(); 00033 protected final RunnerEvent runnerEvent = new RunnerEvent(); 00034 @Monitor 00035 private final RLAgent agent; 00036 @Monitor 00037 private final RLProblem problem; 00038 private Action agentAction = null; 00039 private final int maxEpisodeTimeSteps; 00040 private final int nbEpisode; 00041 00042 public Runner(RLProblem problem, RLAgent agent) { 00043 this(problem, agent, -1, -1); 00044 } 00045 00046 public Runner(RLProblem environment, RLAgent agent, int nbEpisode, int maxEpisodeTimeSteps) { 00047 this.problem = environment; 00048 this.agent = agent; 00049 this.nbEpisode = nbEpisode; 00050 this.maxEpisodeTimeSteps = maxEpisodeTimeSteps; 00051 } 00052 00053 public void run() { 00054 for (int i = 0; i < nbEpisode; i++) 00055 runEpisode(); 00056 } 00057 00058 public void runEpisode() { 00059 assert runnerEvent.step == null || runnerEvent.step.isEpisodeEnding(); 00060 int currentEpisode = runnerEvent.nbEpisodeDone; 00061 do { 00062 step(); 00063 } while (currentEpisode == runnerEvent.nbEpisodeDone); 00064 assert runnerEvent.step.isEpisodeEnding(); 00065 } 00066 00067 public void step() { 00068 assert nbEpisode < 0 || runnerEvent.nbEpisodeDone < nbEpisode; 00069 if (runnerEvent.step == null || runnerEvent.step.isEpisodeEnding()) { 00070 runnerEvent.step = problem.initialize(); 00071 runnerEvent.episodeReward = 0; 00072 agentAction = null; 00073 assert runnerEvent.step.isEpisodeStarting(); 00074 } else { 00075 runnerEvent.step = problem.step(agentAction); 00076 if (runnerEvent.step.time == maxEpisodeTimeSteps) 00077 runnerEvent.step = problem.forceEndEpisode(); 00078 } 00079 agentAction = agent.getAtp1(runnerEvent.step); 00080 runnerEvent.episodeReward += runnerEvent.step.r_tp1; 00081 runnerEvent.nbTotalTimeSteps++; 00082 onTimeStep.fire(runnerEvent); 00083 if (runnerEvent.step.isEpisodeEnding()) { 00084 runnerEvent.nbEpisodeDone += 1; 00085 onEpisodeEnd.fire(runnerEvent); 00086 } 00087 } 00088 00089 public RunnerEvent runnerEvent() { 00090 return runnerEvent; 00091 } 00092 00093 public RLAgent agent() { 00094 return agent; 00095 } 00096 00097 @Override 00098 public void addToMonitor(DataMonitor monitor) { 00099 monitor.add("Reward", new Monitored() { 00100 @Override 00101 public double monitoredValue() { 00102 if (runnerEvent.step == null) 00103 return 0; 00104 return runnerEvent.step.r_tp1; 00105 } 00106 }); 00107 } 00108 }