RLPark 1.0.0
Reinforcement Learning Framework in Java

StochasticPolicy.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.envio.policy;
00002 
00003 import java.util.Map;
00004 import java.util.Random;
00005 
00006 import rlpark.plugin.rltoys.envio.actions.Action;
00007 import rlpark.plugin.rltoys.envio.actions.Actions;
00008 import rlpark.plugin.rltoys.utils.Utils;
00009 import zephyr.plugin.core.api.labels.Labels;
00010 import zephyr.plugin.core.api.monitoring.abstracts.DataMonitor;
00011 import zephyr.plugin.core.api.monitoring.abstracts.MonitorContainer;
00012 import zephyr.plugin.core.api.monitoring.abstracts.Monitored;
00013 
00014 public abstract class StochasticPolicy implements DiscreteActionPolicy, MonitorContainer {
00015   private static final long serialVersionUID = 6747532059495537542L;
00016   protected final Random random;
00017   protected final Action[] actions;
00018   protected final Map<Action, Integer> actionToIndex;
00019 
00020   public StochasticPolicy(Random random, Action[] actions) {
00021     this.random = random;
00022     this.actions = actions;
00023     actionToIndex = Actions.createActionIntMap(actions);
00024   }
00025 
00026   protected int atoi(Action a) {
00027     return actionToIndex.get(a);
00028   }
00029 
00030   protected Action chooseAction(double[] distribution) {
00031     assert checkDistribution(distribution);
00032     double randomValue = random.nextDouble();
00033     double sum = 0;
00034     for (int i = 0; i < distribution.length - 1; i++) {
00035       sum += distribution[i];
00036       if (!Utils.checkValue(sum))
00037         return null;
00038       if (sum >= randomValue)
00039         return actions[i];
00040     }
00041     return actions[actions.length - 1];
00042   }
00043 
00044   public static boolean checkDistribution(double[] distribution) {
00045     double sum = 0.0;
00046     for (double value : distribution)
00047       sum += value;
00048     return Math.abs(1.0 - sum) < Utils.EPSILON;
00049   }
00050 
00051   @Override
00052   public Action[] actions() {
00053     return actions;
00054   }
00055 
00056   public abstract double[] distribution();
00057 
00058   @Override
00059   public double[] values() {
00060     return distribution();
00061   }
00062 
00063   @Override
00064   public void addToMonitor(DataMonitor monitor) {
00065     for (int i = 0; i < actions.length; i++) {
00066       final int a_i = i;
00067       monitor.add(Labels.label(actions[i]), new Monitored() {
00068         @Override
00069         public double monitoredValue() {
00070           return distribution()[a_i];
00071         }
00072       });
00073     }
00074   }
00075 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark