RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.discovery.sorting; 00002 00003 import java.io.Serializable; 00004 import java.util.Arrays; 00005 import java.util.Comparator; 00006 00007 import rlpark.plugin.rltoys.algorithms.LinearLearner; 00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00009 import zephyr.plugin.core.api.monitoring.annotations.IgnoreMonitor; 00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00011 00012 @Monitor 00013 public class WeightSorter implements Serializable { 00014 private static final long serialVersionUID = 3375889959423486133L; 00015 00016 public static class PVectorBasedComparator implements Comparator<Integer>, Serializable { 00017 private static final long serialVersionUID = -2221092563348361745L; 00018 final private double[] data; 00019 00020 public PVectorBasedComparator(PVector reference) { 00021 data = reference.data; 00022 } 00023 00024 @Override 00025 public int compare(Integer o1, Integer o2) { 00026 return Double.compare(data[o1], data[o2]); 00027 } 00028 }; 00029 00030 @IgnoreMonitor 00031 private final LinearLearner[] learners; 00032 @Monitor(level = 4) 00033 protected final PVector sums; 00034 @Monitor(level = 4) 00035 private final Integer[] order; 00036 private Comparator<Integer> comparator; 00037 private int worst; 00038 private final int startSorting; 00039 private final int endSorting; 00040 00041 public WeightSorter(LinearLearner[] learners) { 00042 this(learners, 0, -1); 00043 } 00044 00045 public WeightSorter(LinearLearner[] learners, int startSorting, int endSorting) { 00046 this.learners = learners; 00047 sums = new PVector(learners[0].weights().size); 00048 this.startSorting = startSorting; 00049 this.endSorting = endSorting > 0 ? endSorting : sums.getDimension(); 00050 order = new Integer[this.endSorting - this.startSorting]; 00051 for (int i = this.startSorting; i < this.endSorting; i++) 00052 order[i - this.startSorting] = i; 00053 } 00054 00055 protected Comparator<Integer> createComparator() { 00056 return new PVectorBasedComparator(sums); 00057 } 00058 00059 public void sort() { 00060 if (comparator == null) 00061 comparator = createComparator(); 00062 worst = 0; 00063 updateUnitEvaluation(); 00064 Arrays.sort(order, comparator); 00065 } 00066 00067 protected void updateUnitEvaluation() { 00068 sums.set(0); 00069 for (LinearLearner learner : learners) { 00070 double[] weight = learner.weights().data; 00071 double maxWeight = 0.0; 00072 for (int i = startSorting; i < endSorting; i++) 00073 maxWeight = Math.max(Math.abs(weight[i]), maxWeight); 00074 for (int i = startSorting; i < endSorting; i++) 00075 sums.data[i] += Math.abs(weight[i]) / maxWeight; 00076 } 00077 } 00078 00079 public boolean hasNext() { 00080 return worst < order.length; 00081 } 00082 00083 public int nextWorst() { 00084 int result = order[worst]; 00085 worst++; 00086 return result; 00087 } 00088 00089 public void resetWeights(int index) { 00090 for (LinearLearner learner : learners) 00091 learner.resetWeight(index); 00092 } 00093 00094 public int endSortingPosition() { 00095 return endSorting; 00096 } 00097 00098 public int startSortingPosition() { 00099 return startSorting; 00100 } 00101 }