RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.horde.demons; 00002 00003 import java.io.Serializable; 00004 00005 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD; 00006 import rlpark.plugin.rltoys.algorithms.predictions.td.TD; 00007 import rlpark.plugin.rltoys.algorithms.predictions.td.TDErrorMonitor; 00008 import rlpark.plugin.rltoys.algorithms.predictions.td.TDLambdaAutostep; 00009 import rlpark.plugin.rltoys.horde.functions.RewardFunction; 00010 import rlpark.plugin.rltoys.utils.NotImplemented; 00011 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00012 00013 public class PredictionDemonVerifier implements Serializable { 00014 private static final long serialVersionUID = 6127406364376542150L; 00015 private final PredictionDemon predictionDemon; 00016 private final RewardFunction rewardFunction; 00017 @Monitor 00018 private final TDErrorMonitor errorMonitor; 00019 00020 public PredictionDemonVerifier(PredictionDemon predictionDemon) { 00021 this(predictionDemon, 0.01); 00022 } 00023 00024 public PredictionDemonVerifier(PredictionDemon predictionDemon, double precision) { 00025 this.predictionDemon = predictionDemon; 00026 rewardFunction = predictionDemon.rewardFunction(); 00027 double gamma = extractGamma(predictionDemon.predicter()); 00028 errorMonitor = new TDErrorMonitor(gamma, precision); 00029 } 00030 00031 public double extractGamma(OnPolicyTD learner) { 00032 if (learner instanceof TD) 00033 return ((TD) learner).gamma(); 00034 if (learner instanceof TDLambdaAutostep) 00035 return ((TDLambdaAutostep) learner).gamma(); 00036 throw new NotImplemented(); 00037 } 00038 00039 public TDErrorMonitor errorMonitor() { 00040 return errorMonitor; 00041 } 00042 00043 public double update(boolean endOfEpisode) { 00044 return errorMonitor.update(predictionDemon.prediction(), rewardFunction.reward(), endOfEpisode); 00045 } 00046 }