RLPark 1.0.0
Reinforcement Learning Framework in Java

GTDLambda.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.predictions.td;
00002 
00003 import rlpark.plugin.rltoys.algorithms.traces.ATraces;
00004 import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm;
00005 import rlpark.plugin.rltoys.algorithms.traces.Traces;
00006 import rlpark.plugin.rltoys.math.vector.MutableVector;
00007 import rlpark.plugin.rltoys.math.vector.RealVector;
00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00009 import rlpark.plugin.rltoys.math.vector.pool.VectorPool;
00010 import rlpark.plugin.rltoys.math.vector.pool.VectorPools;
00011 import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs;
00012 import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared;
00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00014 
00015 @SuppressWarnings("restriction")
00016 @Monitor
00017 public class GTDLambda implements OnPolicyTD, GVF, EligibilityTraceAlgorithm {
00018   private static final long serialVersionUID = 8687476023177671278L;
00019   protected double gamma;
00020   final public double alpha_v;
00021   public final double alpha_w;
00022   protected double lambda;
00023   private double gamma_t;
00024   @Monitor(level = 4)
00025   final public PVector v;
00026   @Monitor(level = 4)
00027   protected final PVector w;
00028   private final Traces e;
00029   protected double v_t;
00030   @Monitor(wrappers = { Squared.ID, Abs.ID })
00031   protected double delta_t;
00032   private double correction;
00033   private double rho_t;
00034 
00035   public GTDLambda(double lambda, double gamma, double alpha_v, double alpha_w, int nbFeatures) {
00036     this(lambda, gamma, alpha_v, alpha_w, nbFeatures, new ATraces());
00037   }
00038 
00039   public GTDLambda(double lambda, double gamma, double alpha_v, double alpha_w, int nbFeatures, Traces prototype) {
00040     this.alpha_v = alpha_v;
00041     this.gamma = gamma;
00042     this.lambda = lambda;
00043     this.alpha_w = alpha_w;
00044     v = new PVector(nbFeatures);
00045     w = new PVector(nbFeatures);
00046     e = prototype.newTraces(nbFeatures);
00047   }
00048 
00049   @Override
00050   public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1,
00051       double z_tp1) {
00052     if (x_t == null)
00053       return initEpisode(gamma_tp1);
00054     VectorPool pool = VectorPools.pool(e.vect());
00055     v_t = v.dotProduct(x_t);
00056     delta_t = r_tp1 + (1 - gamma_tp1) * z_tp1 + gamma_tp1 * v.dotProduct(x_tp1) - v_t;
00057     // Update traces
00058     e.update(gamma_t * lambda, x_t);
00059     rho_t = pi_t / b_t;
00060     e.vect().mapMultiplyToSelf(rho_t);
00061     // Compute correction
00062     MutableVector correctionVector = pool.newVector();
00063     if (x_tp1 != null) {
00064       correction = e.vect().dotProduct(w);
00065       correctionVector.addToSelf(correction * gamma_tp1 * (1 - lambda), x_tp1);
00066     }
00067     // Update parameters
00068     MutableVector deltaE = pool.newVector(e.vect()).mapMultiplyToSelf(delta_t);
00069     v.addToSelf(alpha_v, pool.newVector(deltaE).subtractToSelf(correctionVector));
00070     w.addToSelf(alpha_w, deltaE.addToSelf(-w.dotProduct(x_t), x_t));
00071     deltaE = null;
00072     gamma_t = gamma_tp1;
00073     pool.releaseAll();
00074     return delta_t;
00075   }
00076 
00077   protected double initEpisode(double gamma_tp1) {
00078     gamma_t = gamma_tp1;
00079     e.clear();
00080     v_t = 0;
00081     return 0;
00082   }
00083 
00084   @Override
00085   public void resetWeight(int index) {
00086     v.data[index] = 0;
00087     e.vect().setEntry(index, 0);
00088   }
00089 
00090   @Override
00091   public double update(RealVector x_t, RealVector x_tp1, double r_tp1) {
00092     return update(1, 1, x_t, x_tp1, r_tp1, gamma, 0);
00093   }
00094 
00095   @Override
00096   public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1) {
00097     return update(pi_t, b_t, x_t, x_tp1, r_tp1, gamma, 0);
00098   }
00099 
00100   public double update(double pi_t, double b_t, RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1) {
00101     return update(pi_t, b_t, x_t, x_tp1, r_tp1, gamma_tp1, 0);
00102   }
00103 
00104   @Override
00105   public double predict(RealVector phi) {
00106     return v.dotProduct(phi);
00107   }
00108 
00109   public double gamma() {
00110     return gamma;
00111   }
00112 
00113   @Override
00114   public PVector weights() {
00115     return v;
00116   }
00117 
00118   @Override
00119   public PVector secondaryWeights() {
00120     return w;
00121   }
00122 
00123   @Override
00124   public Traces traces() {
00125     return e;
00126   }
00127 
00128   @Override
00129   public double error() {
00130     return delta_t;
00131   }
00132 
00133   @Override
00134   public double prediction() {
00135     return v_t;
00136   }
00137 
00138   public double correction() {
00139     return correction;
00140   }
00141 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark