RLPark 1.0.0
Reinforcement Learning Framework in Java

RecursiveWeightSorter.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.discovery.ltu;
00002 
00003 import java.util.Comparator;
00004 import java.util.HashSet;
00005 import java.util.Set;
00006 
00007 import rlpark.plugin.rltoys.algorithms.LinearLearner;
00008 import rlpark.plugin.rltoys.algorithms.discovery.sorting.WeightSorter;
00009 import rlpark.plugin.rltoys.algorithms.representations.ltu.networks.RandomNetwork;
00010 import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU;
00011 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00012 import rlpark.plugin.rltoys.utils.Utils;
00013 import zephyr.plugin.core.api.monitoring.annotations.IgnoreMonitor;
00014 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00015 
00016 @Monitor
00017 public class RecursiveWeightSorter extends WeightSorter {
00018   private static final long serialVersionUID = -654469131883608071L;
00019   @IgnoreMonitor
00020   final protected RandomNetwork network;
00021   private final int nbMaxParents;
00022   @Monitor(level = 4)
00023   private final PVector recursiveSum;
00024   private final double discount;
00025 
00026   public RecursiveWeightSorter(RandomNetwork network, LinearLearner[] learners, int nbMaxParents) {
00027     super(learners);
00028     assert network.inputSize > network.outputSize;
00029     this.network = network;
00030     this.nbMaxParents = nbMaxParents;
00031     this.recursiveSum = new PVector(sums.size);
00032     this.discount = Utils.timeStepsToDiscount(nbMaxParents);
00033   }
00034 
00035   @Override
00036   protected Comparator<Integer> createComparator() {
00037     return new PVectorBasedComparator(recursiveSum) {
00038       private static final long serialVersionUID = 4220495235775683757L;
00039       private final int maxSort = network.outputSize;
00040 
00041       @Override
00042       public int compare(Integer o1, Integer o2) {
00043         if (o1 >= maxSort && o2 < maxSort)
00044           return 1;
00045         if (o2 >= maxSort && o1 < maxSort)
00046           return -1;
00047         return super.compare(o1, o2);
00048       }
00049     };
00050   }
00051 
00052   @Override
00053   protected void updateUnitEvaluation() {
00054     super.updateUnitEvaluation();
00055     recursiveSum.set(0);
00056     for (int i = 0; i < network.outputSize; i++)
00057       recursiveSum.data[i] = computeRecursiveWeights(new HashSet<LTU>(), network.ltu(i), 1.0);
00058   }
00059 
00060   private double computeRecursiveWeights(Set<LTU> counted, LTU ltu, double currentDiscount) {
00061     boolean added = counted.add(ltu);
00062     if (!added)
00063       return 0;
00064     double result = sums.data[ltu.index()] * currentDiscount;
00065     if (counted.size() >= nbMaxParents)
00066       return result;
00067     double parentDiscount = currentDiscount * discount;
00068     for (LTU parentLTU : network.parents(ltu.index()))
00069       result += computeRecursiveWeights(counted, parentLTU, parentDiscount);
00070     return result;
00071   }
00072 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark