RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy; 00002 00003 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00004 import rlpark.plugin.rltoys.algorithms.traces.ATraces; 00005 import rlpark.plugin.rltoys.algorithms.traces.Traces; 00006 import rlpark.plugin.rltoys.envio.actions.Action; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 import rlpark.plugin.rltoys.utils.Utils; 00009 import zephyr.plugin.core.api.monitoring.annotations.LabelProvider; 00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00011 00012 @Monitor 00013 public class ActorLambda extends Actor { 00014 private static final long serialVersionUID = -1601184295976574511L; 00015 public final Traces[] e_u; 00016 private final double lambda; 00017 private final double gamma; 00018 00019 public ActorLambda(double lambda, double gamma, PolicyDistribution policyDistribution, double alpha_u, int nbFeatures) { 00020 this(lambda, gamma, policyDistribution, alpha_u, nbFeatures, new ATraces()); 00021 } 00022 00023 public ActorLambda(double lambda, double gamma, PolicyDistribution policyDistribution, double alpha_u, 00024 int nbFeatures, Traces prototype) { 00025 this(lambda, gamma, policyDistribution, Utils.newFilledArray(policyDistribution.nbParameterVectors(), alpha_u), 00026 nbFeatures, prototype); 00027 } 00028 00029 public ActorLambda(double lambda, double gamma, PolicyDistribution policyDistribution, double[] alpha_u, 00030 int nbFeatures, Traces prototype) { 00031 super(policyDistribution, alpha_u, nbFeatures); 00032 this.lambda = lambda; 00033 this.gamma = gamma; 00034 e_u = new Traces[policyDistribution.nbParameterVectors()]; 00035 for (int i = 0; i < e_u.length; i++) 00036 e_u[i] = prototype.newTraces(u[i].size); 00037 } 00038 00039 @Override 00040 public void update(RealVector x_t, Action a_t, double delta) { 00041 if (x_t == null) { 00042 initEpisode(); 00043 return; 00044 } 00045 RealVector[] gradLog = policyDistribution.computeGradLog(a_t); 00046 for (int i = 0; i < u.length; i++) 00047 e_u[i].update(gamma * lambda, gradLog[i]); 00048 updatePolicyParameters(gradLog, delta); 00049 } 00050 00051 protected void updatePolicyParameters(RealVector[] gradLog, double delta) { 00052 for (int i = 0; i < u.length; i++) 00053 u[i].addToSelf(alpha_u[i] * delta, e_u[i].vect()); 00054 } 00055 00056 private void initEpisode() { 00057 for (Traces e : e_u) 00058 e.clear(); 00059 } 00060 00061 @LabelProvider(ids = { "e_u" }) 00062 String eligiblityLabelOf(int index) { 00063 return super.labelOf(index); 00064 } 00065 }