RLPark 1.0.0
Reinforcement Learning Framework in Java

HTD.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.predictions.td;
00002 
00003 import rlpark.plugin.rltoys.math.vector.RealVector;
00004 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00005 import rlpark.plugin.rltoys.math.vector.pool.VectorPool;
00006 import rlpark.plugin.rltoys.math.vector.pool.VectorPools;
00007 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00008 
00009 @Monitor
00010 public class HTD implements OnPolicyTD, GVF {
00011   private static final long serialVersionUID = 8687476023177671278L;
00012   protected double gamma;
00013   public double alpha_v;
00014   public double alpha_w;
00015   @Monitor(level = 4)
00016   public PVector v;
00017   @Monitor(level = 4)
00018   protected final PVector w;
00019   public double v_t;
00020   protected double delta_t;
00021   private double correction;
00022   private double ratio;
00023   private double rho_t;
00024 
00025   public HTD(double gamma, double alpha_v, double alpha_w, int nbFeatures) {
00026     this.alpha_v = alpha_v;
00027     this.gamma = gamma;
00028     this.alpha_w = alpha_w;
00029     v = new PVector(nbFeatures);
00030     w = new PVector(nbFeatures);
00031   }
00032 
00033   @Override
00034   public double update(RealVector x_t, RealVector x_tp1, double r_tp1) {
00035     return update(1, 1, x_t, x_tp1, r_tp1);
00036   }
00037 
00038   @Override
00039   public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1) {
00040     return update(pi_t, b_t, x_t, x_tp1, r_tp1, gamma, 0);
00041   }
00042 
00043   @Override
00044   public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1,
00045       double z_tp1) {
00046     if (x_t == null)
00047       return initEpisode();
00048     VectorPool pool = VectorPools.pool(x_t);
00049     v_t = v.dotProduct(x_t);
00050     delta_t = r_tp1 + (1 - gamma_tp1) * z_tp1 + gamma_tp1 * v.dotProduct(x_tp1) - v_t;
00051     correction = w.dotProduct(x_tp1);
00052     ratio = (pi_t - b_t) / b_t;
00053     rho_t = pi_t / b_t;
00054     v.addToSelf(alpha_v,
00055                 pool.newVector(x_t).mapMultiplyToSelf(rho_t * delta_t)
00056                     .addToSelf(pool.newVector(x_tp1).mapMultiplyToSelf(gamma_tp1 * ratio * correction)));
00057     w.addToSelf(alpha_w,
00058                 pool.newVector(x_t).mapMultiplyToSelf(rho_t * (delta_t - correction))
00059                     .addToSelf(pool.newVector(x_tp1).mapMultiplyToSelf(-gamma_tp1 * correction)));
00060     pool.releaseAll();
00061     return delta_t;
00062   }
00063 
00064   protected double initEpisode() {
00065     v_t = 0;
00066     delta_t = 0;
00067     return delta_t;
00068   }
00069 
00070   @Override
00071   public void resetWeight(int index) {
00072     v.data[index] = 0;
00073   }
00074 
00075 
00076   @Override
00077   public double predict(RealVector phi) {
00078     return v.dotProduct(phi);
00079   }
00080 
00081   public double gamma() {
00082     return gamma;
00083   }
00084 
00085   @Override
00086   public PVector weights() {
00087     return v;
00088   }
00089 
00090   @Override
00091   public PVector secondaryWeights() {
00092     return w;
00093   }
00094 
00095 
00096   @Override
00097   public double error() {
00098     return delta_t;
00099   }
00100 
00101   @Override
00102   public double prediction() {
00103     return v_t;
00104   }
00105 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark