RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy; 00002 00003 import java.io.Serializable; 00004 00005 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00006 import rlpark.plugin.rltoys.envio.actions.Action; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00009 import rlpark.plugin.rltoys.utils.Utils; 00010 import zephyr.plugin.core.api.monitoring.abstracts.LabeledCollection; 00011 import zephyr.plugin.core.api.monitoring.annotations.LabelProvider; 00012 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00013 00014 public class Actor implements Serializable { 00015 private static final long serialVersionUID = 3063342634037779182L; 00016 public final double alpha_u[]; 00017 @Monitor(level = 4) 00018 protected final PVector[] u; 00019 @Monitor 00020 protected final PolicyDistribution policyDistribution; 00021 00022 public Actor(PolicyDistribution policyDistribution, double alpha_u, int nbFeatures) { 00023 this(policyDistribution, Utils.newFilledArray(policyDistribution.nbParameterVectors(), alpha_u), nbFeatures); 00024 } 00025 00026 public Actor(PolicyDistribution policyDistribution, double[] alpha_u, int nbFeatures) { 00027 this(policyDistribution.createParameters(nbFeatures), policyDistribution, alpha_u); 00028 } 00029 00030 public Actor(PVector[] policyParameters, PolicyDistribution policyDistribution, double[] alpha_u) { 00031 this.policyDistribution = policyDistribution; 00032 this.alpha_u = alpha_u; 00033 u = policyParameters; 00034 } 00035 00036 public void update(RealVector x_t, Action a_t, double delta) { 00037 if (x_t == null) 00038 return; 00039 RealVector[] gradLog = policyDistribution.computeGradLog(a_t); 00040 for (int i = 0; i < u.length; i++) 00041 u[i].addToSelf(alpha_u[i] * delta, gradLog[i]); 00042 } 00043 00044 public PolicyDistribution policy() { 00045 return policyDistribution; 00046 } 00047 00048 public int vectorSize() { 00049 int result = 0; 00050 for (PVector v : u) 00051 result += v.size; 00052 return result; 00053 } 00054 00055 public PVector[] actorParameters() { 00056 return u; 00057 } 00058 00059 @LabelProvider(ids = { "u" }) 00060 protected String labelOf(int index) { 00061 if (policyDistribution instanceof LabeledCollection) 00062 return ((LabeledCollection) policyDistribution).label(index); 00063 return null; 00064 } 00065 }