RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.discovery.ltu; 00002 00003 import java.io.Serializable; 00004 import java.util.LinkedList; 00005 import java.util.Random; 00006 00007 import rlpark.plugin.rltoys.algorithms.discovery.sorting.WeightSorter; 00008 import rlpark.plugin.rltoys.algorithms.representations.ltu.networks.RandomNetwork; 00009 import rlpark.plugin.rltoys.algorithms.representations.ltu.networks.RandomNetworks; 00010 import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU; 00011 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00012 import zephyr.plugin.core.api.signals.Signal; 00013 00014 public class RepresentationDiscovery implements Serializable { 00015 private static final long serialVersionUID = 8579686361420622461L; 00016 public final Signal<LTU> onLTUAdded = new Signal<LTU>(); 00017 private final RandomNetwork network; 00018 @Monitor 00019 private final WeightSorter sorter; 00020 private final LinkedList<Integer> protectedUnits = new LinkedList<Integer>(); 00021 private final int nbProtectedUnits; 00022 private final LTU prototype; 00023 private final Random random; 00024 private int worstUnit; 00025 private final int nbInputForUnit; 00026 00027 public RepresentationDiscovery(Random random, RandomNetwork network, WeightSorter sorter, LTU prototype, 00028 int nbProtectedUnit, int nbInputForUnit) { 00029 this.random = random; 00030 this.network = network; 00031 this.sorter = sorter; 00032 this.prototype = prototype; 00033 this.nbInputForUnit = nbInputForUnit; 00034 this.nbProtectedUnits = nbProtectedUnit; 00035 assert nbProtectedUnits > 0; 00036 } 00037 00038 public void changeRepresentation(int nbUnitsToChange) { 00039 sorter.sort(); 00040 for (int unitIndex = 0; unitIndex < nbUnitsToChange; unitIndex++) { 00041 worstUnit = findWorstUnit(); 00042 assert worstUnit >= 0; 00043 LTU ltu = createNewUnit(worstUnit); 00044 network.addLTU(ltu); 00045 sorter.resetWeights(ltu.index()); 00046 addIntoProtectedUnits(worstUnit); 00047 onLTUAdded.fire(ltu); 00048 } 00049 } 00050 00051 private void addIntoProtectedUnits(int worstUnit) { 00052 protectedUnits.push(worstUnit); 00053 if (protectedUnits.size() > nbProtectedUnits) 00054 protectedUnits.pollLast(); 00055 } 00056 00057 private LTU createNewUnit(int ltuIndex) { 00058 return RandomNetworks.newRandomUnit(random, prototype, ltuIndex, network.inputSize, nbInputForUnit); 00059 } 00060 00061 protected int findWorstUnit() { 00062 int worstUnit = -1; 00063 do { 00064 worstUnit = sorter.nextWorst(); 00065 } while (protectedUnits.contains(worstUnit)); 00066 return worstUnit; 00067 } 00068 00069 public void fillNetwork() { 00070 RandomNetworks.connect(random, network, prototype, nbInputForUnit, sorter.startSortingPosition(), 00071 sorter.endSortingPosition()); 00072 } 00073 }