RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.gq; 00002 00003 import rlpark.plugin.rltoys.algorithms.LinearLearner; 00004 import rlpark.plugin.rltoys.algorithms.functions.Predictor; 00005 import rlpark.plugin.rltoys.algorithms.traces.ATraces; 00006 import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm; 00007 import rlpark.plugin.rltoys.algorithms.traces.Traces; 00008 import rlpark.plugin.rltoys.math.vector.MutableVector; 00009 import rlpark.plugin.rltoys.math.vector.RealVector; 00010 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00011 import rlpark.plugin.rltoys.math.vector.pool.VectorPool; 00012 import rlpark.plugin.rltoys.math.vector.pool.VectorPools; 00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00014 00015 @Monitor 00016 public class GQ implements Predictor, LinearLearner, EligibilityTraceAlgorithm { 00017 private static final long serialVersionUID = -4971665888576276439L; 00018 @Monitor(level = 4) 00019 public final PVector v; 00020 protected double alpha_v; 00021 protected double alpha_w; 00022 protected double beta_tp1; 00023 protected double lambda_t; 00024 @Monitor(level = 4) 00025 protected final PVector w; 00026 protected final Traces e; 00027 protected double delta_t; 00028 00029 public GQ(double alpha_v, double alpha_w, double beta, double lambda, int nbFeatures) { 00030 this(alpha_v, alpha_w, beta, lambda, nbFeatures, new ATraces()); 00031 } 00032 00033 public GQ(double alpha_v, double alpha_w, double beta, double lambda, int nbFeatures, Traces prototype) { 00034 this.alpha_v = alpha_v; 00035 this.alpha_w = alpha_w; 00036 beta_tp1 = beta; 00037 lambda_t = lambda; 00038 e = prototype.newTraces(nbFeatures); 00039 v = new PVector(nbFeatures); 00040 w = new PVector(nbFeatures); 00041 } 00042 00043 protected double initEpisode() { 00044 e.clear(); 00045 return 0.0; 00046 } 00047 00048 public double update(RealVector x_t, double rho_t, double r_tp1, RealVector x_bar_tp1, double z_tp1) { 00049 if (x_t == null) 00050 return initEpisode(); 00051 VectorPool pool = VectorPools.pool(x_t); 00052 delta_t = r_tp1 + beta_tp1 * z_tp1 + (1 - beta_tp1) * v.dotProduct(x_bar_tp1) - v.dotProduct(x_t); 00053 e.update((1 - beta_tp1) * lambda_t * rho_t, x_t); 00054 MutableVector delta_e = pool.newVector(e.vect()).mapMultiplyToSelf(delta_t); 00055 MutableVector tdCorrection = pool.newVector(); 00056 if (x_bar_tp1 != null) 00057 tdCorrection.set(x_bar_tp1).mapMultiplyToSelf((1 - beta_tp1) * (1 - lambda_t) * e.vect().dotProduct(w)); 00058 v.addToSelf(alpha_v, pool.newVector(delta_e).subtractToSelf(tdCorrection)); 00059 w.addToSelf(alpha_w, delta_e.subtractToSelf(pool.newVector(x_t).mapMultiplyToSelf(w.dotProduct(x_t)))); 00060 delta_e = null; 00061 pool.releaseAll(); 00062 return delta_t; 00063 } 00064 00065 @Override 00066 public double predict(RealVector x) { 00067 return v.dotProduct(x); 00068 } 00069 00070 @Override 00071 public PVector weights() { 00072 return v; 00073 } 00074 00075 @Override 00076 public void resetWeight(int index) { 00077 v.data[index] = 0; 00078 e.vect().setEntry(index, 0); 00079 w.data[index] = 0; 00080 } 00081 00082 @Override 00083 public double error() { 00084 return delta_t; 00085 } 00086 00087 @Override 00088 public Traces traces() { 00089 return e; 00090 } 00091 }