RLPark 1.0.0
Reinforcement Learning Framework in Java

GreedyGQ.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.gq;
00002 
00003 import rlpark.plugin.rltoys.algorithms.control.OffPolicyLearner;
00004 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
00005 import rlpark.plugin.rltoys.envio.actions.Action;
00006 import rlpark.plugin.rltoys.envio.policy.Policies;
00007 import rlpark.plugin.rltoys.envio.policy.Policy;
00008 import rlpark.plugin.rltoys.math.vector.MutableVector;
00009 import rlpark.plugin.rltoys.math.vector.RealVector;
00010 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00011 import rlpark.plugin.rltoys.math.vector.pool.VectorPool;
00012 import rlpark.plugin.rltoys.math.vector.pool.VectorPools;
00013 import rlpark.plugin.rltoys.utils.Prototype;
00014 import rlpark.plugin.rltoys.utils.Utils;
00015 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00016 
00017 public class GreedyGQ implements OffPolicyLearner {
00018   private static final long serialVersionUID = 7017521530598253457L;
00019   @Monitor
00020   protected final GQ gq;
00021   @Monitor
00022   protected final Policy target;
00023   protected final Policy behaviour;
00024   protected final StateToStateAction toStateAction;
00025   @Monitor
00026   public double rho_t;
00027   private final Action[] actions;
00028   private double delta_t;
00029   private final RealVector prototype;
00030 
00031   @SuppressWarnings("unchecked")
00032   public GreedyGQ(GQ gq, Action[] actions, StateToStateAction toStateAction, Policy target, Policy behaviour) {
00033     this.gq = gq;
00034     this.target = target;
00035     this.behaviour = behaviour;
00036     this.toStateAction = toStateAction;
00037     this.actions = actions;
00038     prototype = ((Prototype<RealVector>) gq.e).prototype();
00039   }
00040 
00041   public double update(RealVector x_t, Action a_t, double r_tp1, double gamma_tp1, double z_tp1, RealVector x_tp1,
00042       Action a_tp1) {
00043     rho_t = 0.0;
00044     if (a_t != null) {
00045       target.update(x_t);
00046       behaviour.update(x_t);
00047       rho_t = target.pi(a_t) / behaviour.pi(a_t);
00048     }
00049     assert Utils.checkValue(rho_t);
00050     VectorPool pool = VectorPools.pool(prototype, gq.v.size);
00051     MutableVector sa_bar_tp1 = pool.newVector();
00052     if (x_t != null && x_tp1 != null) {
00053       target.update(x_tp1);
00054       for (Action a : actions) {
00055         double pi = target.pi(a);
00056         if (pi == 0)
00057           continue;
00058         sa_bar_tp1.addToSelf(pi, toStateAction.stateAction(x_tp1, a));
00059       }
00060     }
00061     RealVector phi_stat = x_t != null ? toStateAction.stateAction(x_t, a_t) : null;
00062     delta_t = gq.update(phi_stat, rho_t, r_tp1, sa_bar_tp1, z_tp1);
00063     pool.releaseAll();
00064     return delta_t;
00065   }
00066 
00067   public PVector theta() {
00068     return gq.v;
00069   }
00070 
00071   public double gamma() {
00072     return 1 - gq.beta_tp1;
00073   }
00074 
00075   public GQ gq() {
00076     return gq;
00077   }
00078 
00079   @Override
00080   public Policy targetPolicy() {
00081     return target;
00082   }
00083 
00084   @Override
00085   public void learn(RealVector x_t, Action a_t, RealVector x_tp1, Action a_tp1, double reward) {
00086     update(x_t, a_t, reward, gamma(), 0, x_tp1, a_tp1);
00087   }
00088 
00089   @Override
00090   public Action proposeAction(RealVector x_t) {
00091     return Policies.decide(target, x_t);
00092   }
00093 
00094   @Override
00095   public GQ predictor() {
00096     return gq;
00097   }
00098 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark