RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.horde.functions; 00002 00003 import rlpark.plugin.rltoys.envio.actions.Action; 00004 import rlpark.plugin.rltoys.envio.observations.Legend; 00005 import rlpark.plugin.rltoys.envio.observations.ObsAsDoubles; 00006 import rlpark.plugin.rltoys.envio.observations.Observation; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 import zephyr.plugin.core.api.labels.Labeled; 00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00010 00011 public class RewardObservationFunction implements RewardFunction, Labeled, HordeUpdatable { 00012 private static final long serialVersionUID = -5930168576876015871L; 00013 @Monitor 00014 protected double reward; 00015 private final int observationIndex; 00016 private final String label; 00017 00018 public RewardObservationFunction(Legend legend, String label) { 00019 this.label = label; 00020 observationIndex = legend.indexOf(label); 00021 if (observationIndex < 0) 00022 throw new RuntimeException(label + " not found in the legend"); 00023 } 00024 00025 public void update(double[] o) { 00026 reward = o != null ? o[observationIndex] : 0.0; 00027 } 00028 00029 @Override 00030 public double reward() { 00031 return reward; 00032 } 00033 00034 @Override 00035 public String label() { 00036 return label; 00037 } 00038 00039 @Override 00040 public void update(Observation o_tp1, RealVector x_t, Action a_t, RealVector x_tp1) { 00041 update(((ObsAsDoubles) o_tp1).doubleValues()); 00042 } 00043 }