RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.representations.ltu.networks; 00002 00003 import java.io.Serializable; 00004 import java.util.concurrent.ExecutionException; 00005 import java.util.concurrent.ExecutorService; 00006 import java.util.concurrent.Future; 00007 00008 import rlpark.plugin.rltoys.algorithms.representations.ltu.internal.LTUArray; 00009 import rlpark.plugin.rltoys.algorithms.representations.ltu.internal.LTUUpdated; 00010 import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU; 00011 import rlpark.plugin.rltoys.math.vector.BinaryVector; 00012 import rlpark.plugin.rltoys.utils.Scheduling; 00013 00014 public class RandomNetworkScheduler implements Serializable { 00015 private static final long serialVersionUID = -2515509378000478726L; 00016 00017 protected class LTUSumUpdater implements Runnable { 00018 private final int offset; 00019 private final LTUArray[] connectedLTUs; 00020 private final LTUUpdated updatedLTUs; 00021 private final double[] denseInputVector; 00022 private final boolean[] updated; 00023 00024 LTUSumUpdater(RandomNetwork randomNetwork, int offset) { 00025 this.offset = offset; 00026 connectedLTUs = randomNetwork.connectedLTUs; 00027 updatedLTUs = randomNetwork.updatedLTUs; 00028 updated = updatedLTUs.updated; 00029 denseInputVector = randomNetwork.denseInputVector; 00030 } 00031 00032 @Override 00033 public void run() { 00034 int currentPosition = offset; 00035 int[] activeIndexes = obs.getActiveIndexes(); 00036 while (currentPosition < activeIndexes.length) { 00037 int activeInput = activeIndexes[currentPosition]; 00038 LTU[] connected = connectedLTUs[activeInput].array(); 00039 updateConnectedLTU(connected); 00040 currentPosition += nbThread; 00041 } 00042 } 00043 00044 private void updateConnectedLTU(LTU[] connected) { 00045 for (LTU ltu : connected) { 00046 final int index = ltu.index(); 00047 if (updated[index]) 00048 continue; 00049 updatedLTUs.updateLTUSum(index, ltu, denseInputVector); 00050 } 00051 } 00052 } 00053 00054 protected class LTUActivationUpdater implements Runnable { 00055 private final int offset; 00056 private final LTU[] ltus; 00057 00058 LTUActivationUpdater(RandomNetwork randomNetwork, int offset) { 00059 this.offset = offset; 00060 ltus = randomNetwork.ltus; 00061 } 00062 00063 @Override 00064 public void run() { 00065 int currentPosition = offset; 00066 while (currentPosition < ltus.length) { 00067 final LTU ltu = ltus[currentPosition]; 00068 if (ltu != null && ltu.updateActivation()) 00069 setOutputOn(currentPosition); 00070 currentPosition += nbThread; 00071 } 00072 } 00073 } 00074 00075 transient private ExecutorService executor = null; 00076 transient private LTUSumUpdater[] sumUpdaters; 00077 transient private LTUActivationUpdater[] activationUpdaters; 00078 transient private Future<?>[] futurs; 00079 protected final int nbThread; 00080 BinaryVector obs; 00081 BinaryVector output; 00082 00083 public RandomNetworkScheduler() { 00084 this(Scheduling.getDefaultNbThreads()); 00085 } 00086 00087 public RandomNetworkScheduler(int nbThread) { 00088 this.nbThread = nbThread; 00089 } 00090 00091 private void initialize(RandomNetwork randomNetwork) { 00092 sumUpdaters = new LTUSumUpdater[nbThread]; 00093 activationUpdaters = new LTUActivationUpdater[nbThread]; 00094 for (int i = 0; i < nbThread; i++) { 00095 sumUpdaters[i] = new LTUSumUpdater(randomNetwork, i); 00096 activationUpdaters[i] = new LTUActivationUpdater(randomNetwork, i); 00097 } 00098 futurs = new Future<?>[nbThread]; 00099 executor = Scheduling.newFixedThreadPool("randomnetwork", nbThread); 00100 } 00101 00102 public void update(RandomNetwork randomNetwork, BinaryVector obs) { 00103 if (executor == null) 00104 initialize(randomNetwork); 00105 this.obs = obs; 00106 this.output = randomNetwork.output; 00107 for (int i = 0; i < sumUpdaters.length; i++) 00108 futurs[i] = executor.submit(sumUpdaters[i]); 00109 waitWorkingThread(); 00110 for (int i = 0; i < sumUpdaters.length; i++) 00111 futurs[i] = executor.submit(activationUpdaters[i]); 00112 waitWorkingThread(); 00113 } 00114 00115 private void waitWorkingThread() { 00116 try { 00117 for (Future<?> futur : futurs) 00118 futur.get(); 00119 } catch (InterruptedException e) { 00120 e.printStackTrace(); 00121 } catch (ExecutionException e) { 00122 throw new RuntimeException(e.getCause()); 00123 } 00124 } 00125 00126 synchronized final void setOutputOn(int index) { 00127 output.setOn(index); 00128 } 00129 00130 public void dispose() { 00131 executor.shutdown(); 00132 executor = null; 00133 } 00134 }