RLPark 1.0.0
Reinforcement Learning Framework in Java

QLearning.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.qlearning;
00002 
00003 import rlpark.plugin.rltoys.algorithms.LinearLearner;
00004 import rlpark.plugin.rltoys.algorithms.control.acting.Greedy;
00005 import rlpark.plugin.rltoys.algorithms.functions.Predictor;
00006 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
00007 import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm;
00008 import rlpark.plugin.rltoys.algorithms.traces.Traces;
00009 import rlpark.plugin.rltoys.envio.actions.Action;
00010 import rlpark.plugin.rltoys.envio.policy.Policy;
00011 import rlpark.plugin.rltoys.math.vector.RealVector;
00012 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00014 
00015 @Monitor
00016 public class QLearning implements Predictor, LinearLearner, EligibilityTraceAlgorithm {
00017   private static final long serialVersionUID = -404558746167490755L;
00018   @Monitor(level = 4)
00019   protected final PVector theta;
00020   private final Traces e;
00021   private final double lambda;
00022   private final double gamma;
00023   private final double alpha;
00024   private final StateToStateAction toStateAction;
00025   private double delta;
00026   private final Greedy greedy;
00027 
00028   public QLearning(Action[] actions, double alpha, double gamma, double lambda, StateToStateAction toStateAction,
00029       Traces prototype) {
00030     this.alpha = alpha;
00031     this.gamma = gamma;
00032     this.lambda = lambda;
00033     this.toStateAction = toStateAction;
00034     greedy = new Greedy(this, actions, toStateAction);
00035     theta = new PVector(toStateAction.vectorSize());
00036     e = prototype.newTraces(toStateAction.vectorSize());
00037   }
00038 
00039   public double update(RealVector x_t, Action a_t, RealVector x_tp1, Action a_tp1, double r_tp1) {
00040     if (x_t == null)
00041       return initEpisode();
00042     greedy.update(x_t);
00043     Action at_star = greedy.bestAction();
00044     greedy.update(x_tp1);
00045     RealVector phi_sa_t = toStateAction.stateAction(x_t, a_t);
00046     delta = r_tp1 + gamma * greedy.bestActionValue() - theta.dotProduct(phi_sa_t);
00047     if (a_t == at_star)
00048       e.update(gamma * lambda, phi_sa_t);
00049     else {
00050       e.clear();
00051       e.update(0, phi_sa_t);
00052     }
00053     theta.addToSelf(alpha * delta, e.vect());
00054     return delta;
00055   }
00056 
00057   private double initEpisode() {
00058     if (e != null)
00059       e.clear();
00060     delta = 0.0;
00061     return delta;
00062   }
00063 
00064   @Override
00065   public double predict(RealVector phi_sa) {
00066     return theta.dotProduct(phi_sa);
00067   }
00068 
00069   public PVector theta() {
00070     return theta;
00071   }
00072 
00073   @Override
00074   public void resetWeight(int index) {
00075     theta.setEntry(index, 0);
00076   }
00077 
00078   @Override
00079   public PVector weights() {
00080     return theta;
00081   }
00082 
00083   @Override
00084   public double error() {
00085     return delta;
00086   }
00087 
00088   public Policy greedy() {
00089     return greedy;
00090   }
00091 
00092   @Override
00093   public Traces traces() {
00094     return e;
00095   }
00096 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark