RLPark 1.0.0
Reinforcement Learning Framework in Java

RBFs.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.representations.rbf;
00002 
00003 import java.util.ArrayList;
00004 import java.util.Arrays;
00005 import java.util.List;
00006 
00007 import rlpark.plugin.rltoys.algorithms.functions.states.Projector;
00008 import rlpark.plugin.rltoys.math.ranges.Range;
00009 import rlpark.plugin.rltoys.math.vector.implementations.SVector;
00010 import rlpark.plugin.rltoys.utils.Utils;
00011 
00012 public class RBFs implements Projector {
00013   private static final long serialVersionUID = -1905703492835265008L;
00014   private final List<RBF> rbfs = new ArrayList<RBF>();
00015   private SVector vector;
00016   private boolean includeActiveFeature = false;
00017   private final double tolerance;
00018 
00019   public RBFs(double tolerance) {
00020     this.tolerance = tolerance;
00021   }
00022 
00023   private SVector newVectorInstance() {
00024     return new SVector(vectorSize());
00025   }
00026 
00027   public void includeActiveFeature() {
00028     includeActiveFeature = true;
00029     vector = newVectorInstance();
00030   }
00031 
00032   public void addIndependentRBFs(Range[] allRanges, int resolution, double stddev) {
00033     for (int i = 0; i < allRanges.length; i++)
00034       addRBFs(allRanges, new int[] { i }, resolution, stddev);
00035   }
00036 
00037   public void addIndependentRBFs(Range[] allRanges, int[] selectedInputs, int resolution, double stddev) {
00038     for (int i : selectedInputs)
00039       addRBFs(allRanges, new int[] { i }, resolution, stddev);
00040   }
00041 
00042   public void addFullRBFs(Range[] allRanges, int resolution, double stddev) {
00043     addRBFs(allRanges, Utils.range(0, allRanges.length), resolution, stddev);
00044   }
00045 
00046   public void addRBFs(Range[] ranges, int[] inputIndexes, int resolution, double stddev) {
00047     RBF[] addedRbfs = new RBF[] { new RBF(new int[] {}, new double[] {}, stddev) };
00048     for (int i = 0; i < inputIndexes.length; i++) {
00049       int inputIndex = inputIndexes[i];
00050       addedRbfs = combine(addedRbfs, inputIndex, ranges[inputIndex], resolution, stddev);
00051     }
00052     for (RBF rbf : addedRbfs)
00053       rbfs.add(rbf);
00054     vector = newVectorInstance();
00055   }
00056 
00057   private RBF[] combine(RBF[] rbfs, int inputIndex, Range range, int resolution, double stddev) {
00058     RBF[] result = new RBF[rbfs.length * resolution];
00059     double step = range.length() / resolution;
00060     double start = range.min() + step / 2;
00061     for (int i = 0; i < rbfs.length; i++) {
00062       RBF source = rbfs[i % rbfs.length];
00063       int newPatternLength = source.patternIndexes().length + 1;
00064       for (int x = 0; x < resolution; x++) {
00065         int[] patternIndexes = Arrays.copyOf(source.patternIndexes(), newPatternLength);
00066         patternIndexes[newPatternLength - 1] = inputIndex;
00067         double[] patternValues = Arrays.copyOf(source.patternValues(), newPatternLength);
00068         patternValues[newPatternLength - 1] = x * step + start;
00069         result[i * resolution + x] = new RBF(patternIndexes, patternValues, stddev);
00070       }
00071     }
00072     return result;
00073   }
00074 
00075   @Override
00076   public SVector project(double[] inputs) {
00077     vector.clear();
00078     if (inputs == null)
00079       return vector.copy();
00080     for (int i = 0; i < rbfs.size(); i++) {
00081       final RBF rbf = rbfs.get(i);
00082       double distance = rbf.value(inputs);
00083       if (distance > tolerance)
00084         vector.setEntry(i, distance);
00085     }
00086     if (includeActiveFeature)
00087       vector.setEntry(vector.getDimension() - 1, 1.0);
00088     return vector.copy();
00089   }
00090 
00091   @Override
00092   public int vectorSize() {
00093     int vectorSize = rbfs.size();
00094     return includeActiveFeature ? vectorSize + 1 : vectorSize;
00095   }
00096 
00097   @Override
00098   public double vectorNorm() {
00099     return vectorSize();
00100   }
00101 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark