RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.sarsa; 00002 00003 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00004 import rlpark.plugin.rltoys.envio.actions.Action; 00005 import rlpark.plugin.rltoys.envio.policy.Policies; 00006 import rlpark.plugin.rltoys.envio.policy.Policy; 00007 import rlpark.plugin.rltoys.math.vector.MutableVector; 00008 import rlpark.plugin.rltoys.math.vector.RealVector; 00009 import rlpark.plugin.rltoys.math.vector.pool.VectorPool; 00010 import rlpark.plugin.rltoys.math.vector.pool.VectorPools; 00011 00012 public class ExpectedSarsaControl extends SarsaControl { 00013 private static final long serialVersionUID = 738626133717186128L; 00014 private final Action[] actions; 00015 00016 public ExpectedSarsaControl(Action[] actions, Policy acting, StateToStateAction toStateAction, Sarsa sarsa) { 00017 super(acting, toStateAction, sarsa); 00018 this.actions = actions; 00019 } 00020 00021 @Override 00022 public Action step(RealVector x_t, Action a_t, RealVector x_tp1, double r_tp1) { 00023 acting.update(x_tp1); 00024 Action a_tp1 = acting.sampleAction(); 00025 RealVector xa_tp1 = null; 00026 VectorPool pool = VectorPools.pool(x_tp1, sarsa.q.size); 00027 MutableVector phi_bar_tp1 = pool.newVector(); 00028 if (x_tp1 != null) { 00029 for (Action a : actions) { 00030 double pi = acting.pi(a); 00031 if (pi == 0.0) { 00032 assert a != a_tp1; 00033 continue; 00034 } 00035 RealVector phi_stp1a = toStateAction.stateAction(x_tp1, a); 00036 if (a == a_tp1) 00037 xa_tp1 = phi_stp1a.copy(); 00038 phi_bar_tp1.addToSelf(pi, phi_stp1a); 00039 } 00040 } 00041 sarsa.update(x_t != null ? xa_t : null, xa_tp1, r_tp1); 00042 xa_t = xa_tp1; 00043 pool.releaseAll(); 00044 return a_tp1; 00045 } 00046 00047 @Override 00048 public Policy acting() { 00049 return acting; 00050 } 00051 00052 @Override 00053 public Action proposeAction(RealVector x) { 00054 return Policies.decide(acting, x); 00055 } 00056 }