RLPark 1.0.0
Reinforcement Learning Framework in Java

PredictionDemonVerifier.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.horde.demons;
00002 
00003 import java.io.Serializable;
00004 
00005 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD;
00006 import rlpark.plugin.rltoys.algorithms.predictions.td.TD;
00007 import rlpark.plugin.rltoys.algorithms.predictions.td.TDErrorMonitor;
00008 import rlpark.plugin.rltoys.algorithms.predictions.td.TDLambdaAutostep;
00009 import rlpark.plugin.rltoys.horde.functions.RewardFunction;
00010 import rlpark.plugin.rltoys.utils.NotImplemented;
00011 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00012 
00013 public class PredictionDemonVerifier implements Serializable {
00014   private static final long serialVersionUID = 6127406364376542150L;
00015   private final PredictionDemon predictionDemon;
00016   private final RewardFunction rewardFunction;
00017   @Monitor
00018   private final TDErrorMonitor errorMonitor;
00019 
00020   public PredictionDemonVerifier(PredictionDemon predictionDemon) {
00021     this(predictionDemon, 0.01);
00022   }
00023 
00024   public PredictionDemonVerifier(PredictionDemon predictionDemon, double precision) {
00025     this.predictionDemon = predictionDemon;
00026     rewardFunction = predictionDemon.rewardFunction();
00027     double gamma = extractGamma(predictionDemon.predicter());
00028     errorMonitor = new TDErrorMonitor(gamma, precision);
00029   }
00030 
00031   public double extractGamma(OnPolicyTD learner) {
00032     if (learner instanceof TD)
00033       return ((TD) learner).gamma();
00034     if (learner instanceof TDLambdaAutostep)
00035       return ((TDLambdaAutostep) learner).gamma();
00036     throw new NotImplemented();
00037   }
00038 
00039   public TDErrorMonitor errorMonitor() {
00040     return errorMonitor;
00041   }
00042 
00043   public double update(boolean endOfEpisode) {
00044     return errorMonitor.update(predictionDemon.prediction(), rewardFunction.reward(), endOfEpisode);
00045   }
00046 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark