RLPark 1.0.0
Reinforcement Learning Framework in Java

SweepSelected.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.experiments.parametersweep;
00002 
00003 import java.util.ArrayList;
00004 import java.util.LinkedHashMap;
00005 import java.util.LinkedHashSet;
00006 import java.util.List;
00007 import java.util.Map;
00008 import java.util.Set;
00009 
00010 import rlpark.plugin.rltoys.experiments.helpers.ExperimentCounter;
00011 import rlpark.plugin.rltoys.experiments.parametersweep.interfaces.Context;
00012 import rlpark.plugin.rltoys.experiments.parametersweep.interfaces.SweepDescriptor;
00013 import rlpark.plugin.rltoys.experiments.parametersweep.onpolicy.AbstractContextOnPolicy;
00014 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.FrozenParameters;
00015 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.Parameters;
00016 import rlpark.plugin.rltoys.experiments.parametersweep.parameters.RunInfo;
00017 import rlpark.plugin.rltoys.experiments.scheduling.schedulers.LocalScheduler;
00018 import rlpark.plugin.rltoys.experiments.scheduling.schedulers.Schedulers;
00019 
00020 public class SweepSelected {
00021   private final SweepDescriptor sweepDescriptor;
00022   private final ExperimentCounter counter;
00023   private final LocalScheduler scheduler = new LocalScheduler();
00024   private final List<FrozenParameters> todoParameters;
00025 
00026   public SweepSelected(List<FrozenParameters> todoParameters, SweepDescriptor sweepDescriptor, ExperimentCounter counter) {
00027     this.counter = counter;
00028     this.sweepDescriptor = sweepDescriptor;
00029     this.todoParameters = todoParameters;
00030   }
00031 
00032   private List<Parameters> createJobsDescription(Context context) {
00033     Set<FrozenParameters> contextTodoParameters = selectConsistentParameters(context);
00034     ArrayList<Parameters> result = new ArrayList<Parameters>();
00035     if (contextTodoParameters.isEmpty())
00036       return result;
00037     List<Parameters> allParameters = sweepDescriptor.provideParameters(context);
00038     for (Parameters parameters : allParameters) {
00039       if (contextTodoParameters.contains(parameters.froze()))
00040         result.add(parameters);
00041     }
00042     return result;
00043   }
00044 
00045   private Set<FrozenParameters> selectConsistentParameters(Context context) {
00046     AbstractContextOnPolicy onPolicyContext = (AbstractContextOnPolicy) context;
00047     String algorithmLabel = onPolicyContext.agentFactory().label();
00048     String problemLabel = onPolicyContext.problemFactory().label();
00049     Set<FrozenParameters> selected = new LinkedHashSet<FrozenParameters>();
00050     for (FrozenParameters parameters : todoParameters)
00051       if (parameters.hasFlag(problemLabel) && parameters.hasFlag(algorithmLabel)) {
00052         Parameters parametersCompleted = new Parameters(parameters);
00053         onPolicyContext.problemFactory().setExperimentParameters(parametersCompleted);
00054         selected.add(parametersCompleted.froze());
00055       }
00056     return selected;
00057   }
00058 
00059   public void generateLearningCurve() {
00060     System.out.println("Preparing job descriptions...");
00061     Map<Context, List<Parameters>> descriptions = new LinkedHashMap<Context, List<Parameters>>();
00062     for (Context context : sweepDescriptor.provideContexts())
00063       descriptions.put(context, createJobsDescription(context));
00064     List<Runnable> jobs = new ArrayList<Runnable>();
00065     while (counter.hasNext()) {
00066       counter.nextExperiment();
00067       for (Map.Entry<Context, List<Parameters>> entry : descriptions.entrySet())
00068         for (Parameters parameters : entry.getValue())
00069           jobs.add(entry.getKey().createJob(parameters, counter));
00070     }
00071     Schedulers.addAll(scheduler, jobs, null);
00072     scheduler.start();
00073     scheduler.waitAll();
00074     scheduler.dispose();
00075   }
00076 
00077   public static List<FrozenParameters> toParametersList(String[] args) {
00078     List<FrozenParameters> results = new ArrayList<FrozenParameters>();
00079     for (String arg : args) {
00080       String[] components = arg.split("_");
00081       if (components.length == 1)
00082         continue;
00083       RunInfo infos = new RunInfo();
00084       infos.enableFlag(components[0]);
00085       infos.enableFlag(components[1]);
00086       Parameters parameters = new Parameters(infos);
00087       for (int i = 1; i < components.length / 2; i++)
00088         parameters.putSweepParam(components[i * 2], Double.parseDouble(components[i * 2 + 1]));
00089       results.add(parameters.froze());
00090     }
00091     return results;
00092   }
00093 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark