RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.stategraph; 00002 00003 import java.util.LinkedHashMap; 00004 import java.util.LinkedHashSet; 00005 import java.util.Map; 00006 import java.util.Set; 00007 00008 import org.apache.commons.math.linear.Array2DRowRealMatrix; 00009 import org.apache.commons.math.linear.ArrayRealVector; 00010 import org.apache.commons.math.linear.LUDecompositionImpl; 00011 import org.apache.commons.math.linear.RealMatrix; 00012 00013 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00014 import rlpark.plugin.rltoys.envio.actions.Action; 00015 import rlpark.plugin.rltoys.envio.policy.Policy; 00016 import rlpark.plugin.rltoys.math.vector.RealVector; 00017 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00018 import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph.StepData; 00019 00020 public class FSGAgentState implements StateToStateAction { 00021 private static final long serialVersionUID = -6312948577339609928L; 00022 public final int size; 00023 private final Map<GraphState, Integer> stateIndexes; 00024 private final FiniteStateGraph graph; 00025 private final PVector featureState; 00026 00027 public FSGAgentState(FiniteStateGraph graph) { 00028 this.graph = graph; 00029 stateIndexes = indexStates(graph.states()); 00030 size = nbNonAbsorbingState(); 00031 featureState = new PVector(size); 00032 } 00033 00034 private Map<GraphState, Integer> indexStates(GraphState[] states) { 00035 Map<GraphState, Integer> stateIndexes = new LinkedHashMap<GraphState, Integer>(); 00036 int ci = 0; 00037 for (GraphState state : states) { 00038 GraphState s = state; 00039 if (!s.hasNextState()) 00040 continue; 00041 stateIndexes.put(s, ci); 00042 ci++; 00043 } 00044 return stateIndexes; 00045 } 00046 00047 public StepData step() { 00048 StepData stepData = graph.step(); 00049 if (stepData.s_t != null && stepData.s_t.hasNextState()) 00050 featureState.data[stateIndexes.get(stepData.s_t)] = 0; 00051 if (stepData.s_tp1 != null && stepData.s_tp1.hasNextState()) 00052 featureState.data[stateIndexes.get(stepData.s_tp1)] = 1; 00053 return stepData; 00054 } 00055 00056 public PVector currentFeatureState() { 00057 if (graph.currentState() == null) 00058 return new PVector(size); 00059 return featureState; 00060 } 00061 00062 private RealMatrix createIdentityMatrix(int size) { 00063 RealMatrix phi = new Array2DRowRealMatrix(size, size); 00064 for (int i = 0; i < size; i++) 00065 phi.setEntry(i, i, 1.0); 00066 return phi; 00067 } 00068 00069 public RealMatrix createPhi() { 00070 RealMatrix result = new Array2DRowRealMatrix(nbStates(), nbNonAbsorbingState()); 00071 for (int i = 0; i < nbStates(); i++) 00072 result.setRow(i, getFeatureVector(states()[i]).data); 00073 return result; 00074 } 00075 00076 private PVector getFeatureVector(GraphState graphState) { 00077 PVector result = new PVector(nbNonAbsorbingState()); 00078 int ci = 0; 00079 for (int i = 0; i < nbStates(); i++) { 00080 GraphState s = states()[i]; 00081 if (!s.hasNextState()) 00082 continue; 00083 if (s == graphState) 00084 result.data[ci] = 1; 00085 ci++; 00086 } 00087 return result; 00088 } 00089 00090 public double[] computeSolution(Policy policy, double gamma, double lambda) { 00091 RealMatrix phi = createPhi(); 00092 RealMatrix p = createTransitionProbablityMatrix(policy); 00093 ArrayRealVector d = createStateDistribution(p); 00094 RealMatrix d_pi = createStateDistributionMatrix(d); 00095 RealMatrix p_lambda = computePLambda(p, gamma, lambda); 00096 ArrayRealVector r_bar = computeAverageReward(p); 00097 00098 RealMatrix A = computeA(phi, d_pi, gamma, p_lambda); 00099 ArrayRealVector b = computeB(phi, d_pi, p, r_bar, gamma, lambda); 00100 RealMatrix minusAInverse = new LUDecompositionImpl(A).getSolver().getInverse().scalarMultiply(-1); 00101 return minusAInverse.operate(b).getData(); 00102 } 00103 00104 private ArrayRealVector computeB(RealMatrix phi, RealMatrix dPi, RealMatrix p, ArrayRealVector rBar, double gamma, 00105 double lambda) { 00106 RealMatrix inv = computeIdMinusGammaLambdaP(p, gamma, lambda); 00107 return (ArrayRealVector) phi.transpose().operate(dPi.operate(inv.operate(rBar))); 00108 } 00109 00110 private RealMatrix computeA(RealMatrix phi, RealMatrix dPi, double gamma, RealMatrix pLambda) { 00111 RealMatrix id = createIdentityMatrix(phi.getRowDimension()); 00112 return phi.transpose().multiply(dPi.multiply(pLambda.scalarMultiply(gamma).subtract(id).multiply(phi))); 00113 } 00114 00115 private ArrayRealVector computeAverageReward(RealMatrix p) { 00116 ArrayRealVector result = new ArrayRealVector(p.getColumnDimension()); 00117 for (int i = 0; i < nbStates(); i++) { 00118 if (!states()[i].hasNextState()) 00119 continue; 00120 double sum = 0; 00121 for (int j = 0; j < nbStates(); j++) 00122 sum += p.getEntry(i, j) * states()[j].reward; 00123 result.setEntry(i, sum); 00124 } 00125 return result; 00126 } 00127 00128 private RealMatrix computePLambda(RealMatrix p, double gamma, double lambda) { 00129 RealMatrix inv = computeIdMinusGammaLambdaP(p, gamma, lambda); 00130 return inv.multiply(p).scalarMultiply(1 - lambda); 00131 } 00132 00133 private RealMatrix computeIdMinusGammaLambdaP(RealMatrix p, double gamma, double lambda) { 00134 RealMatrix id = createIdentityMatrix(p.getColumnDimension()); 00135 return new LUDecompositionImpl(id.subtract(p.scalarMultiply(lambda * gamma))).getSolver().getInverse(); 00136 } 00137 00138 private RealMatrix createStateDistributionMatrix(ArrayRealVector d) { 00139 RealMatrix d_pi = new Array2DRowRealMatrix(nbStates(), nbStates()); 00140 int ci = 0; 00141 for (int i = 0; i < nbStates(); i++) { 00142 GraphState s = states()[i]; 00143 if (!s.hasNextState()) 00144 continue; 00145 d_pi.setEntry(i, i, d.getEntry(ci)); 00146 ci++; 00147 } 00148 return d_pi; 00149 } 00150 00151 private ArrayRealVector createStateDistribution(RealMatrix p) { 00152 RealMatrix p_copy = p.copy(); 00153 p_copy = removeColumnAndRow(p_copy, absorbingStatesSet()); 00154 assert p_copy.getColumnDimension() == p_copy.getRowDimension(); 00155 RealMatrix id = createIdentityMatrix(p_copy.getColumnDimension()); 00156 RealMatrix inv = new LUDecompositionImpl(id.subtract(p_copy)).getSolver().getInverse(); 00157 RealMatrix mu = createInitialStateDistribution(); 00158 RealMatrix visits = mu.multiply(inv); 00159 double sum = 0; 00160 for (int i = 0; i < visits.getColumnDimension(); i++) 00161 sum += visits.getEntry(0, i); 00162 return (ArrayRealVector) visits.scalarMultiply(1 / sum).getRowVector(0); 00163 } 00164 00165 private Set<Integer> absorbingStatesSet() { 00166 Set<Integer> endStates = new LinkedHashSet<Integer>(); 00167 for (int i = 0; i < nbStates(); i++) 00168 if (!states()[i].hasNextState()) 00169 endStates.add(i); 00170 return endStates; 00171 } 00172 00173 private int nbNonAbsorbingState() { 00174 return stateIndexes.size(); 00175 } 00176 00177 private RealMatrix removeColumnAndRow(RealMatrix m, Set<Integer> absorbingState) { 00178 RealMatrix result = new Array2DRowRealMatrix(nbNonAbsorbingState(), nbNonAbsorbingState()); 00179 int ci = 0; 00180 for (int i = 0; i < m.getRowDimension(); i++) { 00181 if (absorbingState.contains(i)) 00182 continue; 00183 int cj = 0; 00184 for (int j = 0; j < m.getColumnDimension(); j++) { 00185 if (absorbingState.contains(j)) 00186 continue; 00187 result.setEntry(ci, cj, m.getEntry(i, j)); 00188 cj++; 00189 } 00190 ci++; 00191 } 00192 return result; 00193 } 00194 00195 private RealMatrix createInitialStateDistribution() { 00196 double[] numbers = new double[nbNonAbsorbingState()]; 00197 int ci = 0; 00198 for (int i = 0; i < nbStates(); i++) { 00199 GraphState s = states()[i]; 00200 if (!s.hasNextState()) 00201 continue; 00202 if (s != graph.initialState()) 00203 numbers[ci] = 0.0; 00204 else 00205 numbers[ci] = 1.0; 00206 ci++; 00207 } 00208 RealMatrix result = new Array2DRowRealMatrix(1, numbers.length); 00209 for (int i = 0; i < numbers.length; i++) 00210 result.setEntry(0, i, numbers[i]); 00211 return result; 00212 } 00213 00214 private RealMatrix createTransitionProbablityMatrix(Policy policy) { 00215 RealMatrix p = new Array2DRowRealMatrix(nbStates(), nbStates()); 00216 for (int si = 0; si < nbStates(); si++) { 00217 GraphState s_t = states()[si]; 00218 policy.update(s_t.v()); 00219 for (Action a : graph.actions()) { 00220 double pa = policy.pi(a); 00221 GraphState s_tp1 = s_t.nextState(a); 00222 if (s_tp1 != null) 00223 p.setEntry(si, graph.indexOf(s_tp1), pa); 00224 } 00225 } 00226 for (Integer absorbingState : absorbingStatesSet()) 00227 p.setEntry(absorbingState, absorbingState, 1.0); 00228 return p; 00229 } 00230 00231 private int nbStates() { 00232 return graph.nbStates(); 00233 } 00234 00235 private GraphState[] states() { 00236 return graph.states(); 00237 } 00238 00239 public Map<GraphState, Integer> stateIndexes() { 00240 return stateIndexes; 00241 } 00242 00243 public FiniteStateGraph graph() { 00244 return graph; 00245 } 00246 00247 public PVector featureState(GraphState s) { 00248 PVector result = new PVector(size); 00249 if (s != null && s.hasNextState()) 00250 result.data[stateIndexes.get(s)] = 1; 00251 return result; 00252 } 00253 00254 @Override 00255 public PVector stateAction(RealVector s, Action a) { 00256 PVector sa = new PVector(nbNonAbsorbingState() * graph.actions().length); 00257 if (s == null) 00258 return sa; 00259 GraphState sg = graph.state(s); 00260 for (int ai = 0; ai < graph.actions().length; ai++) 00261 if (graph.actions()[ai] == a) { 00262 sa.setEntry(ai * nbNonAbsorbingState() + stateIndexes.get(sg), 1); 00263 return sa; 00264 } 00265 return null; 00266 } 00267 00268 @Override 00269 public int vectorSize() { 00270 return size; 00271 } 00272 00273 @Override 00274 public double vectorNorm() { 00275 return 1; 00276 } 00277 }