RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.horde; 00002 00003 import java.io.Serializable; 00004 import java.util.concurrent.ExecutionException; 00005 import java.util.concurrent.ExecutorService; 00006 import java.util.concurrent.Future; 00007 00008 import rlpark.plugin.rltoys.utils.Scheduling; 00009 00010 public class HordeScheduler implements Serializable { 00011 private static final long serialVersionUID = 6003588160245867945L; 00012 00013 protected class Updater implements Runnable, Serializable { 00014 private static final long serialVersionUID = 3170029744578080040L; 00015 private final int offset; 00016 00017 Updater(int offset) { 00018 this.offset = offset; 00019 } 00020 00021 @Override 00022 public void run() { 00023 int currentPosition = offset; 00024 while (currentPosition < context.nbElements()) { 00025 if (throwable != null) 00026 return; 00027 try { 00028 context.updateElement(currentPosition); 00029 } catch (Throwable throwable) { 00030 HordeScheduler.this.throwable = throwable; 00031 return; 00032 } 00033 currentPosition += nbThread; 00034 } 00035 } 00036 } 00037 00038 public interface Context { 00039 int nbElements(); 00040 00041 void updateElement(int index); 00042 } 00043 00044 00045 transient private ExecutorService executor = null; 00046 private final Updater[] updaters; 00047 Context context; 00048 transient private Future<?>[] futurs; 00049 transient Throwable throwable = null; 00050 protected final int nbThread; 00051 00052 public HordeScheduler() { 00053 this(Scheduling.getDefaultNbThreads()); 00054 } 00055 00056 public HordeScheduler(int nbThread) { 00057 this.nbThread = nbThread; 00058 updaters = new Updater[nbThread]; 00059 for (int i = 0; i < updaters.length; i++) 00060 updaters[i] = new Updater(i); 00061 } 00062 00063 private void initialize() { 00064 futurs = new Future<?>[nbThread]; 00065 executor = Scheduling.newFixedThreadPool("demons", nbThread); 00066 } 00067 00068 public void update(Context context) { 00069 this.context = context; 00070 if (executor == null) 00071 initialize(); 00072 throwable = null; 00073 for (int i = 0; i < updaters.length; i++) 00074 futurs[i] = executor.submit(updaters[i]); 00075 try { 00076 for (Future<?> futur : futurs) 00077 futur.get(); 00078 } catch (InterruptedException e) { 00079 e.printStackTrace(); 00080 } catch (ExecutionException e) { 00081 e.printStackTrace(); 00082 } 00083 this.context = null; 00084 if (throwable != null) 00085 throw new RuntimeException(throwable); 00086 } 00087 }