RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy; 00002 00003 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00004 import rlpark.plugin.rltoys.algorithms.traces.Traces; 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.math.vector.RealVector; 00007 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00008 00009 public class ActorLambdaOffPolicy extends AbstractActorOffPolicy { 00010 final protected Traces[] e_u; 00011 final public double lambda; 00012 final protected double alpha_u; 00013 private double rho_t; 00014 00015 public ActorLambdaOffPolicy(double lambda, double gamma, PolicyDistribution policyDistribution, double alpha_u, 00016 int nbFeatures, Traces prototype) { 00017 this(policyDistribution.createParameters(nbFeatures), lambda, gamma, policyDistribution, alpha_u, prototype); 00018 } 00019 00020 public ActorLambdaOffPolicy(PVector[] policyParameters, double lambda, double gamma, 00021 PolicyDistribution policyDistribution, double alpha_u, Traces prototype) { 00022 super(policyParameters, policyDistribution); 00023 this.alpha_u = alpha_u; 00024 this.lambda = lambda; 00025 e_u = new Traces[u.length]; 00026 for (int i = 0; i < e_u.length; i++) 00027 e_u[i] = prototype.newTraces(u[i].size); 00028 } 00029 00030 protected void updateEligibilityTraces(double rho_t, Action a_t, double delta) { 00031 RealVector[] gradLog = targetPolicy.computeGradLog(a_t); 00032 for (int i = 0; i < u.length; i++) { 00033 e_u[i].update(lambda, gradLog[i]); 00034 e_u[i].vect().mapMultiplyToSelf(rho_t); 00035 } 00036 } 00037 00038 protected void updatePolicyParameters(double rho_t, Action a_t, double delta) { 00039 for (int i = 0; i < u.length; i++) 00040 u[i].addToSelf(alpha_u * delta, e_u[i].vect()); 00041 } 00042 00043 @Override 00044 protected void updateParameters(double pi_t, double b_t, RealVector x_t, Action a_t, double delta) { 00045 targetPolicy.update(x_t); 00046 rho_t = pi_t / b_t; 00047 updateEligibilityTraces(rho_t, a_t, delta); 00048 updatePolicyParameters(rho_t, a_t, delta); 00049 } 00050 00051 @Override 00052 protected void initEpisode() { 00053 for (Traces e : e_u) 00054 e.clear(); 00055 } 00056 00057 public Traces[] eligibilities() { 00058 return e_u; 00059 } 00060 }