RLPark 1.0.0
Reinforcement Learning Framework in Java

Sarsa.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.sarsa;
00002 
00003 import rlpark.plugin.rltoys.algorithms.functions.ParameterizedFunction;
00004 import rlpark.plugin.rltoys.algorithms.functions.Predictor;
00005 import rlpark.plugin.rltoys.algorithms.traces.ATraces;
00006 import rlpark.plugin.rltoys.algorithms.traces.Traces;
00007 import rlpark.plugin.rltoys.math.vector.RealVector;
00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00010 
00011 @Monitor
00012 public class Sarsa implements Predictor, ParameterizedFunction {
00013   private static final long serialVersionUID = 9030254074554565900L;
00014   @Monitor(level = 4)
00015   protected final Traces e;
00016   @Monitor(level = 4)
00017   protected final PVector q;
00018   protected final double lambda;
00019   protected final double gamma;
00020   protected final double alpha;
00021   protected double delta;
00022   protected double v_t;
00023   protected double v_tp1;
00024 
00025   public Sarsa(double alpha, double gamma, double lambda, int nbFeatures) {
00026     this(alpha, gamma, lambda, nbFeatures, new ATraces());
00027   }
00028 
00029   public Sarsa(double alpha, double gamma, double lambda, int nbFeatures, Traces prototype) {
00030     this(alpha, gamma, lambda, new PVector(nbFeatures), prototype);
00031   }
00032 
00033   public Sarsa(double alpha, double gamma, double lambda, PVector q, Traces prototype) {
00034     this.alpha = alpha;
00035     this.gamma = gamma;
00036     this.lambda = lambda;
00037     this.q = q;
00038     e = prototype.newTraces(q.getDimension());
00039   }
00040 
00041   public double update(RealVector phi_t, RealVector phi_tp1, double r_tp1) {
00042     if (phi_t == null)
00043       return initEpisode();
00044     v_tp1 = phi_tp1 != null ? q.dotProduct(phi_tp1) : 0;
00045     v_t = q.dotProduct(phi_t);
00046     delta = r_tp1 + gamma * v_tp1 - v_t;
00047     e.update(gamma * lambda, phi_t);
00048     q.addToSelf(alpha * delta, e.vect());
00049     return delta;
00050   }
00051 
00052   protected double initEpisode() {
00053     e.clear();
00054     return 0.0;
00055   }
00056 
00057   @Override
00058   public double predict(RealVector phi_sa) {
00059     assert q.getDimension() == phi_sa.getDimension();
00060     return q.dotProduct(phi_sa);
00061   }
00062 
00063   @Override
00064   public PVector weights() {
00065     return q;
00066   }
00067 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark