RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.td; 00002 00003 00004 import rlpark.plugin.rltoys.algorithms.traces.ATraces; 00005 import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm; 00006 import rlpark.plugin.rltoys.algorithms.traces.Traces; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00009 00010 public class TDLambda extends TD implements EligibilityTraceAlgorithm { 00011 private static final long serialVersionUID = 8613865620293286722L; 00012 private final double lambda; 00013 @Monitor 00014 public final Traces e; 00015 00016 public TDLambda(double lambda, double gamma, double alpha, int nbFeatures) { 00017 this(lambda, gamma, alpha, nbFeatures, new ATraces()); 00018 } 00019 00020 public TDLambda(double lambda, double gamma, double alpha, int nbFeatures, Traces prototype) { 00021 super(gamma, alpha, nbFeatures); 00022 this.lambda = lambda; 00023 e = prototype.newTraces(nbFeatures); 00024 } 00025 00026 @Override 00027 protected double initEpisode() { 00028 e.clear(); 00029 return super.initEpisode(); 00030 } 00031 00032 @Override 00033 public double update(RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1) { 00034 if (x_t == null) 00035 return initEpisode(); 00036 v_t = v.dotProduct(x_t); 00037 delta_t = r_tp1 + gamma_tp1 * v.dotProduct(x_tp1) - v_t; 00038 e.update(lambda * gamma_tp1, x_t); 00039 v.addToSelf(alpha_v * delta_t, e.vect()); 00040 return delta_t; 00041 } 00042 00043 @Override 00044 public void resetWeight(int index) { 00045 super.resetWeight(index); 00046 e.vect().setEntry(index, 0); 00047 } 00048 00049 @Override 00050 public Traces traces() { 00051 return e; 00052 } 00053 }