RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy; 00002 00003 00004 import rlpark.plugin.rltoys.algorithms.control.ControlLearner; 00005 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00006 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD; 00007 import rlpark.plugin.rltoys.envio.actions.Action; 00008 import rlpark.plugin.rltoys.math.vector.RealVector; 00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00010 00011 @Monitor 00012 public abstract class AbstractActorCritic implements ControlLearner { 00013 private static final long serialVersionUID = -6085810735822394602L; 00014 public final Actor actor; 00015 public final OnPolicyTD critic; 00016 protected double reward = 0.0; 00017 protected boolean policyRequireUpdate = true; 00018 00019 public AbstractActorCritic(OnPolicyTD critic, Actor actor) { 00020 this.critic = critic; 00021 this.actor = actor; 00022 } 00023 00024 abstract protected double updateCritic(RealVector x_t, RealVector x_tp1, double r_tp1); 00025 00026 protected void updateActor(RealVector x_t, Action a_t, double actorDelta) { 00027 actor.update(x_t, a_t, actorDelta); 00028 } 00029 00030 @Override 00031 public Action proposeAction(RealVector x) { 00032 policyRequireUpdate = true; 00033 policy().update(x); 00034 return policy().sampleAction(); 00035 } 00036 00037 protected PolicyDistribution policy() { 00038 return actor.policy(); 00039 } 00040 00041 public Actor actor() { 00042 return actor; 00043 } 00044 00045 public OnPolicyTD critic() { 00046 return critic; 00047 } 00048 00049 @Override 00050 public Action step(RealVector x_t, Action a_t, RealVector x_tp1, double r_tp1) { 00051 reward = r_tp1; 00052 double actorDelta = updateCritic(x_t, x_tp1, r_tp1); 00053 policyRequireUpdate = x_t == null || policyRequireUpdate; 00054 if (policyRequireUpdate && x_t != null) { 00055 policy().update(x_t); 00056 policyRequireUpdate = false; 00057 } 00058 updateActor(x_t, a_t, actorDelta); 00059 policy().update(x_tp1); 00060 policyRequireUpdate = false; 00061 return policy().sampleAction(); 00062 } 00063 }