RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.stategraph; 00002 00003 import java.io.Serializable; 00004 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.envio.policy.Policies; 00007 import rlpark.plugin.rltoys.envio.policy.Policy; 00008 import rlpark.plugin.rltoys.math.vector.RealVector; 00009 import rlpark.plugin.rltoys.math.vector.implementations.BVector; 00010 00011 00012 public abstract class FiniteStateGraph implements Serializable { 00013 private static final long serialVersionUID = 50902147743062052L; 00014 00015 static public class StepData { 00016 public final int stepTime; 00017 public final GraphState s_t; 00018 public final Action a_t; 00019 public final double r_tp1; 00020 public final GraphState s_tp1; 00021 public final Action a_tp1; 00022 00023 public StepData(int stepTime, GraphState s_t, Action a_t, GraphState s_tp1, double r_tp1, Action a_tp1) { 00024 assert s_t != null || a_t == null; 00025 this.stepTime = stepTime; 00026 this.s_t = s_t; 00027 this.a_t = a_t; 00028 this.s_tp1 = s_tp1; 00029 this.r_tp1 = r_tp1; 00030 this.a_tp1 = a_tp1; 00031 } 00032 00033 public RealVector v_t() { 00034 return s_t != null ? s_t.v() : null; 00035 } 00036 00037 public RealVector v_tp1() { 00038 return s_tp1 != null ? s_tp1.v() : null; 00039 } 00040 00041 @Override 00042 public boolean equals(Object obj) { 00043 if (super.equals(obj)) 00044 return true; 00045 StepData other = (StepData) obj; 00046 return stepTime == other.stepTime && s_t == other.s_t && a_t == other.a_t && s_tp1 == other.s_tp1 00047 && r_tp1 == other.r_tp1; 00048 } 00049 00050 @Override 00051 public int hashCode() { 00052 return toString().hashCode(); 00053 } 00054 00055 @Override 00056 public String toString() { 00057 return String.format("%d: %s,%s -> %s", stepTime, s_t, a_t, s_tp1); 00058 } 00059 } 00060 00061 private int stepTime = -1; 00062 private GraphState s_0; 00063 private Action a_t; 00064 private GraphState s_t; 00065 private final GraphState[] states; 00066 private final Policy acting; 00067 00068 public FiniteStateGraph(Policy policy, GraphState[] states) { 00069 this.states = states; 00070 acting = policy; 00071 for (int i = 0; i < states.length; i++) 00072 states[i].setVectorRepresentation(BVector.toBVector(states.length, new int[] { i })); 00073 } 00074 00075 protected void setInitialState(GraphState s_0) { 00076 assert this.s_0 == null; 00077 assert s_0 != null; 00078 this.s_0 = s_0; 00079 } 00080 00081 public StepData step() { 00082 stepTime += 1; 00083 GraphState s_tm1 = s_t; 00084 Action a_tm1 = null; 00085 if (s_t == null) 00086 s_t = s_0; 00087 else { 00088 a_tm1 = a_t; 00089 s_t = s_tm1.nextState(a_tm1); 00090 } 00091 a_t = Policies.decide(acting, s_t.v()); 00092 double r_t = s_t.reward; 00093 if (!s_t.hasNextState()) { 00094 a_t = null; 00095 s_t = null; 00096 } 00097 return new StepData(stepTime, s_tm1, a_tm1, s_t, r_t, a_t); 00098 } 00099 00100 abstract public double gamma(); 00101 00102 abstract public double[] expectedDiscountedSolution(); 00103 00104 public GraphState[] states() { 00105 return states; 00106 } 00107 00108 public int nbStates() { 00109 return states.length; 00110 } 00111 00112 public GraphState currentState() { 00113 return s_t; 00114 } 00115 00116 abstract public Action[] actions(); 00117 00118 public GraphState initialState() { 00119 return s_0; 00120 } 00121 00122 public Policy policy() { 00123 return acting; 00124 } 00125 00126 public int indexOf(GraphState s) { 00127 for (int i = 0; i < states.length; i++) 00128 if (states[i] == s) 00129 return i; 00130 return -1; 00131 } 00132 00133 public GraphState state(RealVector s) { 00134 if (s == null) 00135 return null; 00136 return states[((BVector) s).getActiveIndexes()[0]]; 00137 } 00138 }