RLPark 1.0.0
Reinforcement Learning Framework in Java

GQ.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.gq;
00002 
00003 import rlpark.plugin.rltoys.algorithms.LinearLearner;
00004 import rlpark.plugin.rltoys.algorithms.functions.Predictor;
00005 import rlpark.plugin.rltoys.algorithms.traces.ATraces;
00006 import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm;
00007 import rlpark.plugin.rltoys.algorithms.traces.Traces;
00008 import rlpark.plugin.rltoys.math.vector.MutableVector;
00009 import rlpark.plugin.rltoys.math.vector.RealVector;
00010 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00011 import rlpark.plugin.rltoys.math.vector.pool.VectorPool;
00012 import rlpark.plugin.rltoys.math.vector.pool.VectorPools;
00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00014 
00015 @Monitor
00016 public class GQ implements Predictor, LinearLearner, EligibilityTraceAlgorithm {
00017   private static final long serialVersionUID = -4971665888576276439L;
00018   @Monitor(level = 4)
00019   public final PVector v;
00020   protected double alpha_v;
00021   protected double alpha_w;
00022   protected double beta_tp1;
00023   protected double lambda_t;
00024   @Monitor(level = 4)
00025   protected final PVector w;
00026   protected final Traces e;
00027   protected double delta_t;
00028 
00029   public GQ(double alpha_v, double alpha_w, double beta, double lambda, int nbFeatures) {
00030     this(alpha_v, alpha_w, beta, lambda, nbFeatures, new ATraces());
00031   }
00032 
00033   public GQ(double alpha_v, double alpha_w, double beta, double lambda, int nbFeatures, Traces prototype) {
00034     this.alpha_v = alpha_v;
00035     this.alpha_w = alpha_w;
00036     beta_tp1 = beta;
00037     lambda_t = lambda;
00038     e = prototype.newTraces(nbFeatures);
00039     v = new PVector(nbFeatures);
00040     w = new PVector(nbFeatures);
00041   }
00042 
00043   protected double initEpisode() {
00044     e.clear();
00045     return 0.0;
00046   }
00047 
00048   public double update(RealVector x_t, double rho_t, double r_tp1, RealVector x_bar_tp1, double z_tp1) {
00049     if (x_t == null)
00050       return initEpisode();
00051     VectorPool pool = VectorPools.pool(x_t);
00052     delta_t = r_tp1 + beta_tp1 * z_tp1 + (1 - beta_tp1) * v.dotProduct(x_bar_tp1) - v.dotProduct(x_t);
00053     e.update((1 - beta_tp1) * lambda_t * rho_t, x_t);
00054     MutableVector delta_e = pool.newVector(e.vect()).mapMultiplyToSelf(delta_t);
00055     MutableVector tdCorrection = pool.newVector();
00056     if (x_bar_tp1 != null)
00057       tdCorrection.set(x_bar_tp1).mapMultiplyToSelf((1 - beta_tp1) * (1 - lambda_t) * e.vect().dotProduct(w));
00058     v.addToSelf(alpha_v, pool.newVector(delta_e).subtractToSelf(tdCorrection));
00059     w.addToSelf(alpha_w, delta_e.subtractToSelf(pool.newVector(x_t).mapMultiplyToSelf(w.dotProduct(x_t))));
00060     delta_e = null;
00061     pool.releaseAll();
00062     return delta_t;
00063   }
00064 
00065   @Override
00066   public double predict(RealVector x) {
00067     return v.dotProduct(x);
00068   }
00069 
00070   @Override
00071   public PVector weights() {
00072     return v;
00073   }
00074 
00075   @Override
00076   public void resetWeight(int index) {
00077     v.data[index] = 0;
00078     e.vect().setEntry(index, 0);
00079     w.data[index] = 0;
00080   }
00081 
00082   @Override
00083   public double error() {
00084     return delta_t;
00085   }
00086 
00087   @Override
00088   public Traces traces() {
00089     return e;
00090   }
00091 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark