RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.td; 00002 00003 import rlpark.plugin.rltoys.algorithms.traces.ATraces; 00004 import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm; 00005 import rlpark.plugin.rltoys.algorithms.traces.Traces; 00006 import rlpark.plugin.rltoys.math.vector.MutableVector; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00009 import rlpark.plugin.rltoys.math.vector.pool.VectorPool; 00010 import rlpark.plugin.rltoys.math.vector.pool.VectorPools; 00011 import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs; 00012 import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared; 00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00014 00015 @SuppressWarnings("restriction") 00016 @Monitor 00017 public class GTDLambda implements OnPolicyTD, GVF, EligibilityTraceAlgorithm { 00018 private static final long serialVersionUID = 8687476023177671278L; 00019 protected double gamma; 00020 final public double alpha_v; 00021 public final double alpha_w; 00022 protected double lambda; 00023 private double gamma_t; 00024 @Monitor(level = 4) 00025 final public PVector v; 00026 @Monitor(level = 4) 00027 protected final PVector w; 00028 private final Traces e; 00029 protected double v_t; 00030 @Monitor(wrappers = { Squared.ID, Abs.ID }) 00031 protected double delta_t; 00032 private double correction; 00033 private double rho_t; 00034 00035 public GTDLambda(double lambda, double gamma, double alpha_v, double alpha_w, int nbFeatures) { 00036 this(lambda, gamma, alpha_v, alpha_w, nbFeatures, new ATraces()); 00037 } 00038 00039 public GTDLambda(double lambda, double gamma, double alpha_v, double alpha_w, int nbFeatures, Traces prototype) { 00040 this.alpha_v = alpha_v; 00041 this.gamma = gamma; 00042 this.lambda = lambda; 00043 this.alpha_w = alpha_w; 00044 v = new PVector(nbFeatures); 00045 w = new PVector(nbFeatures); 00046 e = prototype.newTraces(nbFeatures); 00047 } 00048 00049 @Override 00050 public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1, 00051 double z_tp1) { 00052 if (x_t == null) 00053 return initEpisode(gamma_tp1); 00054 VectorPool pool = VectorPools.pool(e.vect()); 00055 v_t = v.dotProduct(x_t); 00056 delta_t = r_tp1 + (1 - gamma_tp1) * z_tp1 + gamma_tp1 * v.dotProduct(x_tp1) - v_t; 00057 // Update traces 00058 e.update(gamma_t * lambda, x_t); 00059 rho_t = pi_t / b_t; 00060 e.vect().mapMultiplyToSelf(rho_t); 00061 // Compute correction 00062 MutableVector correctionVector = pool.newVector(); 00063 if (x_tp1 != null) { 00064 correction = e.vect().dotProduct(w); 00065 correctionVector.addToSelf(correction * gamma_tp1 * (1 - lambda), x_tp1); 00066 } 00067 // Update parameters 00068 MutableVector deltaE = pool.newVector(e.vect()).mapMultiplyToSelf(delta_t); 00069 v.addToSelf(alpha_v, pool.newVector(deltaE).subtractToSelf(correctionVector)); 00070 w.addToSelf(alpha_w, deltaE.addToSelf(-w.dotProduct(x_t), x_t)); 00071 deltaE = null; 00072 gamma_t = gamma_tp1; 00073 pool.releaseAll(); 00074 return delta_t; 00075 } 00076 00077 protected double initEpisode(double gamma_tp1) { 00078 gamma_t = gamma_tp1; 00079 e.clear(); 00080 v_t = 0; 00081 return 0; 00082 } 00083 00084 @Override 00085 public void resetWeight(int index) { 00086 v.data[index] = 0; 00087 e.vect().setEntry(index, 0); 00088 } 00089 00090 @Override 00091 public double update(RealVector x_t, RealVector x_tp1, double r_tp1) { 00092 return update(1, 1, x_t, x_tp1, r_tp1, gamma, 0); 00093 } 00094 00095 @Override 00096 public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1) { 00097 return update(pi_t, b_t, x_t, x_tp1, r_tp1, gamma, 0); 00098 } 00099 00100 public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1) { 00101 return update(pi_t, b_t, x_t, x_tp1, r_tp1, gamma_tp1, 0); 00102 } 00103 00104 @Override 00105 public double predict(RealVector phi) { 00106 return v.dotProduct(phi); 00107 } 00108 00109 public double gamma() { 00110 return gamma; 00111 } 00112 00113 @Override 00114 public PVector weights() { 00115 return v; 00116 } 00117 00118 @Override 00119 public PVector secondaryWeights() { 00120 return w; 00121 } 00122 00123 @Override 00124 public Traces traces() { 00125 return e; 00126 } 00127 00128 @Override 00129 public double error() { 00130 return delta_t; 00131 } 00132 00133 @Override 00134 public double prediction() { 00135 return v_t; 00136 } 00137 00138 public double correction() { 00139 return correction; 00140 } 00141 }