RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.stategraph02; 00002 00003 import java.io.Serializable; 00004 import java.util.LinkedHashMap; 00005 import java.util.Map; 00006 import java.util.Random; 00007 00008 import rlpark.plugin.rltoys.envio.actions.Action; 00009 00010 public class StateGraph implements Serializable { 00011 private static final long serialVersionUID = -2849828765062029412L; 00012 private final State[] states; 00013 private final Map<State, Integer> stateToIndex = new LinkedHashMap<State, Integer>(); 00014 private final Map<Action, double[][]> transitions = new LinkedHashMap<Action, double[][]>(); 00015 00016 public StateGraph(State s0, State[] states, Action[] actions) { 00017 this.states = states; 00018 for (int i = 0; i < states.length; i++) 00019 stateToIndex.put(states[i], i); 00020 for (int i = 0; i < actions.length; i++) 00021 transitions.put(actions[i], newMatrix(states.length)); 00022 } 00023 00024 private double[][] newMatrix(int length) { 00025 double[][] matrix = new double[length][]; 00026 for (int i = 0; i < matrix.length; i++) 00027 matrix[i] = new double[matrix.length]; 00028 return matrix; 00029 } 00030 00031 public int nbStates() { 00032 return states.length; 00033 } 00034 00035 public int indexOf(State s) { 00036 return stateToIndex.get(s); 00037 } 00038 00039 public State sampleNextState(Random random, State s, Action a) { 00040 double[] p_sa = transitions.get(a)[stateToIndex.get(s)]; 00041 double randomValue = random.nextDouble(); 00042 int i = -1; 00043 double sum = 0; 00044 do { 00045 i++; 00046 sum += p_sa[i]; 00047 } while (sum < randomValue && i < p_sa.length - 1); 00048 assert sum > 0; 00049 return states[i]; 00050 } 00051 00052 public boolean isTerminal(State s) { 00053 int s_i = stateToIndex.get(s); 00054 for (double[][] ps : transitions.values()) { 00055 if (sum(ps[s_i]) == 0) 00056 return true; 00057 } 00058 return false; 00059 } 00060 00061 private double sum(double[] ds) { 00062 double sum = 0; 00063 for (double p : ds) 00064 sum += p; 00065 return sum; 00066 } 00067 00068 public void addTransition(State s_t, Action a_t, State s_tp1, double prob) { 00069 transitions.get(a_t)[stateToIndex.get(s_t)][stateToIndex.get(s_tp1)] = prob; 00070 } 00071 00072 public boolean checkDistribution() { 00073 for (double[][] psa : transitions.values()) { 00074 for (double[] ps : psa) { 00075 double sum = sum(ps); 00076 if (sum != 0 && sum != 1.0) 00077 return false; 00078 } 00079 } 00080 return true; 00081 } 00082 }