RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.sarsa; 00002 00003 import rlpark.plugin.rltoys.algorithms.control.ControlLearner; 00004 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.envio.policy.Policies; 00007 import rlpark.plugin.rltoys.envio.policy.Policy; 00008 import rlpark.plugin.rltoys.math.vector.RealVector; 00009 import rlpark.plugin.rltoys.math.vector.implementations.Vectors; 00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00011 00012 public class SarsaControl implements ControlLearner { 00013 private static final long serialVersionUID = 2848271828496458933L; 00014 @Monitor 00015 protected final Sarsa sarsa; 00016 @Monitor 00017 protected final Policy acting; 00018 protected final StateToStateAction toStateAction; 00019 protected RealVector xa_t = null; 00020 00021 public SarsaControl(Policy acting, StateToStateAction toStateAction, Sarsa sarsa) { 00022 this.sarsa = sarsa; 00023 this.toStateAction = toStateAction; 00024 this.acting = acting; 00025 } 00026 00027 @Override 00028 public Action step(RealVector x_t, Action a_t, RealVector x_tp1, double r_tp1) { 00029 Action a_tp1 = Policies.decide(acting, x_tp1); 00030 RealVector xa_tp1 = toStateAction.stateAction(x_tp1, a_tp1); 00031 sarsa.update(x_t != null ? xa_t : null, xa_tp1, r_tp1); 00032 xa_t = Vectors.bufferedCopy(xa_tp1, xa_t); 00033 return a_tp1; 00034 } 00035 00036 public Policy acting() { 00037 return acting; 00038 } 00039 00040 @Override 00041 public Action proposeAction(RealVector x) { 00042 return Policies.decide(acting, x); 00043 } 00044 }