RLPark 1.0.0
Reinforcement Learning Framework in Java

RewardObservationFunction.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.horde.functions;
00002 
00003 import rlpark.plugin.rltoys.envio.actions.Action;
00004 import rlpark.plugin.rltoys.envio.observations.Legend;
00005 import rlpark.plugin.rltoys.envio.observations.ObsAsDoubles;
00006 import rlpark.plugin.rltoys.envio.observations.Observation;
00007 import rlpark.plugin.rltoys.math.vector.RealVector;
00008 import zephyr.plugin.core.api.labels.Labeled;
00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00010 
00011 public class RewardObservationFunction implements RewardFunction, Labeled, HordeUpdatable {
00012   private static final long serialVersionUID = -5930168576876015871L;
00013   @Monitor
00014   protected double reward;
00015   private final int observationIndex;
00016   private final String label;
00017 
00018   public RewardObservationFunction(Legend legend, String label) {
00019     this.label = label;
00020     observationIndex = legend.indexOf(label);
00021     if (observationIndex < 0)
00022       throw new RuntimeException(label + " not found in the legend");
00023   }
00024 
00025   public void update(double[] o) {
00026     reward = o != null ? o[observationIndex] : 0.0;
00027   }
00028 
00029   @Override
00030   public double reward() {
00031     return reward;
00032   }
00033 
00034   @Override
00035   public String label() {
00036     return label;
00037   }
00038 
00039   @Override
00040   public void update(Observation o_tp1, RealVector x_t, Action a_t, RealVector x_tp1) {
00041     update(((ObsAsDoubles) o_tp1).doubleValues());
00042   }
00043 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark