RLPark 1.0.0
Reinforcement Learning Framework in Java

FiniteStateGraph.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.problems.stategraph;
00002 
00003 import java.io.Serializable;
00004 
00005 import rlpark.plugin.rltoys.envio.actions.Action;
00006 import rlpark.plugin.rltoys.envio.policy.Policies;
00007 import rlpark.plugin.rltoys.envio.policy.Policy;
00008 import rlpark.plugin.rltoys.math.vector.RealVector;
00009 import rlpark.plugin.rltoys.math.vector.implementations.BVector;
00010 
00011 
00012 public abstract class FiniteStateGraph implements Serializable {
00013   private static final long serialVersionUID = 50902147743062052L;
00014 
00015   static public class StepData {
00016     public final int stepTime;
00017     public final GraphState s_t;
00018     public final Action a_t;
00019     public final double r_tp1;
00020     public final GraphState s_tp1;
00021     public final Action a_tp1;
00022 
00023     public StepData(int stepTime, GraphState s_t, Action a_t, GraphState s_tp1, double r_tp1, Action a_tp1) {
00024       assert s_t != null || a_t == null;
00025       this.stepTime = stepTime;
00026       this.s_t = s_t;
00027       this.a_t = a_t;
00028       this.s_tp1 = s_tp1;
00029       this.r_tp1 = r_tp1;
00030       this.a_tp1 = a_tp1;
00031     }
00032 
00033     public RealVector v_t() {
00034       return s_t != null ? s_t.v() : null;
00035     }
00036 
00037     public RealVector v_tp1() {
00038       return s_tp1 != null ? s_tp1.v() : null;
00039     }
00040 
00041     @Override
00042     public boolean equals(Object obj) {
00043       if (super.equals(obj))
00044         return true;
00045       StepData other = (StepData) obj;
00046       return stepTime == other.stepTime && s_t == other.s_t && a_t == other.a_t && s_tp1 == other.s_tp1
00047           && r_tp1 == other.r_tp1;
00048     }
00049 
00050     @Override
00051     public int hashCode() {
00052       return toString().hashCode();
00053     }
00054 
00055     @Override
00056     public String toString() {
00057       return String.format("%d: %s,%s -> %s", stepTime, s_t, a_t, s_tp1);
00058     }
00059   }
00060 
00061   private int stepTime = -1;
00062   private GraphState s_0;
00063   private Action a_t;
00064   private GraphState s_t;
00065   private final GraphState[] states;
00066   private final Policy acting;
00067 
00068   public FiniteStateGraph(Policy policy, GraphState[] states) {
00069     this.states = states;
00070     acting = policy;
00071     for (int i = 0; i < states.length; i++)
00072       states[i].setVectorRepresentation(BVector.toBVector(states.length, new int[] { i }));
00073   }
00074 
00075   protected void setInitialState(GraphState s_0) {
00076     assert this.s_0 == null;
00077     assert s_0 != null;
00078     this.s_0 = s_0;
00079   }
00080 
00081   public StepData step() {
00082     stepTime += 1;
00083     GraphState s_tm1 = s_t;
00084     Action a_tm1 = null;
00085     if (s_t == null)
00086       s_t = s_0;
00087     else {
00088       a_tm1 = a_t;
00089       s_t = s_tm1.nextState(a_tm1);
00090     }
00091     a_t = Policies.decide(acting, s_t.v());
00092     double r_t = s_t.reward;
00093     if (!s_t.hasNextState()) {
00094       a_t = null;
00095       s_t = null;
00096     }
00097     return new StepData(stepTime, s_tm1, a_tm1, s_t, r_t, a_t);
00098   }
00099 
00100   abstract public double gamma();
00101 
00102   abstract public double[] expectedDiscountedSolution();
00103 
00104   public GraphState[] states() {
00105     return states;
00106   }
00107 
00108   public int nbStates() {
00109     return states.length;
00110   }
00111 
00112   public GraphState currentState() {
00113     return s_t;
00114   }
00115 
00116   abstract public Action[] actions();
00117 
00118   public GraphState initialState() {
00119     return s_0;
00120   }
00121 
00122   public Policy policy() {
00123     return acting;
00124   }
00125 
00126   public int indexOf(GraphState s) {
00127     for (int i = 0; i < states.length; i++)
00128       if (states[i] == s)
00129         return i;
00130     return -1;
00131   }
00132 
00133   public GraphState state(RealVector s) {
00134     if (s == null)
00135       return null;
00136     return states[((BVector) s).getActiveIndexes()[0]];
00137   }
00138 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark