RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.td; 00002 00003 import rlpark.plugin.rltoys.math.vector.RealVector; 00004 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00005 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00006 00007 @Monitor 00008 public class TDC extends TD { 00009 private static final long serialVersionUID = 7305877522126081130L; 00010 @Monitor(level = 4) 00011 protected final PVector w; 00012 public final double alpha_w; 00013 00014 public TDC(double gamma, double alpha_v, double alpha_w, int nbFeatures) { 00015 super(gamma, alpha_v, nbFeatures); 00016 w = new PVector(nbFeatures); 00017 this.alpha_w = alpha_w; 00018 } 00019 00020 @Override 00021 public double update(RealVector x_t, RealVector x_tp1, double r_tp1, double gamma_tp1) { 00022 if (x_t == null) 00023 return initEpisode(); 00024 v_t = v.dotProduct(x_t); 00025 delta_t = r_tp1 + gamma_tp1 * v.dotProduct(x_tp1) - v_t; 00026 RealVector tdCorrection = (x_tp1 != null ? x_tp1.mapMultiply(x_t.dotProduct(w)).mapMultiply(alpha_v * gamma_tp1) 00027 : new PVector(x_t.getDimension())); 00028 v.addToSelf(x_t.mapMultiply(alpha_v * delta_t).subtract(tdCorrection)); 00029 w.addToSelf(x_t.mapMultiply(alpha_w * (delta_t - x_t.dotProduct(w)))); 00030 return delta_t; 00031 } 00032 00033 @Override 00034 public void resetWeight(int index) { 00035 super.resetWeight(index); 00036 w.data[index] = 0; 00037 } 00038 00039 public PVector secondaryWeights() { 00040 return w; 00041 } 00042 }