RLPark 1.0.0
Reinforcement Learning Framework in Java

ScaledPolicyDistribution.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.helpers;
00002 
00003 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.BoundedPdf;
00004 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
00005 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyParameterized;
00006 import rlpark.plugin.rltoys.envio.actions.Action;
00007 import rlpark.plugin.rltoys.envio.actions.ActionArray;
00008 import rlpark.plugin.rltoys.envio.policy.BoundedPolicy;
00009 import rlpark.plugin.rltoys.math.ranges.Range;
00010 import rlpark.plugin.rltoys.math.vector.RealVector;
00011 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00012 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00013 
00014 public class ScaledPolicyDistribution implements BoundedPdf, PolicyParameterized {
00015   private static final long serialVersionUID = -7521424991872961399L;
00016   @Monitor
00017   protected final PolicyDistribution policy;
00018   protected final Range policyRange;
00019   protected final Range problemRange;
00020 
00021   public ScaledPolicyDistribution(BoundedPolicy policy, Range problemRange) {
00022     this((PolicyDistribution) policy, policy.range(), problemRange);
00023   }
00024 
00025 
00026   public ScaledPolicyDistribution(PolicyDistribution policy, Range policyRange, Range problemRange) {
00027     this.policy = policy;
00028     this.policyRange = policyRange;
00029     this.problemRange = problemRange;
00030   }
00031 
00032   @Override
00033   public double pi(Action a) {
00034     return policy.pi(problemToPolicy(ActionArray.toDouble(a)));
00035   }
00036 
00037   @Override
00038   public PVector[] createParameters(int nbFeatures) {
00039     return policy.createParameters(nbFeatures);
00040   }
00041 
00042   @Override
00043   public Action sampleAction() {
00044     return policyToProblem(ActionArray.toDouble(policy.sampleAction()));
00045   }
00046 
00047   @Override
00048   public RealVector[] computeGradLog(Action a_t) {
00049     return policy.computeGradLog(problemToPolicy(ActionArray.toDouble(a_t)));
00050   }
00051 
00052   private ActionArray policyToProblem(double policyAction) {
00053     double normalizedAction = normalize(policyRange, policyAction);
00054     return new ActionArray(scale(problemRange, normalizedAction));
00055   }
00056 
00057   protected ActionArray problemToPolicy(double problemAction) {
00058     double normalizedAction = normalize(problemRange, problemAction);
00059     return new ActionArray(scale(policyRange, normalizedAction));
00060   }
00061 
00062   private double normalize(Range range, double a) {
00063     return (a - range.center()) / (range.length() / 2.0);
00064   }
00065 
00066   private double scale(Range range, double a) {
00067     return (a * (range.length() / 2.0)) + range.center();
00068   }
00069 
00070   @Override
00071   public int nbParameterVectors() {
00072     return policy.nbParameterVectors();
00073   }
00074 
00075 
00076   @Override
00077   public double piMax() {
00078     return ((BoundedPdf) policy).piMax();
00079   }
00080 
00081 
00082   @Override
00083   public void update(RealVector x) {
00084     policy.update(x);
00085   }
00086 
00087   @Override
00088   public void setParameters(PVector... parameters) {
00089     ((PolicyParameterized) policy).setParameters(parameters);
00090   }
00091 
00092 
00093   @Override
00094   public PVector[] parameters() {
00095     return ((PolicyParameterized) policy).parameters();
00096   }
00097 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark