RLPark 1.0.0
Reinforcement Learning Framework in Java

ActorLambdaOffPolicy.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy;
00002 
00003 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
00004 import rlpark.plugin.rltoys.algorithms.traces.Traces;
00005 import rlpark.plugin.rltoys.envio.actions.Action;
00006 import rlpark.plugin.rltoys.math.vector.RealVector;
00007 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00008 
00009 public class ActorLambdaOffPolicy extends AbstractActorOffPolicy {
00010   final protected Traces[] e_u;
00011   final public double lambda;
00012   final protected double alpha_u;
00013   private double rho_t;
00014 
00015   public ActorLambdaOffPolicy(double lambda, double gamma, PolicyDistribution policyDistribution, double alpha_u,
00016       int nbFeatures, Traces prototype) {
00017     this(policyDistribution.createParameters(nbFeatures), lambda, gamma, policyDistribution, alpha_u, prototype);
00018   }
00019 
00020   public ActorLambdaOffPolicy(PVector[] policyParameters, double lambda, double gamma,
00021       PolicyDistribution policyDistribution, double alpha_u, Traces prototype) {
00022     super(policyParameters, policyDistribution);
00023     this.alpha_u = alpha_u;
00024     this.lambda = lambda;
00025     e_u = new Traces[u.length];
00026     for (int i = 0; i < e_u.length; i++)
00027       e_u[i] = prototype.newTraces(u[i].size);
00028   }
00029 
00030   protected void updateEligibilityTraces(double rho_t, Action a_t, double delta) {
00031     RealVector[] gradLog = targetPolicy.computeGradLog(a_t);
00032     for (int i = 0; i < u.length; i++) {
00033       e_u[i].update(lambda, gradLog[i]);
00034       e_u[i].vect().mapMultiplyToSelf(rho_t);
00035     }
00036   }
00037 
00038   protected void updatePolicyParameters(double rho_t, Action a_t, double delta) {
00039     for (int i = 0; i < u.length; i++)
00040       u[i].addToSelf(alpha_u * delta, e_u[i].vect());
00041   }
00042 
00043   @Override
00044   protected void updateParameters(double pi_t, double b_t, RealVector x_t, Action a_t, double delta) {
00045     targetPolicy.update(x_t);
00046     rho_t = pi_t / b_t;
00047     updateEligibilityTraces(rho_t, a_t, delta);
00048     updatePolicyParameters(rho_t, a_t, delta);
00049   }
00050 
00051   @Override
00052   protected void initEpisode() {
00053     for (Traces e : e_u)
00054       e.clear();
00055   }
00056 
00057   public Traces[] eligibilities() {
00058     return e_u;
00059   }
00060 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark