RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.representations.ltu.networks; 00002 00003 import java.io.Serializable; 00004 00005 import rlpark.plugin.rltoys.algorithms.representations.ltu.internal.LTUArray; 00006 import rlpark.plugin.rltoys.algorithms.representations.ltu.internal.LTUUpdated; 00007 import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU; 00008 import rlpark.plugin.rltoys.math.vector.BinaryVector; 00009 import rlpark.plugin.rltoys.math.vector.implementations.BVector; 00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00011 00012 public class RandomNetwork implements Serializable { 00013 private static final long serialVersionUID = 8140259178658376161L; 00014 final public int outputSize; 00015 final public int inputSize; 00016 @Monitor(level = 4) 00017 protected final BVector output; 00018 @Monitor(level = 4) 00019 protected final LTU[] ltus; 00020 protected final LTUArray[] connectedLTUs; 00021 @Monitor 00022 protected int nbConnection = 0; 00023 @Monitor 00024 protected int nbActive = 0; 00025 protected final double[] denseInputVector; 00026 final LTUUpdated updatedLTUs; 00027 private final RandomNetworkScheduler scheduler; 00028 @Monitor 00029 int nbUnitUpdated = 0; 00030 00031 public RandomNetwork(int inputSize, int outputSize) { 00032 this(new RandomNetworkScheduler(), inputSize, outputSize); 00033 } 00034 00035 public RandomNetwork(RandomNetworkScheduler scheduler, int inputSize, int outputSize) { 00036 this.outputSize = outputSize; 00037 this.inputSize = inputSize; 00038 this.scheduler = scheduler; 00039 connectedLTUs = new LTUArray[inputSize]; 00040 for (int i = 0; i < connectedLTUs.length; i++) 00041 connectedLTUs[i] = new LTUArray(); 00042 ltus = new LTU[outputSize]; 00043 output = new BVector(outputSize); 00044 denseInputVector = new double[inputSize]; 00045 updatedLTUs = new LTUUpdated(outputSize); 00046 } 00047 00048 public void addLTU(LTU ltu) { 00049 removeLTU(ltus[ltu.index()]); 00050 ltus[ltu.index()] = ltu; 00051 int[] ltuInputs = ltu.inputs(); 00052 for (int input : ltuInputs) 00053 connectedLTUs[input].add(ltu); 00054 addLTUStat(ltu); 00055 } 00056 00057 private void addLTUStat(LTU ltu) { 00058 nbConnection += ltu.inputs().length; 00059 } 00060 00061 private void removeLTUStat(LTU ltu) { 00062 nbConnection -= ltu.inputs().length; 00063 } 00064 00065 public void removeLTU(LTU ltu) { 00066 if (ltu == null) 00067 return; 00068 assert ltus[ltu.index()] != null; 00069 removeLTUStat(ltu); 00070 ltus[ltu.index()] = null; 00071 for (int input : ltu.inputs()) 00072 if (connectedLTUs[input] != null) 00073 connectedLTUs[input].remove(ltu); 00074 } 00075 00076 protected void prepareProjection(BinaryVector obs) { 00077 output.clear(); 00078 updatedLTUs.clean(); 00079 if (obs == null) 00080 return; 00081 for (int activeIndex : obs.getActiveIndexes()) 00082 denseInputVector[activeIndex] = 1; 00083 } 00084 00085 protected void postProjection(BinaryVector obs) { 00086 for (int activeIndex : obs.getActiveIndexes()) 00087 denseInputVector[activeIndex] = 0; 00088 } 00089 00090 public BVector project(BinaryVector obs) { 00091 prepareProjection(obs); 00092 if (obs == null) 00093 return output.copy(); 00094 scheduler.update(this, obs); 00095 nbUnitUpdated = updatedLTUs.nbUnitUpdated(); 00096 postProjection(obs); 00097 nbActive = output.nonZeroElements(); 00098 return output.copy(); 00099 } 00100 00101 public LTU ltu(int i) { 00102 return ltus[i]; 00103 } 00104 00105 public LTU[] parents(int index) { 00106 return connectedLTUs[index] != null ? connectedLTUs[index].array() : new LTU[] {}; 00107 } 00108 00109 public LTU[] ltus() { 00110 return ltus; 00111 } 00112 00113 public void dispose() { 00114 scheduler.dispose(); 00115 } 00116 }