RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.td; 00002 00003 import java.io.Serializable; 00004 00005 import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs; 00006 import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared; 00007 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00008 00009 @SuppressWarnings("restriction") 00010 public class TDErrorMonitor implements Serializable { 00011 private static final long serialVersionUID = 6441800170099052600L; 00012 private final int bufferSize; 00013 private final double[] gammas; 00014 private final double[] predictionHistory; 00015 private final double[] observedHistory; 00016 private int current; 00017 private boolean cacheFilled; 00018 @Monitor(wrappers = { Squared.ID, Abs.ID }) 00019 private double error; 00020 @Monitor 00021 private double prediction, observed; 00022 private boolean errorComputed; 00023 private final double precision; 00024 private final double gamma; 00025 00026 public TDErrorMonitor(double gamma, double precision) { 00027 this.gamma = gamma; 00028 this.precision = precision; 00029 bufferSize = computeBufferSize(gamma, precision); 00030 predictionHistory = new double[bufferSize]; 00031 observedHistory = new double[bufferSize]; 00032 gammas = new double[bufferSize]; 00033 for (int i = 0; i < gammas.length; i++) 00034 gammas[i] = Math.pow(gamma, i); 00035 current = 0; 00036 cacheFilled = false; 00037 } 00038 00039 static public int computeBufferSize(double gamma, double precision) { 00040 return gamma > 0 ? (int) Math.ceil(Math.log(precision) / Math.log(gamma)) : 1; 00041 } 00042 00043 private void reset() { 00044 current = 0; 00045 cacheFilled = false; 00046 errorComputed = false; 00047 error = 0; 00048 prediction = 0; 00049 observed = 0; 00050 } 00051 00052 public double update(double prediction_t, double reward_tp1, boolean endOfEpisode) { 00053 if (endOfEpisode) { 00054 reset(); 00055 return 0.0; 00056 } 00057 if (cacheFilled) { 00058 errorComputed = true; 00059 prediction = predictionHistory[current]; 00060 observed = observedHistory[current]; 00061 error = observed - prediction; 00062 } 00063 observedHistory[current] = 0; 00064 for (int i = 0; i < bufferSize; i++) 00065 observedHistory[(current - i + bufferSize) % bufferSize] += reward_tp1 * gammas[i]; 00066 predictionHistory[current] = prediction_t; 00067 updateCurrent(); 00068 return error; 00069 } 00070 00071 protected void updateCurrent() { 00072 current++; 00073 if (current >= bufferSize) { 00074 cacheFilled = true; 00075 current = 0; 00076 } 00077 } 00078 00079 public double error() { 00080 return error; 00081 } 00082 00083 public boolean errorComputed() { 00084 return errorComputed; 00085 } 00086 00087 public double precision() { 00088 return precision; 00089 } 00090 00091 public double returnValue() { 00092 return observed; 00093 } 00094 00095 public double gamma() { 00096 return gamma; 00097 } 00098 00099 public int bufferSize() { 00100 return bufferSize; 00101 } 00102 }