RLPark 1.0.0
Reinforcement Learning Framework in Java

TDErrorMonitor.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.predictions.td;
00002 
00003 import java.io.Serializable;
00004 
00005 import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs;
00006 import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared;
00007 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00008 
00009 @SuppressWarnings("restriction")
00010 public class TDErrorMonitor implements Serializable {
00011   private static final long serialVersionUID = 6441800170099052600L;
00012   private final int bufferSize;
00013   private final double[] gammas;
00014   private final double[] predictionHistory;
00015   private final double[] observedHistory;
00016   private int current;
00017   private boolean cacheFilled;
00018   @Monitor(wrappers = { Squared.ID, Abs.ID })
00019   private double error;
00020   @Monitor
00021   private double prediction, observed;
00022   private boolean errorComputed;
00023   private final double precision;
00024   private final double gamma;
00025 
00026   public TDErrorMonitor(double gamma, double precision) {
00027     this.gamma = gamma;
00028     this.precision = precision;
00029     bufferSize = computeBufferSize(gamma, precision);
00030     predictionHistory = new double[bufferSize];
00031     observedHistory = new double[bufferSize];
00032     gammas = new double[bufferSize];
00033     for (int i = 0; i < gammas.length; i++)
00034       gammas[i] = Math.pow(gamma, i);
00035     current = 0;
00036     cacheFilled = false;
00037   }
00038 
00039   static public int computeBufferSize(double gamma, double precision) {
00040     return gamma > 0 ? (int) Math.ceil(Math.log(precision) / Math.log(gamma)) : 1;
00041   }
00042 
00043   private void reset() {
00044     current = 0;
00045     cacheFilled = false;
00046     errorComputed = false;
00047     error = 0;
00048     prediction = 0;
00049     observed = 0;
00050   }
00051 
00052   public double update(double prediction_t, double reward_tp1, boolean endOfEpisode) {
00053     if (endOfEpisode) {
00054       reset();
00055       return 0.0;
00056     }
00057     if (cacheFilled) {
00058       errorComputed = true;
00059       prediction = predictionHistory[current];
00060       observed = observedHistory[current];
00061       error = observed - prediction;
00062     }
00063     observedHistory[current] = 0;
00064     for (int i = 0; i < bufferSize; i++)
00065       observedHistory[(current - i + bufferSize) % bufferSize] += reward_tp1 * gammas[i];
00066     predictionHistory[current] = prediction_t;
00067     updateCurrent();
00068     return error;
00069   }
00070 
00071   protected void updateCurrent() {
00072     current++;
00073     if (current >= bufferSize) {
00074       cacheFilled = true;
00075       current = 0;
00076     }
00077   }
00078 
00079   public double error() {
00080     return error;
00081   }
00082 
00083   public boolean errorComputed() {
00084     return errorComputed;
00085   }
00086 
00087   public double precision() {
00088     return precision;
00089   }
00090 
00091   public double returnValue() {
00092     return observed;
00093   }
00094 
00095   public double gamma() {
00096     return gamma;
00097   }
00098 
00099   public int bufferSize() {
00100     return bufferSize;
00101   }
00102 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark