RLPark 1.0.0
Reinforcement Learning Framework in Java

LearningCurveJob.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.experiments.parametersweep.onpolicy.internal;
00002 
00003 import java.io.IOException;
00004 import java.io.Serializable;
00005 
00006 import rlpark.plugin.rltoys.experiments.helpers.ExperimentCounter;
00007 import rlpark.plugin.rltoys.experiments.helpers.Runner;
00008 import rlpark.plugin.rltoys.experiments.helpers.Runner.RunnerEvent;
00009 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.Parameters;
00010 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.RLParameters;
00011 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.ReinforcementLearningContext;
00012 import zephyr.plugin.core.api.internal.monitoring.fileloggers.LoggerRow;
00013 import zephyr.plugin.core.api.signals.Listener;
00014 
00015 @SuppressWarnings("restriction")
00016 public class LearningCurveJob implements Runnable, Serializable {
00017   private static final long serialVersionUID = -5212166519929349880L;
00018   private final Parameters parameters;
00019   private final ReinforcementLearningContext context;
00020   private final ExperimentCounter counter;
00021 
00022   public LearningCurveJob(ReinforcementLearningContext context, Parameters parameters, ExperimentCounter counter) {
00023     this.context = context;
00024     this.parameters = parameters;
00025     this.counter = counter.clone();
00026   }
00027 
00028   protected Listener<RunnerEvent> createRewardListener(final LoggerRow loggerRow) {
00029     return new Listener<Runner.RunnerEvent>() {
00030       @Override
00031       public void listen(RunnerEvent eventInfo) {
00032         loggerRow.writeRow(eventInfo.step.time, eventInfo.step.r_tp1);
00033       }
00034     };
00035   }
00036 
00037   protected Listener<RunnerEvent> createEpisodeListener(final LoggerRow loggerRow) {
00038     return new Listener<Runner.RunnerEvent>() {
00039       @Override
00040       public void listen(RunnerEvent eventInfo) {
00041         loggerRow.writeRow(eventInfo.nbEpisodeDone, eventInfo.step.time);
00042       }
00043     };
00044   }
00045 
00046   protected void setupEpisodeListener(Runner runner, LoggerRow loggerRow) {
00047     loggerRow.writeLegend("Episode", "Steps");
00048     runner.onEpisodeEnd.connect(createEpisodeListener(loggerRow));
00049   }
00050 
00051   protected void setupRewardListener(Runner runner, LoggerRow loggerRow) {
00052     loggerRow.writeLegend("Time", "Reward");
00053     runner.onTimeStep.connect(createRewardListener(loggerRow));
00054   }
00055 
00056   @Override
00057   public void run() {
00058     Runner runner = context.createRunner(counter.currentIndex(), parameters);
00059     String fileName = counter.folderFilename(context.folderPath(), context.fileName());
00060     System.out.println(fileName);
00061     LoggerRow loggerRow = null;
00062     try {
00063       loggerRow = new LoggerRow(fileName, false);
00064     } catch (IOException e) {
00065       e.printStackTrace();
00066       return;
00067     }
00068     if (RLParameters.nbEpisode(parameters) == 1)
00069       setupRewardListener(runner, loggerRow);
00070     else
00071       setupEpisodeListener(runner, loggerRow);
00072     runner.run();
00073     loggerRow.close();
00074   }
00075 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark