RLPark 1.0.0
Reinforcement Learning Framework in Java

FSGAgentState.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.problems.stategraph;
00002 
00003 import java.util.LinkedHashMap;
00004 import java.util.LinkedHashSet;
00005 import java.util.Map;
00006 import java.util.Set;
00007 
00008 import org.apache.commons.math.linear.Array2DRowRealMatrix;
00009 import org.apache.commons.math.linear.ArrayRealVector;
00010 import org.apache.commons.math.linear.LUDecompositionImpl;
00011 import org.apache.commons.math.linear.RealMatrix;
00012 
00013 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
00014 import rlpark.plugin.rltoys.envio.actions.Action;
00015 import rlpark.plugin.rltoys.envio.policy.Policy;
00016 import rlpark.plugin.rltoys.math.vector.RealVector;
00017 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00018 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph.StepData;
00019 
00020 public class FSGAgentState implements StateToStateAction {
00021   private static final long serialVersionUID = -6312948577339609928L;
00022   public final int size;
00023   private final Map<GraphState, Integer> stateIndexes;
00024   private final FiniteStateGraph graph;
00025   private final PVector featureState;
00026 
00027   public FSGAgentState(FiniteStateGraph graph) {
00028     this.graph = graph;
00029     stateIndexes = indexStates(graph.states());
00030     size = nbNonAbsorbingState();
00031     featureState = new PVector(size);
00032   }
00033 
00034   private Map<GraphState, Integer> indexStates(GraphState[] states) {
00035     Map<GraphState, Integer> stateIndexes = new LinkedHashMap<GraphState, Integer>();
00036     int ci = 0;
00037     for (GraphState state : states) {
00038       GraphState s = state;
00039       if (!s.hasNextState())
00040         continue;
00041       stateIndexes.put(s, ci);
00042       ci++;
00043     }
00044     return stateIndexes;
00045   }
00046 
00047   public StepData step() {
00048     StepData stepData = graph.step();
00049     if (stepData.s_t != null && stepData.s_t.hasNextState())
00050       featureState.data[stateIndexes.get(stepData.s_t)] = 0;
00051     if (stepData.s_tp1 != null && stepData.s_tp1.hasNextState())
00052       featureState.data[stateIndexes.get(stepData.s_tp1)] = 1;
00053     return stepData;
00054   }
00055 
00056   public PVector currentFeatureState() {
00057     if (graph.currentState() == null)
00058       return new PVector(size);
00059     return featureState;
00060   }
00061 
00062   private RealMatrix createIdentityMatrix(int size) {
00063     RealMatrix phi = new Array2DRowRealMatrix(size, size);
00064     for (int i = 0; i < size; i++)
00065       phi.setEntry(i, i, 1.0);
00066     return phi;
00067   }
00068 
00069   public RealMatrix createPhi() {
00070     RealMatrix result = new Array2DRowRealMatrix(nbStates(), nbNonAbsorbingState());
00071     for (int i = 0; i < nbStates(); i++)
00072       result.setRow(i, getFeatureVector(states()[i]).data);
00073     return result;
00074   }
00075 
00076   private PVector getFeatureVector(GraphState graphState) {
00077     PVector result = new PVector(nbNonAbsorbingState());
00078     int ci = 0;
00079     for (int i = 0; i < nbStates(); i++) {
00080       GraphState s = states()[i];
00081       if (!s.hasNextState())
00082         continue;
00083       if (s == graphState)
00084         result.data[ci] = 1;
00085       ci++;
00086     }
00087     return result;
00088   }
00089 
00090   public double[] computeSolution(Policy policy, double gamma, double lambda) {
00091     RealMatrix phi = createPhi();
00092     RealMatrix p = createTransitionProbablityMatrix(policy);
00093     ArrayRealVector d = createStateDistribution(p);
00094     RealMatrix d_pi = createStateDistributionMatrix(d);
00095     RealMatrix p_lambda = computePLambda(p, gamma, lambda);
00096     ArrayRealVector r_bar = computeAverageReward(p);
00097 
00098     RealMatrix A = computeA(phi, d_pi, gamma, p_lambda);
00099     ArrayRealVector b = computeB(phi, d_pi, p, r_bar, gamma, lambda);
00100     RealMatrix minusAInverse = new LUDecompositionImpl(A).getSolver().getInverse().scalarMultiply(-1);
00101     return minusAInverse.operate(b).getData();
00102   }
00103 
00104   private ArrayRealVector computeB(RealMatrix phi, RealMatrix dPi, RealMatrix p, ArrayRealVector rBar, double gamma,
00105       double lambda) {
00106     RealMatrix inv = computeIdMinusGammaLambdaP(p, gamma, lambda);
00107     return (ArrayRealVector) phi.transpose().operate(dPi.operate(inv.operate(rBar)));
00108   }
00109 
00110   private RealMatrix computeA(RealMatrix phi, RealMatrix dPi, double gamma, RealMatrix pLambda) {
00111     RealMatrix id = createIdentityMatrix(phi.getRowDimension());
00112     return phi.transpose().multiply(dPi.multiply(pLambda.scalarMultiply(gamma).subtract(id).multiply(phi)));
00113   }
00114 
00115   private ArrayRealVector computeAverageReward(RealMatrix p) {
00116     ArrayRealVector result = new ArrayRealVector(p.getColumnDimension());
00117     for (int i = 0; i < nbStates(); i++) {
00118       if (!states()[i].hasNextState())
00119         continue;
00120       double sum = 0;
00121       for (int j = 0; j < nbStates(); j++)
00122         sum += p.getEntry(i, j) * states()[j].reward;
00123       result.setEntry(i, sum);
00124     }
00125     return result;
00126   }
00127 
00128   private RealMatrix computePLambda(RealMatrix p, double gamma, double lambda) {
00129     RealMatrix inv = computeIdMinusGammaLambdaP(p, gamma, lambda);
00130     return inv.multiply(p).scalarMultiply(1 - lambda);
00131   }
00132 
00133   private RealMatrix computeIdMinusGammaLambdaP(RealMatrix p, double gamma, double lambda) {
00134     RealMatrix id = createIdentityMatrix(p.getColumnDimension());
00135     return new LUDecompositionImpl(id.subtract(p.scalarMultiply(lambda * gamma))).getSolver().getInverse();
00136   }
00137 
00138   private RealMatrix createStateDistributionMatrix(ArrayRealVector d) {
00139     RealMatrix d_pi = new Array2DRowRealMatrix(nbStates(), nbStates());
00140     int ci = 0;
00141     for (int i = 0; i < nbStates(); i++) {
00142       GraphState s = states()[i];
00143       if (!s.hasNextState())
00144         continue;
00145       d_pi.setEntry(i, i, d.getEntry(ci));
00146       ci++;
00147     }
00148     return d_pi;
00149   }
00150 
00151   private ArrayRealVector createStateDistribution(RealMatrix p) {
00152     RealMatrix p_copy = p.copy();
00153     p_copy = removeColumnAndRow(p_copy, absorbingStatesSet());
00154     assert p_copy.getColumnDimension() == p_copy.getRowDimension();
00155     RealMatrix id = createIdentityMatrix(p_copy.getColumnDimension());
00156     RealMatrix inv = new LUDecompositionImpl(id.subtract(p_copy)).getSolver().getInverse();
00157     RealMatrix mu = createInitialStateDistribution();
00158     RealMatrix visits = mu.multiply(inv);
00159     double sum = 0;
00160     for (int i = 0; i < visits.getColumnDimension(); i++)
00161       sum += visits.getEntry(0, i);
00162     return (ArrayRealVector) visits.scalarMultiply(1 / sum).getRowVector(0);
00163   }
00164 
00165   private Set<Integer> absorbingStatesSet() {
00166     Set<Integer> endStates = new LinkedHashSet<Integer>();
00167     for (int i = 0; i < nbStates(); i++)
00168       if (!states()[i].hasNextState())
00169         endStates.add(i);
00170     return endStates;
00171   }
00172 
00173   private int nbNonAbsorbingState() {
00174     return stateIndexes.size();
00175   }
00176 
00177   private RealMatrix removeColumnAndRow(RealMatrix m, Set<Integer> absorbingState) {
00178     RealMatrix result = new Array2DRowRealMatrix(nbNonAbsorbingState(), nbNonAbsorbingState());
00179     int ci = 0;
00180     for (int i = 0; i < m.getRowDimension(); i++) {
00181       if (absorbingState.contains(i))
00182         continue;
00183       int cj = 0;
00184       for (int j = 0; j < m.getColumnDimension(); j++) {
00185         if (absorbingState.contains(j))
00186           continue;
00187         result.setEntry(ci, cj, m.getEntry(i, j));
00188         cj++;
00189       }
00190       ci++;
00191     }
00192     return result;
00193   }
00194 
00195   private RealMatrix createInitialStateDistribution() {
00196     double[] numbers = new double[nbNonAbsorbingState()];
00197     int ci = 0;
00198     for (int i = 0; i < nbStates(); i++) {
00199       GraphState s = states()[i];
00200       if (!s.hasNextState())
00201         continue;
00202       if (s != graph.initialState())
00203         numbers[ci] = 0.0;
00204       else
00205         numbers[ci] = 1.0;
00206       ci++;
00207     }
00208     RealMatrix result = new Array2DRowRealMatrix(1, numbers.length);
00209     for (int i = 0; i < numbers.length; i++)
00210       result.setEntry(0, i, numbers[i]);
00211     return result;
00212   }
00213 
00214   private RealMatrix createTransitionProbablityMatrix(Policy policy) {
00215     RealMatrix p = new Array2DRowRealMatrix(nbStates(), nbStates());
00216     for (int si = 0; si < nbStates(); si++) {
00217       GraphState s_t = states()[si];
00218       policy.update(s_t.v());
00219       for (Action a : graph.actions()) {
00220         double pa = policy.pi(a);
00221         GraphState s_tp1 = s_t.nextState(a);
00222         if (s_tp1 != null)
00223           p.setEntry(si, graph.indexOf(s_tp1), pa);
00224       }
00225     }
00226     for (Integer absorbingState : absorbingStatesSet())
00227       p.setEntry(absorbingState, absorbingState, 1.0);
00228     return p;
00229   }
00230 
00231   private int nbStates() {
00232     return graph.nbStates();
00233   }
00234 
00235   private GraphState[] states() {
00236     return graph.states();
00237   }
00238 
00239   public Map<GraphState, Integer> stateIndexes() {
00240     return stateIndexes;
00241   }
00242 
00243   public FiniteStateGraph graph() {
00244     return graph;
00245   }
00246 
00247   public PVector featureState(GraphState s) {
00248     PVector result = new PVector(size);
00249     if (s != null && s.hasNextState())
00250       result.data[stateIndexes.get(s)] = 1;
00251     return result;
00252   }
00253 
00254   @Override
00255   public PVector stateAction(RealVector s, Action a) {
00256     PVector sa = new PVector(nbNonAbsorbingState() * graph.actions().length);
00257     if (s == null)
00258       return sa;
00259     GraphState sg = graph.state(s);
00260     for (int ai = 0; ai < graph.actions().length; ai++)
00261       if (graph.actions()[ai] == a) {
00262         sa.setEntry(ai * nbNonAbsorbingState() + stateIndexes.get(sg), 1);
00263         return sa;
00264       }
00265     return null;
00266   }
00267 
00268   @Override
00269   public int vectorSize() {
00270     return size;
00271   }
00272 
00273   @Override
00274   public double vectorNorm() {
00275     return 1;
00276   }
00277 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark