RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.representations.ltu.networks; 00002 00003 import java.util.LinkedHashSet; 00004 import java.util.Random; 00005 import java.util.Set; 00006 00007 import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU; 00008 import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTUAdaptiveDensity; 00009 import rlpark.plugin.rltoys.math.vector.BinaryVector; 00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00011 00012 @Monitor 00013 public class AutoRegulatedNetwork extends RandomNetwork { 00014 private static final long serialVersionUID = 1847556584654367004L; 00015 public final int minUnitActive; 00016 public final int maxUnitActive; 00017 public final double minDensity; 00018 public final double maxDensity; 00019 private final Random random; 00020 private int missingUnit; 00021 private int overUnit; 00022 00023 public AutoRegulatedNetwork(Random random, int inputSize, int outputSize, double minDensity, double maxDensity) { 00024 super(inputSize, outputSize); 00025 this.random = random; 00026 this.minDensity = minDensity; 00027 this.maxDensity = maxDensity; 00028 this.minUnitActive = (int) (minDensity * outputSize); 00029 this.maxUnitActive = (int) (maxDensity * outputSize); 00030 } 00031 00032 @Override 00033 protected void postProjection(BinaryVector obs) { 00034 missingUnit = 0; 00035 overUnit = 0; 00036 int nbActive = output.nonZeroElements(); 00037 if (nbActive > maxUnitActive) 00038 decreaseDensity(obs); 00039 if (nbActive < minUnitActive) 00040 increaseDensity(obs); 00041 super.postProjection(obs); 00042 } 00043 00044 private void increaseDensity(BinaryVector obs) { 00045 missingUnit = minUnitActive - output.nonZeroElements(); 00046 Set<LTUAdaptiveDensity> couldHaveAgree = buildCouldHaveAgreeUnits(obs); 00047 assert missingUnit > 0; 00048 double selectionProbability = Math.min(1.0, missingUnit / (double) couldHaveAgree.size()); 00049 for (LTUAdaptiveDensity ltu : couldHaveAgree) { 00050 if (random.nextFloat() > selectionProbability) 00051 continue; 00052 ltu.increaseDensity(random, denseInputVector); 00053 } 00054 } 00055 00056 private Set<LTUAdaptiveDensity> buildCouldHaveAgreeUnits(BinaryVector obs) { 00057 Set<LTUAdaptiveDensity> couldHaveAgree = new LinkedHashSet<LTUAdaptiveDensity>(); 00058 for (int activeInput : obs.getActiveIndexes()) { 00059 for (LTU ltu : parents(activeInput)) { 00060 if (ltu == null || ltu.isActive()) 00061 continue; 00062 if (!(ltu instanceof LTUAdaptiveDensity)) 00063 continue; 00064 couldHaveAgree.add((LTUAdaptiveDensity) ltu); 00065 } 00066 } 00067 return couldHaveAgree; 00068 } 00069 00070 private void decreaseDensity(BinaryVector obs) { 00071 overUnit = output.nonZeroElements() - maxUnitActive; 00072 assert overUnit > 0; 00073 double selectionProbability = overUnit / (double) output.nonZeroElements(); 00074 for (int activeLTUIndex : output.getActiveIndexes()) { 00075 if (random.nextFloat() > selectionProbability) 00076 continue; 00077 LTU ltu = ltus[activeLTUIndex]; 00078 if (ltu == null || !(ltu instanceof LTUAdaptiveDensity)) 00079 continue; 00080 ((LTUAdaptiveDensity) ltu).decreaseDensity(random, denseInputVector); 00081 } 00082 } 00083 }