RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy; 00002 00003 import rlpark.plugin.rltoys.algorithms.control.OffPolicyLearner; 00004 import rlpark.plugin.rltoys.algorithms.functions.Predictor; 00005 import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD; 00006 import rlpark.plugin.rltoys.envio.actions.Action; 00007 import rlpark.plugin.rltoys.envio.policy.Policy; 00008 import rlpark.plugin.rltoys.math.vector.RealVector; 00009 import rlpark.plugin.rltoys.utils.Utils; 00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00011 00012 00013 @Monitor 00014 public class OffPAC implements OffPolicyLearner { 00015 private static final long serialVersionUID = -3586849056133550941L; 00016 public final Policy behaviour; 00017 public final OffPolicyTD critic; 00018 public final ActorOffPolicy actor; 00019 protected double pi_t; 00020 protected double b_t; 00021 00022 public OffPAC(Policy behavior, OffPolicyTD critic, ActorOffPolicy actor) { 00023 this.critic = critic; 00024 this.actor = actor; 00025 this.behaviour = behavior; 00026 } 00027 00028 @Override 00029 public void learn(RealVector x_t, Action a_t, RealVector x_tp1, Action a_tp1, double r_tp1) { 00030 if (x_t != null) { 00031 actor.policy().update(x_t); 00032 pi_t = actor.policy().pi(a_t); 00033 behaviour.update(x_t); 00034 b_t = behaviour.pi(a_t); 00035 } 00036 double delta = critic.update(pi_t, b_t, x_t, x_tp1, r_tp1); 00037 assert Utils.checkValue(delta); 00038 actor.update(pi_t, b_t, x_t, a_t, delta); 00039 } 00040 00041 @Override 00042 public Action proposeAction(RealVector s) { 00043 final Action action = actor.proposeAction(s); 00044 assert action != null; 00045 return action; 00046 } 00047 00048 @Override 00049 public Policy targetPolicy() { 00050 return actor.policy(); 00051 } 00052 00053 @Override 00054 public Predictor predictor() { 00055 return critic; 00056 } 00057 00058 public ActorOffPolicy actor() { 00059 return actor; 00060 } 00061 }