RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.td; 00002 00003 import rlpark.plugin.rltoys.math.vector.RealVector; 00004 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00005 import rlpark.plugin.rltoys.math.vector.pool.VectorPool; 00006 import rlpark.plugin.rltoys.math.vector.pool.VectorPools; 00007 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00008 00009 @Monitor 00010 public class HTD implements OnPolicyTD, GVF { 00011 private static final long serialVersionUID = 8687476023177671278L; 00012 protected double gamma; 00013 public double alpha_v; 00014 public double alpha_w; 00015 @Monitor(level = 4) 00016 public PVector v; 00017 @Monitor(level = 4) 00018 protected final PVector w; 00019 public double v_t; 00020 protected double delta_t; 00021 private double correction; 00022 private double ratio; 00023 private double rho_t; 00024 00025 public HTD(double gamma, double alpha_v, double alpha_w, int nbFeatures) { 00026 this.alpha_v = alpha_v; 00027 this.gamma = gamma; 00028 this.alpha_w = alpha_w; 00029 v = new PVector(nbFeatures); 00030 w = new PVector(nbFeatures); 00031 } 00032 00033 @Override 00034 public double update(RealVector x_t, RealVector x_tp1, double r_tp1) { 00035 return update(1, 1, x_t, x_tp1, r_tp1); 00036 } 00037 00038 @Override 00039 public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1) { 00040 return update(pi_t, b_t, x_t, x_tp1, r_tp1, gamma, 0); 00041 } 00042 00043 @Override 00044 public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1, 00045 double z_tp1) { 00046 if (x_t == null) 00047 return initEpisode(); 00048 VectorPool pool = VectorPools.pool(x_t); 00049 v_t = v.dotProduct(x_t); 00050 delta_t = r_tp1 + (1 - gamma_tp1) * z_tp1 + gamma_tp1 * v.dotProduct(x_tp1) - v_t; 00051 correction = w.dotProduct(x_tp1); 00052 ratio = (pi_t - b_t) / b_t; 00053 rho_t = pi_t / b_t; 00054 v.addToSelf(alpha_v, 00055 pool.newVector(x_t).mapMultiplyToSelf(rho_t * delta_t) 00056 .addToSelf(pool.newVector(x_tp1).mapMultiplyToSelf(gamma_tp1 * ratio * correction))); 00057 w.addToSelf(alpha_w, 00058 pool.newVector(x_t).mapMultiplyToSelf(rho_t * (delta_t - correction)) 00059 .addToSelf(pool.newVector(x_tp1).mapMultiplyToSelf(-gamma_tp1 * correction))); 00060 pool.releaseAll(); 00061 return delta_t; 00062 } 00063 00064 protected double initEpisode() { 00065 v_t = 0; 00066 delta_t = 0; 00067 return delta_t; 00068 } 00069 00070 @Override 00071 public void resetWeight(int index) { 00072 v.data[index] = 0; 00073 } 00074 00075 00076 @Override 00077 public double predict(RealVector phi) { 00078 return v.dotProduct(phi); 00079 } 00080 00081 public double gamma() { 00082 return gamma; 00083 } 00084 00085 @Override 00086 public PVector weights() { 00087 return v; 00088 } 00089 00090 @Override 00091 public PVector secondaryWeights() { 00092 return w; 00093 } 00094 00095 00096 @Override 00097 public double error() { 00098 return delta_t; 00099 } 00100 00101 @Override 00102 public double prediction() { 00103 return v_t; 00104 } 00105 }