RLPark 1.0.0
Reinforcement Learning Framework in Java

AbstractActorCritic.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy;
00002 
00003 
00004 import rlpark.plugin.rltoys.algorithms.control.ControlLearner;
00005 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
00006 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD;
00007 import rlpark.plugin.rltoys.envio.actions.Action;
00008 import rlpark.plugin.rltoys.math.vector.RealVector;
00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00010 
00011 @Monitor
00012 public abstract class AbstractActorCritic implements ControlLearner {
00013   private static final long serialVersionUID = -6085810735822394602L;
00014   public final Actor actor;
00015   public final OnPolicyTD critic;
00016   protected double reward = 0.0;
00017   protected boolean policyRequireUpdate = true;
00018 
00019   public AbstractActorCritic(OnPolicyTD critic, Actor actor) {
00020     this.critic = critic;
00021     this.actor = actor;
00022   }
00023 
00024   abstract protected double updateCritic(RealVector x_t, RealVector x_tp1, double r_tp1);
00025 
00026   protected void updateActor(RealVector x_t, Action a_t, double actorDelta) {
00027     actor.update(x_t, a_t, actorDelta);
00028   }
00029 
00030   @Override
00031   public Action proposeAction(RealVector x) {
00032     policyRequireUpdate = true;
00033     policy().update(x);
00034     return policy().sampleAction();
00035   }
00036 
00037   protected PolicyDistribution policy() {
00038     return actor.policy();
00039   }
00040 
00041   public Actor actor() {
00042     return actor;
00043   }
00044 
00045   public OnPolicyTD critic() {
00046     return critic;
00047   }
00048 
00049   @Override
00050   public Action step(RealVector x_t, Action a_t, RealVector x_tp1, double r_tp1) {
00051     reward = r_tp1;
00052     double actorDelta = updateCritic(x_t, x_tp1, r_tp1);
00053     policyRequireUpdate = x_t == null || policyRequireUpdate;
00054     if (policyRequireUpdate && x_t != null) {
00055       policy().update(x_t);
00056       policyRequireUpdate = false;
00057     }
00058     updateActor(x_t, a_t, actorDelta);
00059     policy().update(x_tp1);
00060     policyRequireUpdate = false;
00061     return policy().sampleAction();
00062   }
00063 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark