RLPark 1.0.0
Reinforcement Learning Framework in Java

WeightSorter.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.discovery.sorting;
00002 
00003 import java.io.Serializable;
00004 import java.util.Arrays;
00005 import java.util.Comparator;
00006 
00007 import rlpark.plugin.rltoys.algorithms.LinearLearner;
00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00009 import zephyr.plugin.core.api.monitoring.annotations.IgnoreMonitor;
00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00011 
00012 @Monitor
00013 public class WeightSorter implements Serializable {
00014   private static final long serialVersionUID = 3375889959423486133L;
00015 
00016   public static class PVectorBasedComparator implements Comparator<Integer>, Serializable {
00017     private static final long serialVersionUID = -2221092563348361745L;
00018     final private double[] data;
00019 
00020     public PVectorBasedComparator(PVector reference) {
00021       data = reference.data;
00022     }
00023 
00024     @Override
00025     public int compare(Integer o1, Integer o2) {
00026       return Double.compare(data[o1], data[o2]);
00027     }
00028   };
00029 
00030   @IgnoreMonitor
00031   private final LinearLearner[] learners;
00032   @Monitor(level = 4)
00033   protected final PVector sums;
00034   @Monitor(level = 4)
00035   private final Integer[] order;
00036   private Comparator<Integer> comparator;
00037   private int worst;
00038   private final int startSorting;
00039   private final int endSorting;
00040 
00041   public WeightSorter(LinearLearner[] learners) {
00042     this(learners, 0, -1);
00043   }
00044 
00045   public WeightSorter(LinearLearner[] learners, int startSorting, int endSorting) {
00046     this.learners = learners;
00047     sums = new PVector(learners[0].weights().size);
00048     this.startSorting = startSorting;
00049     this.endSorting = endSorting > 0 ? endSorting : sums.getDimension();
00050     order = new Integer[this.endSorting - this.startSorting];
00051     for (int i = this.startSorting; i < this.endSorting; i++)
00052       order[i - this.startSorting] = i;
00053   }
00054 
00055   protected Comparator<Integer> createComparator() {
00056     return new PVectorBasedComparator(sums);
00057   }
00058 
00059   public void sort() {
00060     if (comparator == null)
00061       comparator = createComparator();
00062     worst = 0;
00063     updateUnitEvaluation();
00064     Arrays.sort(order, comparator);
00065   }
00066 
00067   protected void updateUnitEvaluation() {
00068     sums.set(0);
00069     for (LinearLearner learner : learners) {
00070       double[] weight = learner.weights().data;
00071       double maxWeight = 0.0;
00072       for (int i = startSorting; i < endSorting; i++)
00073         maxWeight = Math.max(Math.abs(weight[i]), maxWeight);
00074       for (int i = startSorting; i < endSorting; i++)
00075         sums.data[i] += Math.abs(weight[i]) / maxWeight;
00076     }
00077   }
00078 
00079   public boolean hasNext() {
00080     return worst < order.length;
00081   }
00082 
00083   public int nextWorst() {
00084     int result = order[worst];
00085     worst++;
00086     return result;
00087   }
00088 
00089   public void resetWeights(int index) {
00090     for (LinearLearner learner : learners)
00091       learner.resetWeight(index);
00092   }
00093 
00094   public int endSortingPosition() {
00095     return endSorting;
00096   }
00097 
00098   public int startSortingPosition() {
00099     return startSorting;
00100   }
00101 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark