RLPark 1.0.0
Reinforcement Learning Framework in Java

AbstractNormalDistribution.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures;
00002 
00003 import java.util.Random;
00004 
00005 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.BoundedPdf;
00006 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyParameterized;
00007 import rlpark.plugin.rltoys.envio.actions.Action;
00008 import rlpark.plugin.rltoys.envio.actions.ActionArray;
00009 import rlpark.plugin.rltoys.envio.actions.Actions;
00010 import rlpark.plugin.rltoys.math.vector.MutableVector;
00011 import rlpark.plugin.rltoys.math.vector.RealVector;
00012 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00013 import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs;
00014 import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared;
00015 import zephyr.plugin.core.api.monitoring.abstracts.LabeledCollection;
00016 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00017 
00018 @Monitor
00019 @SuppressWarnings("restriction")
00020 public abstract class AbstractNormalDistribution implements PolicyParameterized, LabeledCollection, BoundedPdf {
00021   private static final long serialVersionUID = -6707070542157254303L;
00022   @Monitor(level = 4)
00023   protected PVector u_mean;
00024   @Monitor(level = 4)
00025   protected PVector u_stddev;
00026   @Monitor(wrappers = { Abs.ID })
00027   protected double mean = 0;
00028   protected double stddev = 0;
00029   protected final Random random;
00030   public double a_t;
00031   @Monitor(wrappers = { Squared.ID, Abs.ID })
00032   protected double meanStep;
00033   @Monitor(wrappers = { Squared.ID, Abs.ID })
00034   protected double stddevStep;
00035 
00036   protected RealVector x = null;
00037   protected MutableVector gradMean = null;
00038   protected MutableVector gradStddev = null;
00039 
00040   public AbstractNormalDistribution(Random random) {
00041     this.random = random;
00042   }
00043 
00044   @Override
00045   public PVector[] createParameters(int nbFeatures) {
00046     setParameters(new PVector(nbFeatures), new PVector(nbFeatures));
00047     return new PVector[] { u_mean, u_stddev };
00048   }
00049 
00050   @Override
00051   public void setParameters(PVector... u) {
00052     assert u.length == 2;
00053     u_mean = u[0];
00054     u_stddev = u[1];
00055   }
00056 
00057   @Override
00058   public PVector[] parameters() {
00059     return new PVector[] { u_mean, u_stddev };
00060   }
00061 
00062   public double stddev() {
00063     return stddev;
00064   }
00065 
00066   public double mean() {
00067     return mean;
00068   }
00069 
00070   @Override
00071   final public void update(RealVector x) {
00072     if (this.x == null)
00073       allocateBuffers(x);
00074     ((MutableVector) this.x).set(x);
00075     updateDistribution();
00076   }
00077 
00078   protected void allocateBuffers(RealVector prototype) {
00079     x = prototype.copyAsMutable();
00080     gradMean = prototype.copyAsMutable();
00081     gradStddev = prototype.copyAsMutable();
00082   }
00083 
00084   abstract protected void updateDistribution();
00085 
00086   @Override
00087   public double pi(Action a) {
00088     assert Actions.isOneDimension(a);
00089     return pi_s(ActionArray.toDouble(a));
00090   }
00091 
00092   public abstract double pi_s(double a);
00093 
00094   @Override
00095   public String label(int index) {
00096     return index == 0 ? "mean" : "stddev";
00097   }
00098 
00099   @Override
00100   public int nbParameterVectors() {
00101     return 2;
00102   }
00103 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark