RLPark 1.0.0
Reinforcement Learning Framework in Java

RandomNetworkScheduler.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.representations.ltu.networks;
00002 
00003 import java.io.Serializable;
00004 import java.util.concurrent.ExecutionException;
00005 import java.util.concurrent.ExecutorService;
00006 import java.util.concurrent.Future;
00007 
00008 import rlpark.plugin.rltoys.algorithms.representations.ltu.internal.LTUArray;
00009 import rlpark.plugin.rltoys.algorithms.representations.ltu.internal.LTUUpdated;
00010 import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU;
00011 import rlpark.plugin.rltoys.math.vector.BinaryVector;
00012 import rlpark.plugin.rltoys.utils.Scheduling;
00013 
00014 public class RandomNetworkScheduler implements Serializable {
00015   private static final long serialVersionUID = -2515509378000478726L;
00016 
00017   protected class LTUSumUpdater implements Runnable {
00018     private final int offset;
00019     private final LTUArray[] connectedLTUs;
00020     private final LTUUpdated updatedLTUs;
00021     private final double[] denseInputVector;
00022     private final boolean[] updated;
00023 
00024     LTUSumUpdater(RandomNetwork randomNetwork, int offset) {
00025       this.offset = offset;
00026       connectedLTUs = randomNetwork.connectedLTUs;
00027       updatedLTUs = randomNetwork.updatedLTUs;
00028       updated = updatedLTUs.updated;
00029       denseInputVector = randomNetwork.denseInputVector;
00030     }
00031 
00032     @Override
00033     public void run() {
00034       int currentPosition = offset;
00035       int[] activeIndexes = obs.getActiveIndexes();
00036       while (currentPosition < activeIndexes.length) {
00037         int activeInput = activeIndexes[currentPosition];
00038         LTU[] connected = connectedLTUs[activeInput].array();
00039         updateConnectedLTU(connected);
00040         currentPosition += nbThread;
00041       }
00042     }
00043 
00044     private void updateConnectedLTU(LTU[] connected) {
00045       for (LTU ltu : connected) {
00046         final int index = ltu.index();
00047         if (updated[index])
00048           continue;
00049         updatedLTUs.updateLTUSum(index, ltu, denseInputVector);
00050       }
00051     }
00052   }
00053 
00054   protected class LTUActivationUpdater implements Runnable {
00055     private final int offset;
00056     private final LTU[] ltus;
00057 
00058     LTUActivationUpdater(RandomNetwork randomNetwork, int offset) {
00059       this.offset = offset;
00060       ltus = randomNetwork.ltus;
00061     }
00062 
00063     @Override
00064     public void run() {
00065       int currentPosition = offset;
00066       while (currentPosition < ltus.length) {
00067         final LTU ltu = ltus[currentPosition];
00068         if (ltu != null && ltu.updateActivation())
00069           setOutputOn(currentPosition);
00070         currentPosition += nbThread;
00071       }
00072     }
00073   }
00074 
00075   transient private ExecutorService executor = null;
00076   transient private LTUSumUpdater[] sumUpdaters;
00077   transient private LTUActivationUpdater[] activationUpdaters;
00078   transient private Future<?>[] futurs;
00079   protected final int nbThread;
00080   BinaryVector obs;
00081   BinaryVector output;
00082 
00083   public RandomNetworkScheduler() {
00084     this(Scheduling.getDefaultNbThreads());
00085   }
00086 
00087   public RandomNetworkScheduler(int nbThread) {
00088     this.nbThread = nbThread;
00089   }
00090 
00091   private void initialize(RandomNetwork randomNetwork) {
00092     sumUpdaters = new LTUSumUpdater[nbThread];
00093     activationUpdaters = new LTUActivationUpdater[nbThread];
00094     for (int i = 0; i < nbThread; i++) {
00095       sumUpdaters[i] = new LTUSumUpdater(randomNetwork, i);
00096       activationUpdaters[i] = new LTUActivationUpdater(randomNetwork, i);
00097     }
00098     futurs = new Future<?>[nbThread];
00099     executor = Scheduling.newFixedThreadPool("randomnetwork", nbThread);
00100   }
00101 
00102   public void update(RandomNetwork randomNetwork, BinaryVector obs) {
00103     if (executor == null)
00104       initialize(randomNetwork);
00105     this.obs = obs;
00106     this.output = randomNetwork.output;
00107     for (int i = 0; i < sumUpdaters.length; i++)
00108       futurs[i] = executor.submit(sumUpdaters[i]);
00109     waitWorkingThread();
00110     for (int i = 0; i < sumUpdaters.length; i++)
00111       futurs[i] = executor.submit(activationUpdaters[i]);
00112     waitWorkingThread();
00113   }
00114 
00115   private void waitWorkingThread() {
00116     try {
00117       for (Future<?> futur : futurs)
00118         futur.get();
00119     } catch (InterruptedException e) {
00120       e.printStackTrace();
00121     } catch (ExecutionException e) {
00122       throw new RuntimeException(e.getCause());
00123     }
00124   }
00125 
00126   synchronized final void setOutputOn(int index) {
00127     output.setOn(index);
00128   }
00129 
00130   public void dispose() {
00131     executor.shutdown();
00132     executor = null;
00133   }
00134 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark