RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.scheduling.network; 00002 00003 import java.net.Socket; 00004 import java.util.ArrayList; 00005 import java.util.HashMap; 00006 import java.util.HashSet; 00007 import java.util.Map; 00008 import java.util.Random; 00009 import java.util.Set; 00010 00011 import rlpark.plugin.rltoys.experiments.scheduling.interfaces.JobDoneEvent; 00012 import rlpark.plugin.rltoys.experiments.scheduling.interfaces.JobQueue; 00013 import rlpark.plugin.rltoys.experiments.scheduling.internal.messages.ClientInfo; 00014 import rlpark.plugin.rltoys.experiments.scheduling.internal.messages.MessageJob; 00015 import rlpark.plugin.rltoys.experiments.scheduling.internal.messages.Messages; 00016 import rlpark.plugin.rltoys.experiments.scheduling.internal.network.NetworkClassLoader; 00017 import rlpark.plugin.rltoys.experiments.scheduling.internal.network.SyncSocket; 00018 import rlpark.plugin.rltoys.experiments.scheduling.queue.LocalQueue; 00019 import zephyr.plugin.core.api.signals.Signal; 00020 import zephyr.plugin.core.api.synchronization.Chrono; 00021 00022 public class NetworkJobQueue implements JobQueue { 00023 private static final double MessagePeriod = 1800; 00024 private final SyncSocket syncSocket; 00025 private final Map<Runnable, Integer> jobToId = new HashMap<Runnable, Integer>(); 00026 private final NetworkClassLoader classLoader; 00027 private final Chrono chrono = new Chrono(); 00028 private final Signal<JobDoneEvent> onJobDone = new Signal<JobDoneEvent>(); 00029 private int nbJobsSinceLastMessage = 0; 00030 private boolean denyNewJobRequest = false; 00031 private final LocalQueue localQueue = new LocalQueue(); 00032 00033 public NetworkJobQueue(String serverHostName, int port, int nbCore, boolean multipleConnectionAttempts) { 00034 Socket socket = connectToServer(serverHostName, port, multipleConnectionAttempts); 00035 syncSocket = new SyncSocket(socket); 00036 syncSocket.sendClientInfo(new ClientInfo(nbCore)); 00037 classLoader = NetworkClassLoader.newClassLoader(syncSocket); 00038 } 00039 00040 private void requestJobsToServer() { 00041 MessageJob messageJobTodo = syncSocket.jobTransaction(classLoader); 00042 if (messageJobTodo == null || messageJobTodo.nbJobs() == 0) 00043 return; 00044 Runnable[] jobs = messageJobTodo.jobs(); 00045 int[] ids = messageJobTodo.jobIds(); 00046 Set<Integer> addedIds = new HashSet<Integer>(); 00047 ArrayList<Runnable> newJobs = new ArrayList<Runnable>(); 00048 for (int i = 0; i < jobs.length; i++) { 00049 if (addedIds.contains(ids[i])) 00050 continue; 00051 jobToId.put(jobs[i], ids[i]); 00052 newJobs.add(jobs[i]); 00053 } 00054 localQueue.add(newJobs.iterator(), null); 00055 } 00056 00057 @Override 00058 synchronized public Runnable request() { 00059 if (denyNewJobRequest) 00060 return null; 00061 Runnable job = localQueue.request(); 00062 if (job != null) 00063 return job; 00064 requestJobsToServer(); 00065 return localQueue.request(); 00066 } 00067 00068 @Override 00069 synchronized public void done(Runnable todo, Runnable done) { 00070 Integer jobId = jobToId.remove(todo); 00071 if (jobId != null) 00072 jobDone(done, jobId); 00073 if (localQueue.areAllDone()) 00074 requestJobsToServer(); 00075 onJobDone.fire(new JobDoneEvent(todo, done)); 00076 } 00077 00078 private void jobDone(Runnable done, int jobId) { 00079 syncSocket.write(new MessageJob(jobId, done)); 00080 nbJobsSinceLastMessage += 1; 00081 if (chrono.getCurrentChrono() > MessagePeriod) { 00082 Messages.println(nbJobsSinceLastMessage / chrono.getCurrentChrono() + " jobs per seconds"); 00083 chrono.start(); 00084 nbJobsSinceLastMessage = 0; 00085 } 00086 } 00087 00088 public boolean canAnswerJobRequest() { 00089 return !syncSocket.isClosed() && !denyNewJobRequest; 00090 } 00091 00092 @Override 00093 public Signal<JobDoneEvent> onJobDone() { 00094 return onJobDone; 00095 } 00096 00097 public void denyNewJobRequest() { 00098 denyNewJobRequest = true; 00099 } 00100 00101 public NetworkClassLoader classLoader() { 00102 return classLoader; 00103 } 00104 00105 static private Socket connectToServer(String serverHostName, int port, boolean multipleAttempts) { 00106 Socket socket = null; 00107 Random random = null; 00108 Exception lastException = null; 00109 Chrono connectionTime = new Chrono(); 00110 while (socket == null) { 00111 try { 00112 if (lastException != null) 00113 System.err.println("Retrying to connect..."); 00114 socket = new Socket(serverHostName, port); 00115 } catch (Exception e) { 00116 lastException = e; 00117 if (!multipleAttempts) 00118 break; 00119 if (random == null) 00120 random = new Random(); 00121 if (connectionTime.getCurrentChrono() > 3600) 00122 break; 00123 sleepForConnection(random, 120); 00124 } 00125 } 00126 if (socket == null && lastException != null) 00127 throw new RuntimeException(lastException); 00128 if (socket != null && lastException != null) 00129 System.err.println("Finally connected"); 00130 return socket; 00131 } 00132 00133 private static void sleepForConnection(Random random, int maxWaitingTime) { 00134 long sleepingTime = (long) (random.nextDouble() * maxWaitingTime + 5); 00135 System.err.println(sleepingTime + "s of sleeping time before another attempt to connect"); 00136 try { 00137 Thread.sleep(sleepingTime * 1000); 00138 } catch (InterruptedException e) { 00139 } 00140 } 00141 00142 @Override 00143 public void dispose() { 00144 syncSocket.close(); 00145 localQueue.dispose(); 00146 classLoader.dispose(); 00147 } 00148 }