RLPark 1.0.0
Reinforcement Learning Framework in Java

LTUThreshold.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.representations.ltu.units;
00002 
00003 import java.util.Random;
00004 
00005 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00006 
00007 public class LTUThreshold implements LTUAdaptiveDensity {
00008   private static final long serialVersionUID = -4100313691365362138L;
00009   final static public double Beta = .6;
00010   final public int index;
00011   protected final Connections connections;
00012   @Monitor
00013   protected double threshold;
00014   @Monitor
00015   private double sum;
00016   @Monitor
00017   private boolean isActive;
00018 
00019   public LTUThreshold() {
00020     index = -1;
00021     connections = null;
00022   }
00023 
00024   public LTUThreshold(int index, int[] inputs, byte[] weights) {
00025     this.index = index;
00026     int nbNegative = 0;
00027     connections = new Connections(inputs.length);
00028     for (int i = 0; i < inputs.length; i++) {
00029       connections.setEntry(inputs[i], weights[i]);
00030       if (weights[i] < 0)
00031         nbNegative++;
00032     }
00033     threshold = -nbNegative + 0.6 * inputs.length;
00034   }
00035 
00036   @Override
00037   public void updateSum(double[] inputVector) {
00038     sum = connections.dotProduct(inputVector);
00039   }
00040 
00041   @Override
00042   public LTUThreshold newLTU(int index, int[] inputs, byte[] weights) {
00043     return new LTUThreshold(index, inputs, weights);
00044   }
00045 
00046   @Override
00047   public int[] inputs() {
00048     return connections.indexes;
00049   }
00050 
00051   @Override
00052   public int index() {
00053     return index;
00054   }
00055 
00056   @Override
00057   public void decreaseDensity(Random random, double[] inputVector) {
00058     int bit = random.nextInt(connections.nbActive);
00059     double weight = connections.weights[bit];
00060     boolean inputActive = inputVector[connections.indexes[bit]] > 0;
00061     if ((weight == 1 && inputActive) || (weight == -1 && !inputActive)) {
00062       connections.weights[bit] *= -1;
00063       threshold += 1;
00064     }
00065   }
00066 
00067   @Override
00068   public void increaseDensity(Random random, double[] inputVector) {
00069     int bit = random.nextInt(connections.nbActive);
00070     double weight = connections.weights[bit];
00071     boolean inputActive = inputVector[connections.indexes[bit]] > 0;
00072     if ((weight == -1 && inputActive) || (weight == +1 && !inputActive)) {
00073       connections.weights[bit] *= -1;
00074       threshold -= 1;
00075     }
00076   }
00077 
00078   @Override
00079   public boolean updateActivation() {
00080     isActive = sum >= threshold;
00081     sum = 0;
00082     return isActive;
00083   }
00084 
00085   @Override
00086   public boolean isActive() {
00087     return isActive;
00088   }
00089 
00090   public Connections connections() {
00091     return connections;
00092   }
00093 
00094   public void setThreshold(double threshold) {
00095     this.threshold = threshold;
00096   }
00097 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark