RLPark 1.0.0
Reinforcement Learning Framework in Java

StateGraph.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.problems.stategraph02;
00002 
00003 import java.io.Serializable;
00004 import java.util.LinkedHashMap;
00005 import java.util.Map;
00006 import java.util.Random;
00007 
00008 import rlpark.plugin.rltoys.envio.actions.Action;
00009 
00010 public class StateGraph implements Serializable {
00011   private static final long serialVersionUID = -2849828765062029412L;
00012   private final State[] states;
00013   private final Map<State, Integer> stateToIndex = new LinkedHashMap<State, Integer>();
00014   private final Map<Action, double[][]> transitions = new LinkedHashMap<Action, double[][]>();
00015 
00016   public StateGraph(State s0, State[] states, Action[] actions) {
00017     this.states = states;
00018     for (int i = 0; i < states.length; i++)
00019       stateToIndex.put(states[i], i);
00020     for (int i = 0; i < actions.length; i++)
00021       transitions.put(actions[i], newMatrix(states.length));
00022   }
00023 
00024   private double[][] newMatrix(int length) {
00025     double[][] matrix = new double[length][];
00026     for (int i = 0; i < matrix.length; i++)
00027       matrix[i] = new double[matrix.length];
00028     return matrix;
00029   }
00030 
00031   public int nbStates() {
00032     return states.length;
00033   }
00034 
00035   public int indexOf(State s) {
00036     return stateToIndex.get(s);
00037   }
00038 
00039   public State sampleNextState(Random random, State s, Action a) {
00040     double[] p_sa = transitions.get(a)[stateToIndex.get(s)];
00041     double randomValue = random.nextDouble();
00042     int i = -1;
00043     double sum = 0;
00044     do {
00045       i++;
00046       sum += p_sa[i];
00047     } while (sum < randomValue && i < p_sa.length - 1);
00048     assert sum > 0;
00049     return states[i];
00050   }
00051 
00052   public boolean isTerminal(State s) {
00053     int s_i = stateToIndex.get(s);
00054     for (double[][] ps : transitions.values()) {
00055       if (sum(ps[s_i]) == 0)
00056         return true;
00057     }
00058     return false;
00059   }
00060 
00061   private double sum(double[] ds) {
00062     double sum = 0;
00063     for (double p : ds)
00064       sum += p;
00065     return sum;
00066   }
00067 
00068   public void addTransition(State s_t, Action a_t, State s_tp1, double prob) {
00069     transitions.get(a_t)[stateToIndex.get(s_t)][stateToIndex.get(s_tp1)] = prob;
00070   }
00071 
00072   public boolean checkDistribution() {
00073     for (double[][] psa : transitions.values()) {
00074       for (double[] ps : psa) {
00075         double sum = sum(ps);
00076         if (sum != 0 && sum != 1.0)
00077           return false;
00078       }
00079     }
00080     return true;
00081   }
00082 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark