RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00006 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00007 import rlpark.plugin.rltoys.envio.actions.Action; 00008 import rlpark.plugin.rltoys.envio.policy.StochasticPolicy; 00009 import rlpark.plugin.rltoys.math.ranges.Range; 00010 import rlpark.plugin.rltoys.math.vector.MutableVector; 00011 import rlpark.plugin.rltoys.math.vector.RealVector; 00012 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00013 import rlpark.plugin.rltoys.utils.Utils; 00014 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00015 00016 public class BoltzmannDistribution extends StochasticPolicy implements PolicyDistribution { 00017 private static final long serialVersionUID = 7036360201611314726L; 00018 private final MutableVector[] xa; 00019 @Monitor(level = 4) 00020 private PVector u; 00021 private MutableVector xaBar; 00022 private MutableVector gradBuffer; 00023 private final StateToStateAction toStateAction; 00024 private final double[] distribution; 00025 @Monitor 00026 private final Range linearRangeOverall = new Range(1.0, 1.0); 00027 @Monitor 00028 private final Range linearRangeAveraged = new Range(1.0, 1.0); 00029 00030 public BoltzmannDistribution(Random random, Action[] actions, StateToStateAction toStateAction) { 00031 super(random, actions); 00032 assert toStateAction != null; 00033 this.toStateAction = toStateAction; 00034 distribution = new double[actions.length]; 00035 xa = new MutableVector[actions.length]; 00036 } 00037 00038 @Override 00039 public double pi(Action a) { 00040 return distribution[atoi(a)]; 00041 } 00042 00043 @Override 00044 public void update(RealVector x) { 00045 linearRangeAveraged.reset(); 00046 double sum = 0; 00047 clearBuffers(x); 00048 for (int a_i = 0; a_i < actions.length; a_i++) { 00049 xa[a_i].set(toStateAction.stateAction(x, actions[a_i])); 00050 final double linearCombination = u.dotProduct(xa[a_i]); 00051 linearRangeOverall.update(linearCombination); 00052 linearRangeAveraged.update(linearCombination); 00053 double probabilityNotNormalized = Math.exp(linearCombination); 00054 assert Utils.checkValue(probabilityNotNormalized); 00055 distribution[a_i] = probabilityNotNormalized; 00056 sum += probabilityNotNormalized; 00057 xaBar.addToSelf(probabilityNotNormalized, xa[a_i]); 00058 } 00059 for (int i = 0; i < distribution.length; i++) { 00060 distribution[i] /= sum; 00061 assert Utils.checkValue(distribution[i]); 00062 } 00063 xaBar.mapMultiplyToSelf(1.0 / sum); 00064 } 00065 00066 private void clearBuffers(RealVector x) { 00067 if (xaBar == null) { 00068 xaBar = toStateAction.stateAction(x, actions[0]).newInstance(u.size); 00069 gradBuffer = xaBar.newInstance(u.size); 00070 for (int i = 0; i < xa.length; i++) 00071 xa[i] = xaBar.newInstance(u.size); 00072 return; 00073 } 00074 xaBar.clear(); 00075 } 00076 00077 @Override 00078 public Action sampleAction() { 00079 return chooseAction(distribution); 00080 } 00081 00082 @Override 00083 public PVector[] createParameters(int nbFeatures) { 00084 u = new PVector(toStateAction.vectorSize()); 00085 return new PVector[] { u }; 00086 } 00087 00088 @Override 00089 public RealVector[] computeGradLog(Action a_t) { 00090 gradBuffer.clear(); 00091 gradBuffer.set(xa[atoi(a_t)]); 00092 return new RealVector[] { gradBuffer.subtractToSelf(xaBar) }; 00093 } 00094 00095 @Override 00096 public int nbParameterVectors() { 00097 return 1; 00098 } 00099 00100 @Override 00101 public double[] distribution() { 00102 return distribution; 00103 } 00104 00105 static public double probaToLinearValue(int nbAction, double proba) { 00106 double max = Math.log(proba * (nbAction - 1)) - Math.log(1 - proba); 00107 assert proba > .5 && max > 0 || proba < .5 && max < 0; 00108 return max; 00109 } 00110 }