RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.parametersweep.onpolicy.internal; 00002 00003 import java.io.IOException; 00004 import java.io.Serializable; 00005 00006 import rlpark.plugin.rltoys.experiments.helpers.ExperimentCounter; 00007 import rlpark.plugin.rltoys.experiments.helpers.Runner; 00008 import rlpark.plugin.rltoys.experiments.helpers.Runner.RunnerEvent; 00009 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.Parameters; 00010 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.RLParameters; 00011 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.ReinforcementLearningContext; 00012 import zephyr.plugin.core.api.internal.monitoring.fileloggers.LoggerRow; 00013 import zephyr.plugin.core.api.signals.Listener; 00014 00015 @SuppressWarnings("restriction") 00016 public class LearningCurveJob implements Runnable, Serializable { 00017 private static final long serialVersionUID = -5212166519929349880L; 00018 private final Parameters parameters; 00019 private final ReinforcementLearningContext context; 00020 private final ExperimentCounter counter; 00021 00022 public LearningCurveJob(ReinforcementLearningContext context, Parameters parameters, ExperimentCounter counter) { 00023 this.context = context; 00024 this.parameters = parameters; 00025 this.counter = counter.clone(); 00026 } 00027 00028 protected Listener<RunnerEvent> createRewardListener(final LoggerRow loggerRow) { 00029 return new Listener<Runner.RunnerEvent>() { 00030 @Override 00031 public void listen(RunnerEvent eventInfo) { 00032 loggerRow.writeRow(eventInfo.step.time, eventInfo.step.r_tp1); 00033 } 00034 }; 00035 } 00036 00037 protected Listener<RunnerEvent> createEpisodeListener(final LoggerRow loggerRow) { 00038 return new Listener<Runner.RunnerEvent>() { 00039 @Override 00040 public void listen(RunnerEvent eventInfo) { 00041 loggerRow.writeRow(eventInfo.nbEpisodeDone, eventInfo.step.time); 00042 } 00043 }; 00044 } 00045 00046 protected void setupEpisodeListener(Runner runner, LoggerRow loggerRow) { 00047 loggerRow.writeLegend("Episode", "Steps"); 00048 runner.onEpisodeEnd.connect(createEpisodeListener(loggerRow)); 00049 } 00050 00051 protected void setupRewardListener(Runner runner, LoggerRow loggerRow) { 00052 loggerRow.writeLegend("Time", "Reward"); 00053 runner.onTimeStep.connect(createRewardListener(loggerRow)); 00054 } 00055 00056 @Override 00057 public void run() { 00058 Runner runner = context.createRunner(counter.currentIndex(), parameters); 00059 String fileName = counter.folderFilename(context.folderPath(), context.fileName()); 00060 System.out.println(fileName); 00061 LoggerRow loggerRow = null; 00062 try { 00063 loggerRow = new LoggerRow(fileName, false); 00064 } catch (IOException e) { 00065 e.printStackTrace(); 00066 return; 00067 } 00068 if (RLParameters.nbEpisode(parameters) == 1) 00069 setupRewardListener(runner, loggerRow); 00070 else 00071 setupEpisodeListener(runner, loggerRow); 00072 runner.run(); 00073 loggerRow.close(); 00074 } 00075 }