RLPark 1.0.0
Reinforcement Learning Framework in Java

ActorLambda.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy;
00002 
00003 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
00004 import rlpark.plugin.rltoys.algorithms.traces.ATraces;
00005 import rlpark.plugin.rltoys.algorithms.traces.Traces;
00006 import rlpark.plugin.rltoys.envio.actions.Action;
00007 import rlpark.plugin.rltoys.math.vector.RealVector;
00008 import rlpark.plugin.rltoys.utils.Utils;
00009 import zephyr.plugin.core.api.monitoring.annotations.LabelProvider;
00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00011 
00012 @Monitor
00013 public class ActorLambda extends Actor {
00014   private static final long serialVersionUID = -1601184295976574511L;
00015   public final Traces[] e_u;
00016   private final double lambda;
00017   private final double gamma;
00018 
00019   public ActorLambda(double lambda, double gamma, PolicyDistribution policyDistribution, double alpha_u, int nbFeatures) {
00020     this(lambda, gamma, policyDistribution, alpha_u, nbFeatures, new ATraces());
00021   }
00022 
00023   public ActorLambda(double lambda, double gamma, PolicyDistribution policyDistribution, double alpha_u,
00024       int nbFeatures, Traces prototype) {
00025     this(lambda, gamma, policyDistribution, Utils.newFilledArray(policyDistribution.nbParameterVectors(), alpha_u),
00026          nbFeatures, prototype);
00027   }
00028 
00029   public ActorLambda(double lambda, double gamma, PolicyDistribution policyDistribution, double[] alpha_u,
00030       int nbFeatures, Traces prototype) {
00031     super(policyDistribution, alpha_u, nbFeatures);
00032     this.lambda = lambda;
00033     this.gamma = gamma;
00034     e_u = new Traces[policyDistribution.nbParameterVectors()];
00035     for (int i = 0; i < e_u.length; i++)
00036       e_u[i] = prototype.newTraces(u[i].size);
00037   }
00038 
00039   @Override
00040   public void update(RealVector x_t, Action a_t, double delta) {
00041     if (x_t == null) {
00042       initEpisode();
00043       return;
00044     }
00045     RealVector[] gradLog = policyDistribution.computeGradLog(a_t);
00046     for (int i = 0; i < u.length; i++)
00047       e_u[i].update(gamma * lambda, gradLog[i]);
00048     updatePolicyParameters(gradLog, delta);
00049   }
00050 
00051   protected void updatePolicyParameters(RealVector[] gradLog, double delta) {
00052     for (int i = 0; i < u.length; i++)
00053       u[i].addToSelf(alpha_u[i] * delta, e_u[i].vect());
00054   }
00055 
00056   private void initEpisode() {
00057     for (Traces e : e_u)
00058       e.clear();
00059   }
00060 
00061   @LabelProvider(ids = { "e_u" })
00062   String eligiblityLabelOf(int index) {
00063     return super.labelOf(index);
00064   }
00065 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark