RLPark 1.0.0
Reinforcement Learning Framework in Java

PuddleWorld.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.problems.puddleworld;
00002 
00003 import java.util.Arrays;
00004 import java.util.Random;
00005 
00006 import rlpark.plugin.rltoys.algorithms.functions.ContinuousFunction;
00007 import rlpark.plugin.rltoys.envio.actions.Action;
00008 import rlpark.plugin.rltoys.envio.actions.ActionArray;
00009 import rlpark.plugin.rltoys.envio.observations.Legend;
00010 import rlpark.plugin.rltoys.envio.rl.TRStep;
00011 import rlpark.plugin.rltoys.math.ranges.Range;
00012 import rlpark.plugin.rltoys.problems.ProblemBounded;
00013 import rlpark.plugin.rltoys.problems.ProblemContinuousAction;
00014 import rlpark.plugin.rltoys.problems.ProblemDiscreteAction;
00015 import rlpark.plugin.rltoys.utils.Utils;
00016 import zephyr.plugin.core.api.monitoring.abstracts.DataMonitor;
00017 import zephyr.plugin.core.api.monitoring.abstracts.MonitorContainer;
00018 import zephyr.plugin.core.api.monitoring.abstracts.Monitored;
00019 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00020 
00021 public class PuddleWorld implements ProblemBounded, ProblemDiscreteAction, ProblemContinuousAction,
00022     MonitorContainer {
00023   private final Action[] actions;
00024   protected TRStep step = null;
00025   private final int nbDimensions;
00026   private double[] start = null;
00027   @Monitor
00028   private ContinuousFunction rewardFunction = null;
00029   private final Legend legend;
00030   private final Random random;
00031   private TerminationFunction terminationFunction = null;
00032   private final Range observationRange;
00033   private final Range actionRange;
00034   private final double absoluteNoise;
00035   @Monitor
00036   private final double[] lastActions;
00037 
00038   public PuddleWorld(Random random, int nbDimension, Range observationRange, Range actionRange,
00039       double relativeNoise) {
00040     this.random = random;
00041     this.observationRange = observationRange;
00042     this.actionRange = actionRange;
00043     this.nbDimensions = nbDimension;
00044     this.absoluteNoise = (actionRange.length() / 2.0) * relativeNoise;
00045     legend = createLegend();
00046     actions = createActions();
00047     lastActions = new double[nbDimension];
00048   }
00049 
00050   private Action[] createActions() {
00051     Action[] actions = new Action[2 * nbDimensions + 1];
00052     for (int i = 0; i < actions.length - 1; i++) {
00053       int dimension = i / 2;
00054       int dimensionAction = i % 2;
00055       double[] actionValues = Utils.newFilledArray(nbDimensions, 0);
00056       if (dimensionAction == 0)
00057         actionValues[dimension] = -1;
00058       else
00059         actionValues[dimension] = 1;
00060       actions[i] = new ActionArray(actionValues);
00061     }
00062     actions[actions.length - 1] = new ActionArray(Utils.newFilledArray(nbDimensions, 0));
00063     return actions;
00064   }
00065 
00066   public void setStart(double[] start) {
00067     this.start = start;
00068   }
00069 
00070   public void setRewardFunction(ContinuousFunction rewardFunction) {
00071     this.rewardFunction = rewardFunction;
00072   }
00073 
00074   public void setTermination(TerminationFunction terminationFunction) {
00075     this.terminationFunction = terminationFunction;
00076   }
00077 
00078   private Legend createLegend() {
00079     String[] labels = new String[nbDimensions];
00080     for (int i = 0; i < nbDimensions; i++)
00081       labels[i] = "x" + i;
00082     return new Legend(labels);
00083   }
00084 
00085   @Override
00086   public TRStep initialize() {
00087     double[] position = start;
00088     if (position == null) {
00089       position = new double[nbDimensions];
00090       do {
00091         for (int i = 0; i < position.length; i++)
00092           position[i] = observationRange.choose(random);
00093       } while (isTerminated(position));
00094     }
00095     step = new TRStep(position, reward(position));
00096     return step;
00097   }
00098 
00099   @Override
00100   public TRStep step(Action action) {
00101     if (isTerminated(step.o_tp1)) {
00102       step = step.createEndingStep();
00103       return step;
00104     }
00105     double[] envAction = computeEnvironmentAction(action);
00106     double[] x_tp1 = new double[nbDimensions];
00107     for (int i = 0; i < x_tp1.length; i++)
00108       x_tp1[i] = observationRange.bound(step.o_tp1[i] + envAction[i]);
00109     step = new TRStep(step, action, x_tp1, reward(x_tp1));
00110     return step;
00111   }
00112 
00113   private double reward(double[] position) {
00114     if (rewardFunction == null)
00115       return 0.0;
00116     return rewardFunction.value(position);
00117   }
00118 
00119   private boolean isTerminated(double[] position) {
00120     if (terminationFunction == null)
00121       return false;
00122     return terminationFunction.isTerminated(position);
00123   }
00124 
00125   private double[] computeEnvironmentAction(Action action) {
00126     double[] agentAction = ((ActionArray) action).actions;
00127     System.arraycopy(agentAction, 0, lastActions, 0, nbDimensions);
00128     double[] envAction = new double[agentAction.length];
00129     for (int i = 0; i < envAction.length; i++) {
00130       double noise = (random.nextDouble() * absoluteNoise) - (absoluteNoise / 2);
00131       envAction[i] = actionRange.bound(agentAction[i]) + noise;
00132     }
00133     return envAction;
00134   }
00135 
00136   @Override
00137   public Legend legend() {
00138     return legend;
00139   }
00140 
00141   @Override
00142   public Range[] actionRanges() {
00143     Range[] ranges = new Range[nbDimensions];
00144     Arrays.fill(ranges, actionRange);
00145     return ranges;
00146   }
00147 
00148   @Override
00149   public Range[] getObservationRanges() {
00150     Range[] ranges = new Range[nbDimensions];
00151     Arrays.fill(ranges, observationRange);
00152     return ranges;
00153   }
00154 
00155   @Override
00156   public void addToMonitor(DataMonitor monitor) {
00157     monitor.add("Reward", new Monitored() {
00158       @Override
00159       public double monitoredValue() {
00160         return step != null ? step.r_tp1 : 0.0;
00161       }
00162     });
00163     for (int i = 0; i < legend.nbLabels(); i++) {
00164       final int index = i;
00165       monitor.add(legend.label(i), new Monitored() {
00166         @Override
00167         public double monitoredValue() {
00168           return step != null && step.o_tp1 != null ? step.o_tp1[index] : 0.0;
00169         }
00170       });
00171     }
00172   }
00173 
00174   public int nbDimensions() {
00175     return nbDimensions;
00176   }
00177 
00178   public ContinuousFunction rewardFunction() {
00179     return rewardFunction;
00180   }
00181 
00182   public double[] start() {
00183     return start;
00184   }
00185 
00186   @Override
00187   public Action[] actions() {
00188     return actions;
00189   }
00190 
00191   @Override
00192   public TRStep lastStep() {
00193     return step;
00194   }
00195 
00196   @Override
00197   public TRStep forceEndEpisode() {
00198     step = step.createEndingStep();
00199     return step;
00200   }
00201 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark