RLPark 1.0.0
Reinforcement Learning Framework in Java

NormalDistribution.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures;
00002 
00003 import static rlpark.plugin.rltoys.utils.Utils.square;
00004 
00005 import java.util.Random;
00006 
00007 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
00008 import rlpark.plugin.rltoys.envio.actions.Action;
00009 import rlpark.plugin.rltoys.envio.actions.ActionArray;
00010 import rlpark.plugin.rltoys.math.vector.RealVector;
00011 import rlpark.plugin.rltoys.math.vector.implementations.Vectors;
00012 import rlpark.plugin.rltoys.utils.Utils;
00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00014 
00015 @Monitor
00016 public class NormalDistribution extends AbstractNormalDistribution {
00017   private static final long serialVersionUID = -4074721193363280217L;
00018   protected double sigma2;
00019   private final double initialMean;
00020   private final double initialStddev;
00021 
00022   public NormalDistribution(Random random, double mean, double sigma) {
00023     super(random);
00024     initialMean = mean;
00025     initialStddev = sigma;
00026   }
00027 
00028   @Override
00029   public RealVector[] computeGradLog(Action a) {
00030     updateSteps(ActionArray.toDouble(a));
00031     gradMean.set(x).mapMultiplyToSelf(meanStep);
00032     gradStddev.set(x).mapMultiplyToSelf(stddevStep);
00033     assert Vectors.checkValues(gradMean) && Vectors.checkValues(gradStddev);
00034     return new RealVector[] { gradMean, gradStddev };
00035   }
00036 
00037   protected void updateSteps(double a) {
00038     meanStep = (a - mean) / sigma2;
00039     stddevStep = square(a - mean) / sigma2 - 1;
00040   }
00041 
00042   @Override
00043   public Action sampleAction() {
00044     a_t = random.nextGaussian() * stddev + mean;
00045     if (!Utils.checkValue(a_t))
00046       return null;
00047     return new ActionArray(a_t);
00048   }
00049 
00050   @Override
00051   protected void updateDistribution() {
00052     mean = u_mean.dotProduct(x) + initialMean;
00053     stddev = Math.exp(u_stddev.dotProduct(x)) * initialStddev + Utils.EPSILON;
00054     sigma2 = square(stddev);
00055     assert Utils.checkValue(mean) && Utils.checkValue(sigma2);
00056   }
00057 
00058   @Override
00059   public double pi_s(double a) {
00060     double ammu2 = (a - mean) * (a - mean);
00061     return Math.exp(-ammu2 / (2 * sigma2)) / Math.sqrt(2 * Math.PI * sigma2);
00062   }
00063 
00064   static public JointDistribution newJointDistribution(Random random, int nbNormalDistribution, double mean,
00065       double sigma) {
00066     PolicyDistribution[] distributions = new PolicyDistribution[nbNormalDistribution];
00067     for (int i = 0; i < distributions.length; i++)
00068       distributions[i] = new NormalDistribution(random, mean, sigma);
00069     return new JointDistribution(distributions);
00070   }
00071 
00072   @Override
00073   public double piMax() {
00074     return Math.max(pi(new ActionArray(mean)), Utils.EPSILON);
00075   }
00076 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark