RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.utils; 00002 00003 import java.io.File; 00004 import java.io.FileInputStream; 00005 import java.io.FileOutputStream; 00006 import java.io.IOException; 00007 import java.io.InputStream; 00008 import java.io.ObjectInputStream; 00009 import java.io.ObjectOutputStream; 00010 import java.io.ObjectStreamClass; 00011 import java.io.OutputStream; 00012 import java.io.Serializable; 00013 import java.lang.reflect.Constructor; 00014 import java.lang.reflect.InvocationTargetException; 00015 import java.lang.reflect.Method; 00016 import java.util.ArrayList; 00017 import java.util.Arrays; 00018 import java.util.Collection; 00019 import java.util.LinkedHashMap; 00020 import java.util.LinkedHashSet; 00021 import java.util.LinkedList; 00022 import java.util.List; 00023 import java.util.Map; 00024 import java.util.Random; 00025 import java.util.Set; 00026 import java.util.zip.GZIPInputStream; 00027 import java.util.zip.GZIPOutputStream; 00028 00029 00030 public class Utils { 00031 final public static double EPSILON = 10e-8; 00032 00033 public static <T> Set<T> asSet(T... ts) { 00034 return new LinkedHashSet<T>(asList(ts)); 00035 } 00036 00037 public static <T> List<T> asList(T... ts) { 00038 return Arrays.asList(ts); 00039 } 00040 00041 public static <T> LinkedList<T> asLinkedList(T... ts) { 00042 return new LinkedList<T>(asList(ts)); 00043 } 00044 00045 public static <T, U> Map<T, U> asMap(T key, U value) { 00046 Map<T, U> result = new LinkedHashMap<T, U>(); 00047 result.put(key, value); 00048 return result; 00049 } 00050 00051 public static boolean checkValue(double value) { 00052 return !Double.isInfinite(value) && !Double.isNaN(value); 00053 } 00054 00055 public static boolean checkProbability(double value) { 00056 return value >= 0 && value <= 1; 00057 } 00058 00059 public static String[] asStringArray(Collection<String> collection) { 00060 String[] result = new String[collection.size()]; 00061 collection.toArray(result); 00062 return result; 00063 } 00064 00065 public static int[] asIntArray(Collection<Integer> collection) { 00066 int[] result = new int[collection.size()]; 00067 int index = 0; 00068 for (Integer i : collection) { 00069 result[index] = i; 00070 index++; 00071 } 00072 return result; 00073 } 00074 00075 public static double[] asDoubleArray(List<Double> list) { 00076 double[] result = new double[list.size()]; 00077 for (int i = 0; i < result.length; i++) 00078 result[i] = list.get(i); 00079 return result; 00080 } 00081 00082 public static <T> T choose(Random random, List<T> list) { 00083 if (random == null) 00084 return list.get(0); 00085 return list.get(random.nextInt(list.size())); 00086 } 00087 00088 public static <T> T choose(Random random, T... elements) { 00089 return elements[random.nextInt(elements.length)]; 00090 } 00091 00092 public static <T> T choose(Random random, Collection<T> set) { 00093 return choose(random, new ArrayList<T>(set)); 00094 } 00095 00096 public static double trunc(double value, double threshold) { 00097 return Math.max(Math.min(value, threshold), -threshold); 00098 } 00099 00100 public static <T> T first(Iterable<T> iterable) { 00101 return iterable.iterator().next(); 00102 } 00103 00104 public static Object[] asObjectArray(Collection<?> collection) { 00105 Object[] result = new Object[collection.size()]; 00106 int i = 0; 00107 for (Object o : collection) { 00108 result[i] = o; 00109 i++; 00110 } 00111 return result; 00112 } 00113 00114 public static int[] range(int imin, int imax) { 00115 int[] result = new int[imax - imin]; 00116 for (int i = 0; i < result.length; i++) 00117 result[i] = imin + i; 00118 return result; 00119 } 00120 00121 public static File createTempFile(String prefix) { 00122 try { 00123 return File.createTempFile(prefix, ""); 00124 } catch (IOException e) { 00125 throw new RuntimeException(e); 00126 } 00127 } 00128 00129 public static boolean checkInstanciated(Object[] array) { 00130 for (Object o : array) 00131 if (o == null) 00132 return false; 00133 return true; 00134 } 00135 00136 static public void notSupported() { 00137 throw new RuntimeException("Operation not supported"); 00138 } 00139 00140 static public void notImplemented() { 00141 throw new RuntimeException("Operation not implemented"); 00142 } 00143 00144 static final public double square(double a) { 00145 return a * a; 00146 } 00147 00148 static public double discountToTimeSteps(double discount) { 00149 assert discount >= 0 && discount < 1.0; 00150 return 1 / (1 - discount); 00151 } 00152 00153 static public double timeStepsToDiscount(int timeSteps) { 00154 assert timeSteps > 0; 00155 return 1.0 - 1.0 / timeSteps; 00156 } 00157 00158 public static String[] concat(String[] array01, String... array02) { 00159 String[] result = new String[array01.length + array02.length]; 00160 System.arraycopy(array01, 0, result, 0, array01.length); 00161 System.arraycopy(array02, 0, result, array01.length, array02.length); 00162 return result; 00163 } 00164 00165 public static double[] concat(double[] array01, double... array02) { 00166 double[] result = new double[array01.length + array02.length]; 00167 System.arraycopy(array01, 0, result, 0, array01.length); 00168 System.arraycopy(array02, 0, result, array01.length, array02.length); 00169 return result; 00170 } 00171 00172 public static <T> T newInstance(Class<T> type, Object... args) { 00173 Class<?>[] classArgs = new Class<?>[args.length]; 00174 for (int i = 0; i < classArgs.length; i++) 00175 classArgs[i] = args[i].getClass(); 00176 return newInstance(type, classArgs, args); 00177 } 00178 00179 public static <T> T newInstance(Class<T> type, Class<?>[] classArgs, Object... args) { 00180 Constructor<T> constructor = null; 00181 try { 00182 constructor = type.getConstructor(classArgs); 00183 } catch (SecurityException e) { 00184 e.printStackTrace(); 00185 return null; 00186 } catch (NoSuchMethodException e) { 00187 e.printStackTrace(); 00188 return null; 00189 } 00190 try { 00191 return constructor.newInstance(args); 00192 } catch (IllegalArgumentException e) { 00193 e.printStackTrace(); 00194 } catch (InstantiationException e) { 00195 e.printStackTrace(); 00196 } catch (IllegalAccessException e) { 00197 e.printStackTrace(); 00198 } catch (InvocationTargetException e) { 00199 e.printStackTrace(); 00200 } 00201 return null; 00202 } 00203 00204 public static void save(Serializable serialized, String filepath) { 00205 save(serialized, new File(filepath)); 00206 } 00207 00208 public static void save(Serializable serialized, File file) { 00209 try { 00210 OutputStream fout = new FileOutputStream(file); 00211 if (file.getName().endsWith(".gz")) 00212 fout = new GZIPOutputStream(fout); 00213 ObjectOutputStream out = new ObjectOutputStream(fout); 00214 out.writeObject(serialized); 00215 out.close(); 00216 } catch (IOException e) { 00217 throw new RuntimeException(e); 00218 } 00219 } 00220 00221 public static Object load(String filepath) { 00222 return load(new File(filepath)); 00223 } 00224 00225 // 00226 // public static Object load(File file) { 00227 // try { 00228 // FileInputStream fis = new FileInputStream(file); 00229 // ObjectInputStream in = new ObjectInputStream(fis); 00230 // Object serialized = in.readObject(); 00231 // in.close(); 00232 // return serialized; 00233 // } catch (IOException e) { 00234 // throw new RuntimeException(e); 00235 // } catch (ClassNotFoundException e) { 00236 // throw new RuntimeException(e); 00237 // } 00238 // } 00239 00240 public static Object load(File file) { 00241 return load(file, new ClassLoader[] { Thread.currentThread().getContextClassLoader() }); 00242 } 00243 00244 public static Object load(File file, final ClassLoader... loaders) { 00245 try { 00246 InputStream fis = new FileInputStream(file); 00247 if (file.getName().endsWith(".gz")) 00248 fis = new GZIPInputStream(fis); 00249 ObjectInputStream oIn = new ObjectInputStream(fis) { 00250 @Override 00251 protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { 00252 String className = desc.getName(); 00253 try { 00254 return Class.forName(className); 00255 } catch (ClassNotFoundException exc) { 00256 for (ClassLoader cl : loaders) { 00257 try { 00258 return cl.loadClass(className); 00259 } catch (ClassNotFoundException e) { 00260 } 00261 } 00262 throw new ClassNotFoundException(className); 00263 } 00264 } 00265 }; 00266 Object serialized = oIn.readObject(); 00267 oIn.close(); 00268 return serialized; 00269 } catch (IOException e) { 00270 throw new RuntimeException(e); 00271 } catch (ClassNotFoundException e) { 00272 throw new RuntimeException(e); 00273 } 00274 } 00275 00276 static public double[] newFilledArray(int length, double value) { 00277 double[] result = new double[length]; 00278 Arrays.fill(result, value); 00279 return result; 00280 } 00281 00282 @SuppressWarnings("unchecked") 00283 public static <T> T clone(T t) { 00284 Method method; 00285 try { 00286 method = t.getClass().getMethod("clone"); 00287 } catch (SecurityException e) { 00288 e.printStackTrace(); 00289 return null; 00290 } catch (NoSuchMethodException e) { 00291 e.printStackTrace(); 00292 return null; 00293 } 00294 try { 00295 return (T) method.invoke(t); 00296 } catch (IllegalArgumentException e) { 00297 e.printStackTrace(); 00298 } catch (IllegalAccessException e) { 00299 e.printStackTrace(); 00300 } catch (InvocationTargetException e) { 00301 e.printStackTrace(); 00302 } 00303 return null; 00304 } 00305 }