RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.parametersweep; 00002 00003 import java.util.ArrayList; 00004 import java.util.LinkedHashMap; 00005 import java.util.LinkedHashSet; 00006 import java.util.List; 00007 import java.util.Map; 00008 import java.util.Set; 00009 00010 import rlpark.plugin.rltoys.experiments.helpers.ExperimentCounter; 00011 import rlpark.plugin.rltoys.experiments.parametersweep.interfaces.Context; 00012 import rlpark.plugin.rltoys.experiments.parametersweep.interfaces.SweepDescriptor; 00013 import rlpark.plugin.rltoys.experiments.parametersweep.onpolicy.AbstractContextOnPolicy; 00014 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.FrozenParameters; 00015 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.Parameters; 00016 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.RunInfo; 00017 import rlpark.plugin.rltoys.experiments.scheduling.schedulers.LocalScheduler; 00018 import rlpark.plugin.rltoys.experiments.scheduling.schedulers.Schedulers; 00019 00020 public class SweepSelected { 00021 private final SweepDescriptor sweepDescriptor; 00022 private final ExperimentCounter counter; 00023 private final LocalScheduler scheduler = new LocalScheduler(); 00024 private final List<FrozenParameters> todoParameters; 00025 00026 public SweepSelected(List<FrozenParameters> todoParameters, SweepDescriptor sweepDescriptor, ExperimentCounter counter) { 00027 this.counter = counter; 00028 this.sweepDescriptor = sweepDescriptor; 00029 this.todoParameters = todoParameters; 00030 } 00031 00032 private List<Parameters> createJobsDescription(Context context) { 00033 Set<FrozenParameters> contextTodoParameters = selectConsistentParameters(context); 00034 ArrayList<Parameters> result = new ArrayList<Parameters>(); 00035 if (contextTodoParameters.isEmpty()) 00036 return result; 00037 List<Parameters> allParameters = sweepDescriptor.provideParameters(context); 00038 for (Parameters parameters : allParameters) { 00039 if (contextTodoParameters.contains(parameters.froze())) 00040 result.add(parameters); 00041 } 00042 return result; 00043 } 00044 00045 private Set<FrozenParameters> selectConsistentParameters(Context context) { 00046 AbstractContextOnPolicy onPolicyContext = (AbstractContextOnPolicy) context; 00047 String algorithmLabel = onPolicyContext.agentFactory().label(); 00048 String problemLabel = onPolicyContext.problemFactory().label(); 00049 Set<FrozenParameters> selected = new LinkedHashSet<FrozenParameters>(); 00050 for (FrozenParameters parameters : todoParameters) 00051 if (parameters.hasFlag(problemLabel) && parameters.hasFlag(algorithmLabel)) { 00052 Parameters parametersCompleted = new Parameters(parameters); 00053 onPolicyContext.problemFactory().setExperimentParameters(parametersCompleted); 00054 selected.add(parametersCompleted.froze()); 00055 } 00056 return selected; 00057 } 00058 00059 public void generateLearningCurve() { 00060 System.out.println("Preparing job descriptions..."); 00061 Map<Context, List<Parameters>> descriptions = new LinkedHashMap<Context, List<Parameters>>(); 00062 for (Context context : sweepDescriptor.provideContexts()) 00063 descriptions.put(context, createJobsDescription(context)); 00064 List<Runnable> jobs = new ArrayList<Runnable>(); 00065 while (counter.hasNext()) { 00066 counter.nextExperiment(); 00067 for (Map.Entry<Context, List<Parameters>> entry : descriptions.entrySet()) 00068 for (Parameters parameters : entry.getValue()) 00069 jobs.add(entry.getKey().createJob(parameters, counter)); 00070 } 00071 Schedulers.addAll(scheduler, jobs, null); 00072 scheduler.start(); 00073 scheduler.waitAll(); 00074 scheduler.dispose(); 00075 } 00076 00077 public static List<FrozenParameters> toParametersList(String[] args) { 00078 List<FrozenParameters> results = new ArrayList<FrozenParameters>(); 00079 for (String arg : args) { 00080 String[] components = arg.split("_"); 00081 if (components.length == 1) 00082 continue; 00083 RunInfo infos = new RunInfo(); 00084 infos.enableFlag(components[0]); 00085 infos.enableFlag(components[1]); 00086 Parameters parameters = new Parameters(infos); 00087 for (int i = 1; i < components.length / 2; i++) 00088 parameters.putSweepParam(components[i * 2], Double.parseDouble(components[i * 2 + 1])); 00089 results.add(parameters.froze()); 00090 } 00091 return results; 00092 } 00093 }