RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.representations.rbf; 00002 00003 import java.util.ArrayList; 00004 import java.util.Arrays; 00005 import java.util.List; 00006 00007 import rlpark.plugin.rltoys.algorithms.functions.states.Projector; 00008 import rlpark.plugin.rltoys.math.ranges.Range; 00009 import rlpark.plugin.rltoys.math.vector.implementations.SVector; 00010 import rlpark.plugin.rltoys.utils.Utils; 00011 00012 public class RBFs implements Projector { 00013 private static final long serialVersionUID = -1905703492835265008L; 00014 private final List<RBF> rbfs = new ArrayList<RBF>(); 00015 private SVector vector; 00016 private boolean includeActiveFeature = false; 00017 private final double tolerance; 00018 00019 public RBFs(double tolerance) { 00020 this.tolerance = tolerance; 00021 } 00022 00023 private SVector newVectorInstance() { 00024 return new SVector(vectorSize()); 00025 } 00026 00027 public void includeActiveFeature() { 00028 includeActiveFeature = true; 00029 vector = newVectorInstance(); 00030 } 00031 00032 public void addIndependentRBFs(Range[] allRanges, int resolution, double stddev) { 00033 for (int i = 0; i < allRanges.length; i++) 00034 addRBFs(allRanges, new int[] { i }, resolution, stddev); 00035 } 00036 00037 public void addIndependentRBFs(Range[] allRanges, int[] selectedInputs, int resolution, double stddev) { 00038 for (int i : selectedInputs) 00039 addRBFs(allRanges, new int[] { i }, resolution, stddev); 00040 } 00041 00042 public void addFullRBFs(Range[] allRanges, int resolution, double stddev) { 00043 addRBFs(allRanges, Utils.range(0, allRanges.length), resolution, stddev); 00044 } 00045 00046 public void addRBFs(Range[] ranges, int[] inputIndexes, int resolution, double stddev) { 00047 RBF[] addedRbfs = new RBF[] { new RBF(new int[] {}, new double[] {}, stddev) }; 00048 for (int i = 0; i < inputIndexes.length; i++) { 00049 int inputIndex = inputIndexes[i]; 00050 addedRbfs = combine(addedRbfs, inputIndex, ranges[inputIndex], resolution, stddev); 00051 } 00052 for (RBF rbf : addedRbfs) 00053 rbfs.add(rbf); 00054 vector = newVectorInstance(); 00055 } 00056 00057 private RBF[] combine(RBF[] rbfs, int inputIndex, Range range, int resolution, double stddev) { 00058 RBF[] result = new RBF[rbfs.length * resolution]; 00059 double step = range.length() / resolution; 00060 double start = range.min() + step / 2; 00061 for (int i = 0; i < rbfs.length; i++) { 00062 RBF source = rbfs[i % rbfs.length]; 00063 int newPatternLength = source.patternIndexes().length + 1; 00064 for (int x = 0; x < resolution; x++) { 00065 int[] patternIndexes = Arrays.copyOf(source.patternIndexes(), newPatternLength); 00066 patternIndexes[newPatternLength - 1] = inputIndex; 00067 double[] patternValues = Arrays.copyOf(source.patternValues(), newPatternLength); 00068 patternValues[newPatternLength - 1] = x * step + start; 00069 result[i * resolution + x] = new RBF(patternIndexes, patternValues, stddev); 00070 } 00071 } 00072 return result; 00073 } 00074 00075 @Override 00076 public SVector project(double[] inputs) { 00077 vector.clear(); 00078 if (inputs == null) 00079 return vector.copy(); 00080 for (int i = 0; i < rbfs.size(); i++) { 00081 final RBF rbf = rbfs.get(i); 00082 double distance = rbf.value(inputs); 00083 if (distance > tolerance) 00084 vector.setEntry(i, distance); 00085 } 00086 if (includeActiveFeature) 00087 vector.setEntry(vector.getDimension() - 1, 1.0); 00088 return vector.copy(); 00089 } 00090 00091 @Override 00092 public int vectorSize() { 00093 int vectorSize = rbfs.size(); 00094 return includeActiveFeature ? vectorSize + 1 : vectorSize; 00095 } 00096 00097 @Override 00098 public double vectorNorm() { 00099 return vectorSize(); 00100 } 00101 }