RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.representations.ltu.units; 00002 00003 import java.util.Random; 00004 00005 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00006 00007 public class LTUThreshold implements LTUAdaptiveDensity { 00008 private static final long serialVersionUID = -4100313691365362138L; 00009 final static public double Beta = .6; 00010 final public int index; 00011 protected final Connections connections; 00012 @Monitor 00013 protected double threshold; 00014 @Monitor 00015 private double sum; 00016 @Monitor 00017 private boolean isActive; 00018 00019 public LTUThreshold() { 00020 index = -1; 00021 connections = null; 00022 } 00023 00024 public LTUThreshold(int index, int[] inputs, byte[] weights) { 00025 this.index = index; 00026 int nbNegative = 0; 00027 connections = new Connections(inputs.length); 00028 for (int i = 0; i < inputs.length; i++) { 00029 connections.setEntry(inputs[i], weights[i]); 00030 if (weights[i] < 0) 00031 nbNegative++; 00032 } 00033 threshold = -nbNegative + 0.6 * inputs.length; 00034 } 00035 00036 @Override 00037 public void updateSum(double[] inputVector) { 00038 sum = connections.dotProduct(inputVector); 00039 } 00040 00041 @Override 00042 public LTUThreshold newLTU(int index, int[] inputs, byte[] weights) { 00043 return new LTUThreshold(index, inputs, weights); 00044 } 00045 00046 @Override 00047 public int[] inputs() { 00048 return connections.indexes; 00049 } 00050 00051 @Override 00052 public int index() { 00053 return index; 00054 } 00055 00056 @Override 00057 public void decreaseDensity(Random random, double[] inputVector) { 00058 int bit = random.nextInt(connections.nbActive); 00059 double weight = connections.weights[bit]; 00060 boolean inputActive = inputVector[connections.indexes[bit]] > 0; 00061 if ((weight == 1 && inputActive) || (weight == -1 && !inputActive)) { 00062 connections.weights[bit] *= -1; 00063 threshold += 1; 00064 } 00065 } 00066 00067 @Override 00068 public void increaseDensity(Random random, double[] inputVector) { 00069 int bit = random.nextInt(connections.nbActive); 00070 double weight = connections.weights[bit]; 00071 boolean inputActive = inputVector[connections.indexes[bit]] > 0; 00072 if ((weight == -1 && inputActive) || (weight == +1 && !inputActive)) { 00073 connections.weights[bit] *= -1; 00074 threshold -= 1; 00075 } 00076 } 00077 00078 @Override 00079 public boolean updateActivation() { 00080 isActive = sum >= threshold; 00081 sum = 0; 00082 return isActive; 00083 } 00084 00085 @Override 00086 public boolean isActive() { 00087 return isActive; 00088 } 00089 00090 public Connections connections() { 00091 return connections; 00092 } 00093 00094 public void setThreshold(double threshold) { 00095 this.threshold = threshold; 00096 } 00097 }