RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.discovery.ltu; 00002 00003 import java.util.Comparator; 00004 import java.util.HashSet; 00005 import java.util.Set; 00006 00007 import rlpark.plugin.rltoys.algorithms.LinearLearner; 00008 import rlpark.plugin.rltoys.algorithms.discovery.sorting.WeightSorter; 00009 import rlpark.plugin.rltoys.algorithms.representations.ltu.networks.RandomNetwork; 00010 import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU; 00011 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00012 import rlpark.plugin.rltoys.utils.Utils; 00013 import zephyr.plugin.core.api.monitoring.annotations.IgnoreMonitor; 00014 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00015 00016 @Monitor 00017 public class RecursiveWeightSorter extends WeightSorter { 00018 private static final long serialVersionUID = -654469131883608071L; 00019 @IgnoreMonitor 00020 final protected RandomNetwork network; 00021 private final int nbMaxParents; 00022 @Monitor(level = 4) 00023 private final PVector recursiveSum; 00024 private final double discount; 00025 00026 public RecursiveWeightSorter(RandomNetwork network, LinearLearner[] learners, int nbMaxParents) { 00027 super(learners); 00028 assert network.inputSize > network.outputSize; 00029 this.network = network; 00030 this.nbMaxParents = nbMaxParents; 00031 this.recursiveSum = new PVector(sums.size); 00032 this.discount = Utils.timeStepsToDiscount(nbMaxParents); 00033 } 00034 00035 @Override 00036 protected Comparator<Integer> createComparator() { 00037 return new PVectorBasedComparator(recursiveSum) { 00038 private static final long serialVersionUID = 4220495235775683757L; 00039 private final int maxSort = network.outputSize; 00040 00041 @Override 00042 public int compare(Integer o1, Integer o2) { 00043 if (o1 >= maxSort && o2 < maxSort) 00044 return 1; 00045 if (o2 >= maxSort && o1 < maxSort) 00046 return -1; 00047 return super.compare(o1, o2); 00048 } 00049 }; 00050 } 00051 00052 @Override 00053 protected void updateUnitEvaluation() { 00054 super.updateUnitEvaluation(); 00055 recursiveSum.set(0); 00056 for (int i = 0; i < network.outputSize; i++) 00057 recursiveSum.data[i] = computeRecursiveWeights(new HashSet<LTU>(), network.ltu(i), 1.0); 00058 } 00059 00060 private double computeRecursiveWeights(Set<LTU> counted, LTU ltu, double currentDiscount) { 00061 boolean added = counted.add(ltu); 00062 if (!added) 00063 return 0; 00064 double result = sums.data[ltu.index()] * currentDiscount; 00065 if (counted.size() >= nbMaxParents) 00066 return result; 00067 double parentDiscount = currentDiscount * discount; 00068 for (LTU parentLTU : network.parents(ltu.index())) 00069 result += computeRecursiveWeights(counted, parentLTU, parentDiscount); 00070 return result; 00071 } 00072 }