RLPark 1.0.0
Reinforcement Learning Framework in Java

RepresentationDiscovery.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.discovery.ltu;
00002 
00003 import java.io.Serializable;
00004 import java.util.LinkedList;
00005 import java.util.Random;
00006 
00007 import rlpark.plugin.rltoys.algorithms.discovery.sorting.WeightSorter;
00008 import rlpark.plugin.rltoys.algorithms.representations.ltu.networks.RandomNetwork;
00009 import rlpark.plugin.rltoys.algorithms.representations.ltu.networks.RandomNetworks;
00010 import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU;
00011 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00012 import zephyr.plugin.core.api.signals.Signal;
00013 
00014 public class RepresentationDiscovery implements Serializable {
00015   private static final long serialVersionUID = 8579686361420622461L;
00016   public final Signal<LTU> onLTUAdded = new Signal<LTU>();
00017   private final RandomNetwork network;
00018   @Monitor
00019   private final WeightSorter sorter;
00020   private final LinkedList<Integer> protectedUnits = new LinkedList<Integer>();
00021   private final int nbProtectedUnits;
00022   private final LTU prototype;
00023   private final Random random;
00024   private int worstUnit;
00025   private final int nbInputForUnit;
00026 
00027   public RepresentationDiscovery(Random random, RandomNetwork network, WeightSorter sorter, LTU prototype,
00028       int nbProtectedUnit, int nbInputForUnit) {
00029     this.random = random;
00030     this.network = network;
00031     this.sorter = sorter;
00032     this.prototype = prototype;
00033     this.nbInputForUnit = nbInputForUnit;
00034     this.nbProtectedUnits = nbProtectedUnit;
00035     assert nbProtectedUnits > 0;
00036   }
00037 
00038   public void changeRepresentation(int nbUnitsToChange) {
00039     sorter.sort();
00040     for (int unitIndex = 0; unitIndex < nbUnitsToChange; unitIndex++) {
00041       worstUnit = findWorstUnit();
00042       assert worstUnit >= 0;
00043       LTU ltu = createNewUnit(worstUnit);
00044       network.addLTU(ltu);
00045       sorter.resetWeights(ltu.index());
00046       addIntoProtectedUnits(worstUnit);
00047       onLTUAdded.fire(ltu);
00048     }
00049   }
00050 
00051   private void addIntoProtectedUnits(int worstUnit) {
00052     protectedUnits.push(worstUnit);
00053     if (protectedUnits.size() > nbProtectedUnits)
00054       protectedUnits.pollLast();
00055   }
00056 
00057   private LTU createNewUnit(int ltuIndex) {
00058     return RandomNetworks.newRandomUnit(random, prototype, ltuIndex, network.inputSize, nbInputForUnit);
00059   }
00060 
00061   protected int findWorstUnit() {
00062     int worstUnit = -1;
00063     do {
00064       worstUnit = sorter.nextWorst();
00065     } while (protectedUnits.contains(worstUnit));
00066     return worstUnit;
00067   }
00068 
00069   public void fillNetwork() {
00070     RandomNetworks.connect(random, network, prototype, nbInputForUnit, sorter.startSortingPosition(),
00071                            sorter.endSortingPosition());
00072   }
00073 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark