RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy; 00002 00003 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD; 00004 import rlpark.plugin.rltoys.math.vector.RealVector; 00005 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00006 00007 @Monitor 00008 public class AverageRewardActorCritic extends AbstractActorCritic { 00009 private static final long serialVersionUID = 3772938582043052714L; 00010 protected double averageReward = 0.0; 00011 private final double alpha_r; 00012 00013 public AverageRewardActorCritic(double alpha_r, OnPolicyTD critic, Actor actor) { 00014 super(critic, actor); 00015 this.alpha_r = alpha_r; 00016 } 00017 00018 @Override 00019 protected double updateCritic(RealVector x_t, RealVector x_tp1, double r_tp1) { 00020 double delta = critic.update(x_t, x_tp1, r_tp1 - averageReward); 00021 averageReward += alpha_r * delta; 00022 return delta; 00023 } 00024 00025 public double currentAverage() { 00026 return averageReward; 00027 } 00028 }