RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.functions.stateactions; 00002 00003 import java.util.Map; 00004 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.envio.actions.Actions; 00007 import rlpark.plugin.rltoys.math.vector.BinaryVector; 00008 import rlpark.plugin.rltoys.math.vector.MutableVector; 00009 import rlpark.plugin.rltoys.math.vector.RealVector; 00010 import rlpark.plugin.rltoys.math.vector.implementations.BVector; 00011 00012 public class TabularAction implements StateToStateAction, Cloneable { 00013 private static final long serialVersionUID = 1705117400022134128L; 00014 private final Action[] actions; 00015 private final int stateVectorSize; 00016 private BVector nullVector; 00017 private final double vectorNorm; 00018 private boolean includeActiveFeature = false; 00019 private RealVector buffer; 00020 private final Map<Action, Integer> actionToIndex; 00021 00022 public TabularAction(Action[] actions, double vectorNorm, int vectorSize) { 00023 this.actions = actions; 00024 this.vectorNorm = vectorNorm + 1; 00025 this.stateVectorSize = vectorSize; 00026 this.nullVector = new BVector(vectorSize()); 00027 actionToIndex = Actions.createActionIntMap(actions); 00028 } 00029 00030 protected int atoi(Action a) { 00031 return actionToIndex.get(a); 00032 } 00033 00034 public void includeActiveFeature() { 00035 includeActiveFeature = true; 00036 this.nullVector = new BVector(vectorSize()); 00037 } 00038 00039 @Override 00040 public int vectorSize() { 00041 int result = stateVectorSize * actions.length; 00042 if (includeActiveFeature) 00043 result += 1; 00044 return result; 00045 } 00046 00047 @Override 00048 public RealVector stateAction(RealVector s, Action a) { 00049 if (s == null) 00050 return nullVector; 00051 if (buffer == null) 00052 buffer = (s instanceof BinaryVector) ? new BVector(vectorSize()) : s.newInstance(vectorSize()); 00053 int offset = atoi(a) * stateVectorSize; 00054 if (s instanceof BinaryVector) 00055 return stateAction((BinaryVector) s, offset); 00056 MutableVector phi_sa = (MutableVector) buffer; 00057 phi_sa.clear(); 00058 if (includeActiveFeature) 00059 phi_sa.setEntry(vectorSize() - 1, 1); 00060 for (int s_i = 0; s_i < s.getDimension(); s_i++) 00061 phi_sa.setEntry(s_i + offset, s.getEntry(s_i)); 00062 return phi_sa; 00063 } 00064 00065 private RealVector stateAction(BinaryVector s, int offset) { 00066 BVector phi_sa = (BVector) buffer; 00067 phi_sa.clear(); 00068 phi_sa.mergeSubVector(offset, s); 00069 if (includeActiveFeature) 00070 phi_sa.setOn(phi_sa.getDimension() - 1); 00071 return phi_sa; 00072 } 00073 00074 public Action[] actions() { 00075 return actions; 00076 } 00077 00078 @Override 00079 public double vectorNorm() { 00080 return vectorNorm; 00081 } 00082 00083 @Override 00084 public StateToStateAction clone() throws CloneNotSupportedException { 00085 return new TabularAction(actions, vectorNorm - 1, stateVectorSize); 00086 } 00087 }