RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.horde.demons; 00002 00003 import rlpark.plugin.rltoys.algorithms.LinearLearner; 00004 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD; 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.horde.functions.RewardFunction; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 import zephyr.plugin.core.api.labels.Labeled; 00009 import zephyr.plugin.core.api.labels.Labels; 00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00011 00012 public class PredictionDemon implements Demon, Labeled { 00013 private static final long serialVersionUID = -6966208035134604865L; 00014 private final RewardFunction rewardFunction; 00015 @Monitor 00016 private final OnPolicyTD td; 00017 00018 public PredictionDemon(RewardFunction rewardFunction, OnPolicyTD td) { 00019 this.rewardFunction = rewardFunction; 00020 this.td = td; 00021 } 00022 00023 @Override 00024 public void update(RealVector x_t, Action a_t, RealVector x_tp1) { 00025 td.update(x_t, x_tp1, rewardFunction.reward()); 00026 } 00027 00028 public double prediction() { 00029 return td.prediction(); 00030 } 00031 00032 public RewardFunction rewardFunction() { 00033 return rewardFunction; 00034 } 00035 00036 public OnPolicyTD predicter() { 00037 return td; 00038 } 00039 00040 @Override 00041 public String label() { 00042 return "demon" + Labels.label(rewardFunction); 00043 } 00044 00045 @Override 00046 public LinearLearner learner() { 00047 return td; 00048 } 00049 }