RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.sarsa; 00002 00003 import rlpark.plugin.rltoys.algorithms.functions.ParameterizedFunction; 00004 import rlpark.plugin.rltoys.algorithms.functions.Predictor; 00005 import rlpark.plugin.rltoys.algorithms.traces.ATraces; 00006 import rlpark.plugin.rltoys.algorithms.traces.Traces; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00010 00011 @Monitor 00012 public class Sarsa implements Predictor, ParameterizedFunction { 00013 private static final long serialVersionUID = 9030254074554565900L; 00014 @Monitor(level = 4) 00015 protected final Traces e; 00016 @Monitor(level = 4) 00017 protected final PVector q; 00018 protected final double lambda; 00019 protected final double gamma; 00020 protected final double alpha; 00021 protected double delta; 00022 protected double v_t; 00023 protected double v_tp1; 00024 00025 public Sarsa(double alpha, double gamma, double lambda, int nbFeatures) { 00026 this(alpha, gamma, lambda, nbFeatures, new ATraces()); 00027 } 00028 00029 public Sarsa(double alpha, double gamma, double lambda, int nbFeatures, Traces prototype) { 00030 this(alpha, gamma, lambda, new PVector(nbFeatures), prototype); 00031 } 00032 00033 public Sarsa(double alpha, double gamma, double lambda, PVector q, Traces prototype) { 00034 this.alpha = alpha; 00035 this.gamma = gamma; 00036 this.lambda = lambda; 00037 this.q = q; 00038 e = prototype.newTraces(q.getDimension()); 00039 } 00040 00041 public double update(RealVector phi_t, RealVector phi_tp1, double r_tp1) { 00042 if (phi_t == null) 00043 return initEpisode(); 00044 v_tp1 = phi_tp1 != null ? q.dotProduct(phi_tp1) : 0; 00045 v_t = q.dotProduct(phi_t); 00046 delta = r_tp1 + gamma * v_tp1 - v_t; 00047 e.update(gamma * lambda, phi_t); 00048 q.addToSelf(alpha * delta, e.vect()); 00049 return delta; 00050 } 00051 00052 protected double initEpisode() { 00053 e.clear(); 00054 return 0.0; 00055 } 00056 00057 @Override 00058 public double predict(RealVector phi_sa) { 00059 assert q.getDimension() == phi_sa.getDimension(); 00060 return q.dotProduct(phi_sa); 00061 } 00062 00063 @Override 00064 public PVector weights() { 00065 return q; 00066 } 00067 }