RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.experiments.parametersweep.parameters; 00002 00003 import java.io.Serializable; 00004 import java.util.HashMap; 00005 import java.util.LinkedHashMap; 00006 import java.util.Map; 00007 00008 public abstract class AbstractParameters implements Comparable<AbstractParameters>, Serializable { 00009 private static final long serialVersionUID = 8135997315567194984L; 00010 protected final Map<String, Double> parameters = new LinkedHashMap<String, Double>(); 00011 protected final Map<String, Double> results = new LinkedHashMap<String, Double>(); 00012 private final RunInfo infos; 00013 00014 public AbstractParameters(RunInfo infos) { 00015 this(infos, new HashMap<String, Double>(), new HashMap<String, Double>()); 00016 assert infos != null; 00017 } 00018 00019 public AbstractParameters(RunInfo infos, Map<String, Double> parameters, Map<String, Double> results) { 00020 this.infos = infos; 00021 this.parameters.putAll(parameters); 00022 this.results.putAll(results); 00023 } 00024 00025 public void putResult(String parameterName, double parameterValue) { 00026 results.put(parameterName, parameterValue); 00027 } 00028 00029 public boolean hasKey(String key) { 00030 return parameters.containsKey(key) || infos.hasKey(key); 00031 } 00032 00033 public double get(String name, double defaultValue) { 00034 Double value = getDouble(name); 00035 return value != null ? value : defaultValue; 00036 } 00037 00038 public double get(String name) { 00039 return getDouble(name); 00040 } 00041 00042 public Double getDouble(String name) { 00043 Double parameterValue = parameters.get(name); 00044 if (parameterValue != null) 00045 return parameterValue; 00046 parameterValue = infos().get(name); 00047 if (parameterValue != null) 00048 return parameterValue; 00049 return results.get(name); 00050 } 00051 00052 @Override 00053 public String toString() { 00054 StringBuilder result = new StringBuilder(parameters.toString()); 00055 if (!results.isEmpty()) 00056 result.append("=" + results.toString()); 00057 return result.toString(); 00058 } 00059 00060 private Map<String, Double> getAllSweepValues() { 00061 Map<String, Double> all = new LinkedHashMap<String, Double>(parameters); 00062 all.putAll(results); 00063 return all; 00064 } 00065 00066 public String[] labels() { 00067 Map<String, Double> all = getAllSweepValues(); 00068 String[] result = new String[all.size()]; 00069 all.keySet().toArray(result); 00070 return result; 00071 } 00072 00073 public double[] values() { 00074 Map<String, Double> all = getAllSweepValues(); 00075 double[] result = new double[all.size()]; 00076 int index = 0; 00077 for (Double value : all.values()) { 00078 result[index] = value; 00079 index++; 00080 } 00081 return result; 00082 } 00083 00084 @Override 00085 public int compareTo(AbstractParameters other) { 00086 for (Map.Entry<String, Double> entry : other.parameters.entrySet()) { 00087 Double d1 = parameters.get(entry.getKey()); 00088 if (d1 == null) 00089 return -1; 00090 Double d2 = entry.getValue(); 00091 if (d2 == null) 00092 return 1; 00093 int compared = Double.compare(d1, d2); 00094 if (compared != 0) 00095 return compared; 00096 } 00097 return 0; 00098 } 00099 00100 public RunInfo infos() { 00101 return infos; 00102 } 00103 00104 public boolean hasFlag(String flag) { 00105 return infos().hasFlag(flag); 00106 } 00107 }