RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.td; 00002 00003 import rlpark.plugin.rltoys.math.vector.RealVector; 00004 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00005 import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs; 00006 import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared; 00007 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00008 00009 @Monitor 00010 @SuppressWarnings("restriction") 00011 public class TD implements OnPolicyTD { 00012 private static final long serialVersionUID = -3640476464100200081L; 00013 final public double alpha_v; 00014 protected double gamma; 00015 @Monitor(level = 4) 00016 final public PVector v; 00017 @Monitor(wrappers = { Squared.ID, Abs.ID }) 00018 protected double delta_t; 00019 protected double v_t; 00020 00021 public TD(double alpha_v, int nbFeatures) { 00022 this(Double.NaN, alpha_v, nbFeatures); 00023 } 00024 00025 public TD(double gamma, double alpha_v, int nbFeatures) { 00026 this.alpha_v = alpha_v; 00027 this.gamma = gamma; 00028 v = new PVector(nbFeatures); 00029 } 00030 00031 protected double initEpisode() { 00032 v_t = 0; 00033 delta_t = 0; 00034 return delta_t; 00035 } 00036 00037 @Override 00038 public double update(RealVector x_t, RealVector x_tp1, double r_tp1) { 00039 return update(x_t, x_tp1, r_tp1, gamma); 00040 } 00041 00042 public double update(RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1) { 00043 if (x_t == null) 00044 return initEpisode(); 00045 v_t = v.dotProduct(x_t); 00046 delta_t = r_tp1 + gamma_tp1 * v.dotProduct(x_tp1) - v_t; 00047 v.addToSelf(alpha_v * delta_t, x_t); 00048 return delta_t; 00049 } 00050 00051 @Override 00052 public double predict(RealVector phi) { 00053 return v.dotProduct(phi); 00054 } 00055 00056 public double gamma() { 00057 return gamma; 00058 } 00059 00060 @Override 00061 public PVector weights() { 00062 return v; 00063 } 00064 00065 @Override 00066 public void resetWeight(int index) { 00067 v.data[index] = 0; 00068 } 00069 00070 @Override 00071 public double error() { 00072 return delta_t; 00073 } 00074 00075 @Override 00076 public double prediction() { 00077 return v_t; 00078 } 00079 }