RLPark 1.0.0
Reinforcement Learning Framework in Java

ExpectedSarsaControl.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.sarsa;
00002 
00003 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
00004 import rlpark.plugin.rltoys.envio.actions.Action;
00005 import rlpark.plugin.rltoys.envio.policy.Policies;
00006 import rlpark.plugin.rltoys.envio.policy.Policy;
00007 import rlpark.plugin.rltoys.math.vector.MutableVector;
00008 import rlpark.plugin.rltoys.math.vector.RealVector;
00009 import rlpark.plugin.rltoys.math.vector.pool.VectorPool;
00010 import rlpark.plugin.rltoys.math.vector.pool.VectorPools;
00011 
00012 public class ExpectedSarsaControl extends SarsaControl {
00013   private static final long serialVersionUID = 738626133717186128L;
00014   private final Action[] actions;
00015 
00016   public ExpectedSarsaControl(Action[] actions, Policy acting, StateToStateAction toStateAction, Sarsa sarsa) {
00017     super(acting, toStateAction, sarsa);
00018     this.actions = actions;
00019   }
00020 
00021   @Override
00022   public Action step(RealVector x_t, Action a_t, RealVector x_tp1, double r_tp1) {
00023     acting.update(x_tp1);
00024     Action a_tp1 = acting.sampleAction();
00025     RealVector xa_tp1 = null;
00026     VectorPool pool = VectorPools.pool(x_tp1, sarsa.q.size);
00027     MutableVector phi_bar_tp1 = pool.newVector();
00028     if (x_tp1 != null) {
00029       for (Action a : actions) {
00030         double pi = acting.pi(a);
00031         if (pi == 0.0) {
00032           assert a != a_tp1;
00033           continue;
00034         }
00035         RealVector phi_stp1a = toStateAction.stateAction(x_tp1, a);
00036         if (a == a_tp1)
00037           xa_tp1 = phi_stp1a.copy();
00038         phi_bar_tp1.addToSelf(pi, phi_stp1a);
00039       }
00040     }
00041     sarsa.update(x_t != null ? xa_t : null, xa_tp1, r_tp1);
00042     xa_t = xa_tp1;
00043     pool.releaseAll();
00044     return a_tp1;
00045   }
00046 
00047   @Override
00048   public Policy acting() {
00049     return acting;
00050   }
00051 
00052   @Override
00053   public Action proposeAction(RealVector x) {
00054     return Policies.decide(acting, x);
00055   }
00056 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark