RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures; 00002 00003 import static rlpark.plugin.rltoys.utils.Utils.square; 00004 00005 import java.util.Random; 00006 00007 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00008 import rlpark.plugin.rltoys.envio.actions.Action; 00009 import rlpark.plugin.rltoys.envio.actions.ActionArray; 00010 import rlpark.plugin.rltoys.math.vector.RealVector; 00011 import rlpark.plugin.rltoys.math.vector.implementations.Vectors; 00012 import rlpark.plugin.rltoys.utils.Utils; 00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00014 00015 @Monitor 00016 public class NormalDistribution extends AbstractNormalDistribution { 00017 private static final long serialVersionUID = -4074721193363280217L; 00018 protected double sigma2; 00019 private final double initialMean; 00020 private final double initialStddev; 00021 00022 public NormalDistribution(Random random, double mean, double sigma) { 00023 super(random); 00024 initialMean = mean; 00025 initialStddev = sigma; 00026 } 00027 00028 @Override 00029 public RealVector[] computeGradLog(Action a) { 00030 updateSteps(ActionArray.toDouble(a)); 00031 gradMean.set(x).mapMultiplyToSelf(meanStep); 00032 gradStddev.set(x).mapMultiplyToSelf(stddevStep); 00033 assert Vectors.checkValues(gradMean) && Vectors.checkValues(gradStddev); 00034 return new RealVector[] { gradMean, gradStddev }; 00035 } 00036 00037 protected void updateSteps(double a) { 00038 meanStep = (a - mean) / sigma2; 00039 stddevStep = square(a - mean) / sigma2 - 1; 00040 } 00041 00042 @Override 00043 public Action sampleAction() { 00044 a_t = random.nextGaussian() * stddev + mean; 00045 if (!Utils.checkValue(a_t)) 00046 return null; 00047 return new ActionArray(a_t); 00048 } 00049 00050 @Override 00051 protected void updateDistribution() { 00052 mean = u_mean.dotProduct(x) + initialMean; 00053 stddev = Math.exp(u_stddev.dotProduct(x)) * initialStddev + Utils.EPSILON; 00054 sigma2 = square(stddev); 00055 assert Utils.checkValue(mean) && Utils.checkValue(sigma2); 00056 } 00057 00058 @Override 00059 public double pi_s(double a) { 00060 double ammu2 = (a - mean) * (a - mean); 00061 return Math.exp(-ammu2 / (2 * sigma2)) / Math.sqrt(2 * Math.PI * sigma2); 00062 } 00063 00064 static public JointDistribution newJointDistribution(Random random, int nbNormalDistribution, double mean, 00065 double sigma) { 00066 PolicyDistribution[] distributions = new PolicyDistribution[nbNormalDistribution]; 00067 for (int i = 0; i < distributions.length; i++) 00068 distributions[i] = new NormalDistribution(random, mean, sigma); 00069 return new JointDistribution(distributions); 00070 } 00071 00072 @Override 00073 public double piMax() { 00074 return Math.max(pi(new ActionArray(mean)), Utils.EPSILON); 00075 } 00076 }