RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.representations.observations; 00002 00003 import java.io.Serializable; 00004 import java.util.ArrayList; 00005 import java.util.List; 00006 00007 import rlpark.plugin.rltoys.envio.observations.Legend; 00008 import rlpark.plugin.rltoys.math.ranges.Range; 00009 import zephyr.plugin.core.api.monitoring.annotations.LabelProvider; 00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00011 00012 public class ObsHistory implements Serializable { 00013 private static final long serialVersionUID = 7843542344680953444L; 00014 public int nbTimeSteps; 00015 private final int obsVectorSize; 00016 private final double[] oh_t; 00017 @Monitor(level = 2) 00018 private final double[] oh_tp1; 00019 private final Legend legend; 00020 private final Range[] ranges; 00021 00022 public ObsHistory(int nbStepHistory, Legend legend) { 00023 this(nbStepHistory, legend, null); 00024 } 00025 00026 public ObsHistory(int nbStepHistory, Legend legend, Range[] ranges) { 00027 assert nbStepHistory >= 0; 00028 assert ranges == null || legend.nbLabels() == ranges.length; 00029 nbTimeSteps = nbStepHistory + 1; 00030 obsVectorSize = legend.nbLabels(); 00031 oh_t = new double[nbTimeSteps * obsVectorSize]; 00032 oh_tp1 = new double[nbTimeSteps * obsVectorSize]; 00033 this.legend = buildLegend(legend); 00034 this.ranges = ranges; 00035 } 00036 00037 @LabelProvider(ids = { "oh_tp1" }) 00038 protected String labelOf(int index) { 00039 return legend.label(index); 00040 } 00041 00042 private Legend buildLegend(Legend legend) { 00043 List<String> obsLabel = legend.getLabels(); 00044 String[] labels = new String[nbTimeSteps * obsLabel.size()]; 00045 for (int i = 0; i < nbTimeSteps; i++) { 00046 int timeOffset = nbTimeSteps - i - 1; 00047 for (int j = 0; j < obsLabel.size(); j++) { 00048 String label = obsLabel.get(j); 00049 label += toTimeLabel(timeOffset); 00050 labels[i * obsLabel.size() + j] = label; 00051 } 00052 } 00053 return new Legend(labels); 00054 } 00055 00056 static protected String toTimeLabel(int timeOffset) { 00057 return "[t-" + timeOffset + "]"; 00058 } 00059 00060 public Legend legend() { 00061 return legend; 00062 } 00063 00064 public int historyVectorSize() { 00065 return nbTimeSteps * obsVectorSize; 00066 } 00067 00068 public double[] update(double[] o_tp1) { 00069 if (o_tp1 == null) 00070 return null; 00071 System.arraycopy(oh_tp1, 0, oh_t, 0, oh_t.length); 00072 int historyObsLength = (nbTimeSteps - 1) * obsVectorSize; 00073 System.arraycopy(oh_t, obsVectorSize, oh_tp1, 0, historyObsLength); 00074 System.arraycopy(o_tp1, 0, oh_tp1, historyObsLength, obsVectorSize); 00075 return oh_tp1; 00076 } 00077 00078 public int[] selectIndexes(int timeOffset, String... prefixes) { 00079 List<String> selectedLabels = selectLabels(timeOffset, prefixes); 00080 int[] indexes = new int[selectedLabels.size()]; 00081 for (int i = 0; i < indexes.length; i++) 00082 indexes[i] = legend.indexOf(selectedLabels.get(i)); 00083 return indexes; 00084 } 00085 00086 public List<String> selectLabels(int timeOffset, String... prefixes) { 00087 List<String> result = new ArrayList<String>(); 00088 String timeStringLabel = toTimeLabel(timeOffset); 00089 for (String label : legend.getLabels()) { 00090 if (!label.endsWith(timeStringLabel)) 00091 continue; 00092 for (String prefix : prefixes) 00093 if (label.startsWith(prefix)) { 00094 result.add(label); 00095 break; 00096 } 00097 } 00098 return result; 00099 } 00100 00101 public Range[] getRanges() { 00102 if (ranges == null || legend.nbLabels() / nbTimeSteps != ranges.length) 00103 return null; 00104 Range[] result = new Range[historyVectorSize()]; 00105 for (int i = 0; i < result.length; i++) 00106 result[i] = ranges[i % ranges.length]; 00107 return result; 00108 } 00109 }