RLPark 1.0.0
Reinforcement Learning Framework in Java

PredictionSweepJob.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.experiments.parametersweep.prediction;
00002 
00003 import rlpark.plugin.rltoys.algorithms.predictions.supervised.LearningAlgorithm;
00004 import rlpark.plugin.rltoys.experiments.helpers.ExperimentCounter;
00005 import rlpark.plugin.rltoys.experiments.parametersweep.interfaces.JobWithParameters;
00006 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.Parameters;
00007 import rlpark.plugin.rltoys.experiments.scheduling.interfaces.TimedJob;
00008 import rlpark.plugin.rltoys.problems.PredictionProblem;
00009 import rlpark.plugin.rltoys.utils.Utils;
00010 import zephyr.plugin.core.api.synchronization.Chrono;
00011 
00012 public class PredictionSweepJob implements JobWithParameters, TimedJob {
00013   private static final long serialVersionUID = -1601304080766261525L;
00014   private final PredictionContext context;
00015   private final Parameters parameters;
00016   private final int counter;
00017 
00018   public PredictionSweepJob(PredictionContext context, Parameters parameters, ExperimentCounter counter) {
00019     this.context = context;
00020     this.parameters = parameters;
00021     this.counter = counter.currentIndex();
00022   }
00023 
00024   @Override
00025   public void run() {
00026     Chrono chrono = new Chrono();
00027     PredictionProblem problem = context.problemFactory().createProblem(counter, parameters);
00028     LearningAlgorithm learner = (LearningAlgorithm) context.learnerFactory()
00029         .createLearner(counter, problem, parameters);
00030     PredictorEvaluator evaluator = context.createPredictorEvaluator(parameters);
00031     int nbLearningSteps = PredictionParameters.nbLearningSteps(parameters);
00032     int nbEvaluationSteps = PredictionParameters.nbEvaluationSteps(parameters);
00033     try {
00034       boolean resultEnabled = run(null, problem, learner, nbLearningSteps)
00035           && run(evaluator, problem, learner, nbEvaluationSteps);
00036       if (!resultEnabled)
00037         evaluator.worstResultUntilEnd();
00038     } catch (Throwable e) {
00039       e.printStackTrace(System.err);
00040       evaluator.worstResultUntilEnd();
00041     }
00042     evaluator.putResult(parameters);
00043     parameters.setComputationTimeMillis(chrono.getCurrentMillis());
00044   }
00045 
00046   private boolean run(PredictorEvaluator evaluator, PredictionProblem problem, LearningAlgorithm learner, long nbSteps) {
00047     for (int t = 0; t < nbSteps; t++) {
00048       boolean update = problem.update();
00049       if (!update)
00050         return true;
00051       if (evaluator != null) {
00052         double prediction = learner.predict(problem.input());
00053         evaluator.registerPrediction(t, problem.target(), prediction);
00054         if (!Utils.checkValue(prediction))
00055           return false;
00056       }
00057       double error = learner.learn(problem.input(), problem.target());
00058       if (!Utils.checkValue(error))
00059         return false;
00060     }
00061     return true;
00062   }
00063 
00064   @Override
00065   public long getComputationTimeMillis() {
00066     return parameters.getComputationTimeMillis();
00067   }
00068 
00069   @Override
00070   public Parameters parameters() {
00071     return parameters;
00072   }
00073 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark