RLPark 1.0.0
Reinforcement Learning Framework in Java

ObsHistory.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.representations.observations;
00002 
00003 import java.io.Serializable;
00004 import java.util.ArrayList;
00005 import java.util.List;
00006 
00007 import rlpark.plugin.rltoys.envio.observations.Legend;
00008 import rlpark.plugin.rltoys.math.ranges.Range;
00009 import zephyr.plugin.core.api.monitoring.annotations.LabelProvider;
00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00011 
00012 public class ObsHistory implements Serializable {
00013   private static final long serialVersionUID = 7843542344680953444L;
00014   public int nbTimeSteps;
00015   private final int obsVectorSize;
00016   private final double[] oh_t;
00017   @Monitor(level = 2)
00018   private final double[] oh_tp1;
00019   private final Legend legend;
00020   private final Range[] ranges;
00021 
00022   public ObsHistory(int nbStepHistory, Legend legend) {
00023     this(nbStepHistory, legend, null);
00024   }
00025 
00026   public ObsHistory(int nbStepHistory, Legend legend, Range[] ranges) {
00027     assert nbStepHistory >= 0;
00028     assert ranges == null || legend.nbLabels() == ranges.length;
00029     nbTimeSteps = nbStepHistory + 1;
00030     obsVectorSize = legend.nbLabels();
00031     oh_t = new double[nbTimeSteps * obsVectorSize];
00032     oh_tp1 = new double[nbTimeSteps * obsVectorSize];
00033     this.legend = buildLegend(legend);
00034     this.ranges = ranges;
00035   }
00036 
00037   @LabelProvider(ids = { "oh_tp1" })
00038   protected String labelOf(int index) {
00039     return legend.label(index);
00040   }
00041 
00042   private Legend buildLegend(Legend legend) {
00043     List<String> obsLabel = legend.getLabels();
00044     String[] labels = new String[nbTimeSteps * obsLabel.size()];
00045     for (int i = 0; i < nbTimeSteps; i++) {
00046       int timeOffset = nbTimeSteps - i - 1;
00047       for (int j = 0; j < obsLabel.size(); j++) {
00048         String label = obsLabel.get(j);
00049         label += toTimeLabel(timeOffset);
00050         labels[i * obsLabel.size() + j] = label;
00051       }
00052     }
00053     return new Legend(labels);
00054   }
00055 
00056   static protected String toTimeLabel(int timeOffset) {
00057     return "[t-" + timeOffset + "]";
00058   }
00059 
00060   public Legend legend() {
00061     return legend;
00062   }
00063 
00064   public int historyVectorSize() {
00065     return nbTimeSteps * obsVectorSize;
00066   }
00067 
00068   public double[] update(double[] o_tp1) {
00069     if (o_tp1 == null)
00070       return null;
00071     System.arraycopy(oh_tp1, 0, oh_t, 0, oh_t.length);
00072     int historyObsLength = (nbTimeSteps - 1) * obsVectorSize;
00073     System.arraycopy(oh_t, obsVectorSize, oh_tp1, 0, historyObsLength);
00074     System.arraycopy(o_tp1, 0, oh_tp1, historyObsLength, obsVectorSize);
00075     return oh_tp1;
00076   }
00077 
00078   public int[] selectIndexes(int timeOffset, String... prefixes) {
00079     List<String> selectedLabels = selectLabels(timeOffset, prefixes);
00080     int[] indexes = new int[selectedLabels.size()];
00081     for (int i = 0; i < indexes.length; i++)
00082       indexes[i] = legend.indexOf(selectedLabels.get(i));
00083     return indexes;
00084   }
00085 
00086   public List<String> selectLabels(int timeOffset, String... prefixes) {
00087     List<String> result = new ArrayList<String>();
00088     String timeStringLabel = toTimeLabel(timeOffset);
00089     for (String label : legend.getLabels()) {
00090       if (!label.endsWith(timeStringLabel))
00091         continue;
00092       for (String prefix : prefixes)
00093         if (label.startsWith(prefix)) {
00094           result.add(label);
00095           break;
00096         }
00097     }
00098     return result;
00099   }
00100 
00101   public Range[] getRanges() {
00102     if (ranges == null || legend.nbLabels() / nbTimeSteps != ranges.length)
00103       return null;
00104     Range[] result = new Range[historyVectorSize()];
00105     for (int i = 0; i < result.length; i++)
00106       result[i] = ranges[i % ranges.length];
00107     return result;
00108   }
00109 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark