RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.envio.policy; 00002 00003 import java.util.Map; 00004 import java.util.Random; 00005 00006 import rlpark.plugin.rltoys.envio.actions.Action; 00007 import rlpark.plugin.rltoys.envio.actions.Actions; 00008 import rlpark.plugin.rltoys.utils.Utils; 00009 import zephyr.plugin.core.api.labels.Labels; 00010 import zephyr.plugin.core.api.monitoring.abstracts.DataMonitor; 00011 import zephyr.plugin.core.api.monitoring.abstracts.MonitorContainer; 00012 import zephyr.plugin.core.api.monitoring.abstracts.Monitored; 00013 00014 public abstract class StochasticPolicy implements DiscreteActionPolicy, MonitorContainer { 00015 private static final long serialVersionUID = 6747532059495537542L; 00016 protected final Random random; 00017 protected final Action[] actions; 00018 protected final Map<Action, Integer> actionToIndex; 00019 00020 public StochasticPolicy(Random random, Action[] actions) { 00021 this.random = random; 00022 this.actions = actions; 00023 actionToIndex = Actions.createActionIntMap(actions); 00024 } 00025 00026 protected int atoi(Action a) { 00027 return actionToIndex.get(a); 00028 } 00029 00030 protected Action chooseAction(double[] distribution) { 00031 assert checkDistribution(distribution); 00032 double randomValue = random.nextDouble(); 00033 double sum = 0; 00034 for (int i = 0; i < distribution.length - 1; i++) { 00035 sum += distribution[i]; 00036 if (!Utils.checkValue(sum)) 00037 return null; 00038 if (sum >= randomValue) 00039 return actions[i]; 00040 } 00041 return actions[actions.length - 1]; 00042 } 00043 00044 public static boolean checkDistribution(double[] distribution) { 00045 double sum = 0.0; 00046 for (double value : distribution) 00047 sum += value; 00048 return Math.abs(1.0 - sum) < Utils.EPSILON; 00049 } 00050 00051 @Override 00052 public Action[] actions() { 00053 return actions; 00054 } 00055 00056 public abstract double[] distribution(); 00057 00058 @Override 00059 public double[] values() { 00060 return distribution(); 00061 } 00062 00063 @Override 00064 public void addToMonitor(DataMonitor monitor) { 00065 for (int i = 0; i < actions.length; i++) { 00066 final int a_i = i; 00067 monitor.add(Labels.label(actions[i]), new Monitored() { 00068 @Override 00069 public double monitoredValue() { 00070 return distribution()[a_i]; 00071 } 00072 }); 00073 } 00074 } 00075 }