RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.parametersweep.prediction; 00002 00003 import java.util.List; 00004 00005 import rlpark.plugin.rltoys.experiments.helpers.ExperimentCounter; 00006 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.Parameters; 00007 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.ParametersProvider; 00008 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.RunInfo; 00009 import rlpark.plugin.rltoys.utils.Utils; 00010 00011 public abstract class PredictionSweepContext implements PredictionContext { 00012 private static final long serialVersionUID = 6250984799273140622L; 00013 private final PredictionProblemFactory problemFactory; 00014 private final PredictionLearnerFactory learnerFactory; 00015 00016 public PredictionSweepContext(PredictionProblemFactory problemFactory, PredictionLearnerFactory learnerFactory) { 00017 this.problemFactory = problemFactory; 00018 this.learnerFactory = learnerFactory; 00019 } 00020 00021 @Override 00022 public String folderPath() { 00023 return problemFactory.label() + "/" + learnerFactory.label(); 00024 } 00025 00026 @Override 00027 public String fileName() { 00028 return ExperimentCounter.DefaultFileName; 00029 } 00030 00031 @Override 00032 public Runnable createJob(Parameters parameters, ExperimentCounter counter) { 00033 return new PredictionSweepJob(this, parameters, counter); 00034 } 00035 00036 public List<Parameters> provideParameters() { 00037 RunInfo infos = new RunInfo(); 00038 infos.enableFlag(problemFactory.label()); 00039 infos.enableFlag(learnerFactory.label()); 00040 infos.put(Parameters.PerformanceNbCheckPoint, Parameters.DefaultNbPerformanceCheckpoints); 00041 List<Parameters> parameters = Utils.asList(new Parameters(infos)); 00042 if (problemFactory instanceof ParametersProvider) 00043 parameters = ((ParametersProvider) problemFactory).provideParameters(parameters); 00044 if (learnerFactory instanceof ParametersProvider) 00045 parameters = ((ParametersProvider) learnerFactory).provideParameters(parameters); 00046 return parameters; 00047 } 00048 00049 @Override 00050 public PredictionProblemFactory problemFactory() { 00051 return problemFactory; 00052 } 00053 00054 @Override 00055 public PredictionLearnerFactory learnerFactory() { 00056 return learnerFactory; 00057 } 00058 }