RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.parametersweep.prediction; 00002 00003 import rlpark.plugin.rltoys.algorithms.predictions.supervised.LearningAlgorithm; 00004 import rlpark.plugin.rltoys.experiments.helpers.ExperimentCounter; 00005 import rlpark.plugin.rltoys.experiments.parametersweep.interfaces.JobWithParameters; 00006 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.Parameters; 00007 import rlpark.plugin.rltoys.experiments.scheduling.interfaces.TimedJob; 00008 import rlpark.plugin.rltoys.problems.PredictionProblem; 00009 import rlpark.plugin.rltoys.utils.Utils; 00010 import zephyr.plugin.core.api.synchronization.Chrono; 00011 00012 public class PredictionSweepJob implements JobWithParameters, TimedJob { 00013 private static final long serialVersionUID = -1601304080766261525L; 00014 private final PredictionContext context; 00015 private final Parameters parameters; 00016 private final int counter; 00017 00018 public PredictionSweepJob(PredictionContext context, Parameters parameters, ExperimentCounter counter) { 00019 this.context = context; 00020 this.parameters = parameters; 00021 this.counter = counter.currentIndex(); 00022 } 00023 00024 @Override 00025 public void run() { 00026 Chrono chrono = new Chrono(); 00027 PredictionProblem problem = context.problemFactory().createProblem(counter, parameters); 00028 LearningAlgorithm learner = (LearningAlgorithm) context.learnerFactory() 00029 .createLearner(counter, problem, parameters); 00030 PredictorEvaluator evaluator = context.createPredictorEvaluator(parameters); 00031 int nbLearningSteps = PredictionParameters.nbLearningSteps(parameters); 00032 int nbEvaluationSteps = PredictionParameters.nbEvaluationSteps(parameters); 00033 try { 00034 boolean resultEnabled = run(null, problem, learner, nbLearningSteps) 00035 && run(evaluator, problem, learner, nbEvaluationSteps); 00036 if (!resultEnabled) 00037 evaluator.worstResultUntilEnd(); 00038 } catch (Throwable e) { 00039 e.printStackTrace(System.err); 00040 evaluator.worstResultUntilEnd(); 00041 } 00042 evaluator.putResult(parameters); 00043 parameters.setComputationTimeMillis(chrono.getCurrentMillis()); 00044 } 00045 00046 private boolean run(PredictorEvaluator evaluator, PredictionProblem problem, LearningAlgorithm learner, long nbSteps) { 00047 for (int t = 0; t < nbSteps; t++) { 00048 boolean update = problem.update(); 00049 if (!update) 00050 return true; 00051 if (evaluator != null) { 00052 double prediction = learner.predict(problem.input()); 00053 evaluator.registerPrediction(t, problem.target(), prediction); 00054 if (!Utils.checkValue(prediction)) 00055 return false; 00056 } 00057 double error = learner.learn(problem.input(), problem.target()); 00058 if (!Utils.checkValue(error)) 00059 return false; 00060 } 00061 return true; 00062 } 00063 00064 @Override 00065 public long getComputationTimeMillis() { 00066 return parameters.getComputationTimeMillis(); 00067 } 00068 00069 @Override 00070 public Parameters parameters() { 00071 return parameters; 00072 } 00073 }