RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.BoundedPdf; 00006 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyParameterized; 00007 import rlpark.plugin.rltoys.envio.actions.Action; 00008 import rlpark.plugin.rltoys.envio.actions.ActionArray; 00009 import rlpark.plugin.rltoys.envio.actions.Actions; 00010 import rlpark.plugin.rltoys.math.vector.MutableVector; 00011 import rlpark.plugin.rltoys.math.vector.RealVector; 00012 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00013 import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs; 00014 import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared; 00015 import zephyr.plugin.core.api.monitoring.abstracts.LabeledCollection; 00016 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00017 00018 @Monitor 00019 @SuppressWarnings("restriction") 00020 public abstract class AbstractNormalDistribution implements PolicyParameterized, LabeledCollection, BoundedPdf { 00021 private static final long serialVersionUID = -6707070542157254303L; 00022 @Monitor(level = 4) 00023 protected PVector u_mean; 00024 @Monitor(level = 4) 00025 protected PVector u_stddev; 00026 @Monitor(wrappers = { Abs.ID }) 00027 protected double mean = 0; 00028 protected double stddev = 0; 00029 protected final Random random; 00030 public double a_t; 00031 @Monitor(wrappers = { Squared.ID, Abs.ID }) 00032 protected double meanStep; 00033 @Monitor(wrappers = { Squared.ID, Abs.ID }) 00034 protected double stddevStep; 00035 00036 protected RealVector x = null; 00037 protected MutableVector gradMean = null; 00038 protected MutableVector gradStddev = null; 00039 00040 public AbstractNormalDistribution(Random random) { 00041 this.random = random; 00042 } 00043 00044 @Override 00045 public PVector[] createParameters(int nbFeatures) { 00046 setParameters(new PVector(nbFeatures), new PVector(nbFeatures)); 00047 return new PVector[] { u_mean, u_stddev }; 00048 } 00049 00050 @Override 00051 public void setParameters(PVector... u) { 00052 assert u.length == 2; 00053 u_mean = u[0]; 00054 u_stddev = u[1]; 00055 } 00056 00057 @Override 00058 public PVector[] parameters() { 00059 return new PVector[] { u_mean, u_stddev }; 00060 } 00061 00062 public double stddev() { 00063 return stddev; 00064 } 00065 00066 public double mean() { 00067 return mean; 00068 } 00069 00070 @Override 00071 final public void update(RealVector x) { 00072 if (this.x == null) 00073 allocateBuffers(x); 00074 ((MutableVector) this.x).set(x); 00075 updateDistribution(); 00076 } 00077 00078 protected void allocateBuffers(RealVector prototype) { 00079 x = prototype.copyAsMutable(); 00080 gradMean = prototype.copyAsMutable(); 00081 gradStddev = prototype.copyAsMutable(); 00082 } 00083 00084 abstract protected void updateDistribution(); 00085 00086 @Override 00087 public double pi(Action a) { 00088 assert Actions.isOneDimension(a); 00089 return pi_s(ActionArray.toDouble(a)); 00090 } 00091 00092 public abstract double pi_s(double a); 00093 00094 @Override 00095 public String label(int index) { 00096 return index == 0 ? "mean" : "stddev"; 00097 } 00098 00099 @Override 00100 public int nbParameterVectors() { 00101 return 2; 00102 } 00103 }