RLPark 1.0.0
Reinforcement Learning Framework in Java

TD.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.predictions.td;
00002 
00003 import rlpark.plugin.rltoys.math.vector.RealVector;
00004 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00005 import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs;
00006 import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared;
00007 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00008 
00009 @Monitor
00010 @SuppressWarnings("restriction")
00011 public class TD implements OnPolicyTD {
00012   private static final long serialVersionUID = -3640476464100200081L;
00013   final public double alpha_v;
00014   protected double gamma;
00015   @Monitor(level = 4)
00016   final public PVector v;
00017   @Monitor(wrappers = { Squared.ID, Abs.ID })
00018   protected double delta_t;
00019   protected double v_t;
00020 
00021   public TD(double alpha_v, int nbFeatures) {
00022     this(Double.NaN, alpha_v, nbFeatures);
00023   }
00024 
00025   public TD(double gamma, double alpha_v, int nbFeatures) {
00026     this.alpha_v = alpha_v;
00027     this.gamma = gamma;
00028     v = new PVector(nbFeatures);
00029   }
00030 
00031   protected double initEpisode() {
00032     v_t = 0;
00033     delta_t = 0;
00034     return delta_t;
00035   }
00036 
00037   @Override
00038   public double update(RealVector x_t, RealVector x_tp1, double r_tp1) {
00039     return update(x_t, x_tp1, r_tp1, gamma);
00040   }
00041 
00042   public double update(RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1) {
00043     if (x_t == null)
00044       return initEpisode();
00045     v_t = v.dotProduct(x_t);
00046     delta_t = r_tp1 + gamma_tp1 * v.dotProduct(x_tp1) - v_t;
00047     v.addToSelf(alpha_v * delta_t, x_t);
00048     return delta_t;
00049   }
00050 
00051   @Override
00052   public double predict(RealVector phi) {
00053     return v.dotProduct(phi);
00054   }
00055 
00056   public double gamma() {
00057     return gamma;
00058   }
00059 
00060   @Override
00061   public PVector weights() {
00062     return v;
00063   }
00064 
00065   @Override
00066   public void resetWeight(int index) {
00067     v.data[index] = 0;
00068   }
00069 
00070   @Override
00071   public double error() {
00072     return delta_t;
00073   }
00074 
00075   @Override
00076   public double prediction() {
00077     return v_t;
00078   }
00079 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark