RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.parametersweep.onpolicy; 00002 00003 import rlpark.plugin.rltoys.agents.representations.RepresentationFactory; 00004 import rlpark.plugin.rltoys.envio.rl.RLAgent; 00005 import rlpark.plugin.rltoys.experiments.helpers.ExperimentCounter; 00006 import rlpark.plugin.rltoys.experiments.helpers.Runner; 00007 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.Parameters; 00008 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.RunInfo; 00009 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.AgentFactory; 00010 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.ProblemFactory; 00011 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.RLParameters; 00012 import rlpark.plugin.rltoys.experiments.parametersweep.reinforcementlearning.ReinforcementLearningContext; 00013 import rlpark.plugin.rltoys.problems.RLProblem; 00014 00015 public abstract class AbstractContextOnPolicy implements ReinforcementLearningContext { 00016 private static final long serialVersionUID = -6212106048889219995L; 00017 private final AgentFactory agentFactory; 00018 private final ProblemFactory environmentFactory; 00019 private final RepresentationFactory representationFactory; 00020 00021 public AbstractContextOnPolicy(ProblemFactory environmentFactory, RepresentationFactory representationFactory, 00022 AgentFactory agentFactory) { 00023 this.environmentFactory = environmentFactory; 00024 this.representationFactory = representationFactory; 00025 this.agentFactory = agentFactory; 00026 } 00027 00028 @Override 00029 public Runner createRunner(int counter, Parameters parameters) { 00030 RLProblem problem = environmentFactory.createEnvironment(ExperimentCounter.newRandom(counter)); 00031 RLAgent agent = agentFactory.createAgent(counter, problem, parameters, representationFactory); 00032 int nbEpisode = RLParameters.nbEpisode(parameters); 00033 int maxEpisodeTimeSteps = RLParameters.maxEpisodeTimeSteps(parameters); 00034 return new Runner(problem, agent, nbEpisode, maxEpisodeTimeSteps); 00035 } 00036 00037 @Override 00038 public String fileName() { 00039 return ExperimentCounter.DefaultFileName; 00040 } 00041 00042 @Override 00043 public String folderPath() { 00044 return environmentFactory.label() + "/" + agentFactory.label(); 00045 } 00046 00047 public AgentFactory agentFactory() { 00048 return agentFactory; 00049 } 00050 00051 public ProblemFactory problemFactory() { 00052 return environmentFactory; 00053 } 00054 00055 public Parameters contextParameters() { 00056 RunInfo infos = new RunInfo(); 00057 infos.enableFlag(agentFactory.label()); 00058 infos.enableFlag(environmentFactory.label()); 00059 Parameters parameters = new Parameters(infos); 00060 environmentFactory.setExperimentParameters(parameters); 00061 return parameters; 00062 } 00063 }