RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.parametersweep.onpolicy; 00002 00003 import rlpark.plugin.rltoys.agents.representations.RepresentationFactory; 00004 import rlpark.plugin.rltoys.experiments.helpers.ExperimentCounter; 00005 import rlpark.plugin.rltoys.experiments.parametersweep.onpolicy.internal.OnPolicyEvaluationContext; 00006 import rlpark.plugin.rltoys.experiments.parametersweep.onpolicy.internal.OnPolicyRewardMonitor; 00007 import rlpark.plugin.rltoys.experiments.parametersweep.onpolicy.internal.RewardMonitorAverage; 00008 import rlpark.plugin.rltoys.experiments.parametersweep.onpolicy.internal.RewardMonitorEpisode; 00009 import rlpark.plugin.rltoys.experiments.parametersweep.onpolicy.internal.SweepJob; 00010 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.Parameters; 00011 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.AgentFactory; 00012 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.ProblemFactory; 00013 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.RLParameters; 00014 00015 public class ContextEvaluation extends AbstractContextOnPolicy implements OnPolicyEvaluationContext { 00016 private static final long serialVersionUID = -5926779335932073094L; 00017 private final int nbRewardCheckpoint; 00018 00019 public ContextEvaluation(ProblemFactory environmentFactory, RepresentationFactory representationFactory, 00020 AgentFactory agentFactory, int nbRewardCheckpoint) { 00021 super(environmentFactory, representationFactory, agentFactory); 00022 this.nbRewardCheckpoint = nbRewardCheckpoint; 00023 } 00024 00025 @Override 00026 public Runnable createJob(Parameters parameters, ExperimentCounter counter) { 00027 return new SweepJob(this, parameters, counter); 00028 } 00029 00030 private OnPolicyRewardMonitor createRewardMonitor(String prefix, int nbBins, Parameters parameters) { 00031 int nbEpisode = RLParameters.nbEpisode(parameters); 00032 int maxEpisodeTimeSteps = RLParameters.maxEpisodeTimeSteps(parameters); 00033 if (nbEpisode == 1 || parameters.hasFlag(RLParameters.OnPolicyTimeStepsEvaluationFlag)) 00034 return new RewardMonitorAverage(prefix, nbBins, maxEpisodeTimeSteps); 00035 return new RewardMonitorEpisode(prefix, nbBins, nbEpisode); 00036 } 00037 00038 00039 @Override 00040 public OnPolicyRewardMonitor createRewardMonitor(Parameters parameters) { 00041 return createRewardMonitor("", nbRewardCheckpoint, parameters); 00042 } 00043 }