RLPark 1.0.0
Reinforcement Learning Framework in Java

TDLambda.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.predictions.td;
00002 
00003 
00004 import rlpark.plugin.rltoys.algorithms.traces.ATraces;
00005 import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm;
00006 import rlpark.plugin.rltoys.algorithms.traces.Traces;
00007 import rlpark.plugin.rltoys.math.vector.RealVector;
00008 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00009 
00010 public class TDLambda extends TD implements EligibilityTraceAlgorithm {
00011   private static final long serialVersionUID = 8613865620293286722L;
00012   private final double lambda;
00013   @Monitor
00014   public final Traces e;
00015 
00016   public TDLambda(double lambda, double gamma, double alpha, int nbFeatures) {
00017     this(lambda, gamma, alpha, nbFeatures, new ATraces());
00018   }
00019 
00020   public TDLambda(double lambda, double gamma, double alpha, int nbFeatures, Traces prototype) {
00021     super(gamma, alpha, nbFeatures);
00022     this.lambda = lambda;
00023     e = prototype.newTraces(nbFeatures);
00024   }
00025 
00026   @Override
00027   protected double initEpisode() {
00028     e.clear();
00029     return super.initEpisode();
00030   }
00031 
00032   @Override
00033   public double update(RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1) {
00034     if (x_t == null)
00035       return initEpisode();
00036     v_t = v.dotProduct(x_t);
00037     delta_t = r_tp1 + gamma_tp1 * v.dotProduct(x_tp1) - v_t;
00038     e.update(lambda * gamma_tp1, x_t);
00039     v.addToSelf(alpha_v * delta_t, e.vect());
00040     return delta_t;
00041   }
00042 
00043   @Override
00044   public void resetWeight(int index) {
00045     super.resetWeight(index);
00046     e.vect().setEntry(index, 0);
00047   }
00048 
00049   @Override
00050   public Traces traces() {
00051     return e;
00052   }
00053 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark