RLPark 1.0.0
Reinforcement Learning Framework in Java

BoltzmannDistribution.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures;
00002 
00003 import java.util.Random;
00004 
00005 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
00006 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
00007 import rlpark.plugin.rltoys.envio.actions.Action;
00008 import rlpark.plugin.rltoys.envio.policy.StochasticPolicy;
00009 import rlpark.plugin.rltoys.math.ranges.Range;
00010 import rlpark.plugin.rltoys.math.vector.MutableVector;
00011 import rlpark.plugin.rltoys.math.vector.RealVector;
00012 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00013 import rlpark.plugin.rltoys.utils.Utils;
00014 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00015 
00016 public class BoltzmannDistribution extends StochasticPolicy implements PolicyDistribution {
00017   private static final long serialVersionUID = 7036360201611314726L;
00018   private final MutableVector[] xa;
00019   @Monitor(level = 4)
00020   private PVector u;
00021   private MutableVector xaBar;
00022   private MutableVector gradBuffer;
00023   private final StateToStateAction toStateAction;
00024   private final double[] distribution;
00025   @Monitor
00026   private final Range linearRangeOverall = new Range(1.0, 1.0);
00027   @Monitor
00028   private final Range linearRangeAveraged = new Range(1.0, 1.0);
00029 
00030   public BoltzmannDistribution(Random random, Action[] actions, StateToStateAction toStateAction) {
00031     super(random, actions);
00032     assert toStateAction != null;
00033     this.toStateAction = toStateAction;
00034     distribution = new double[actions.length];
00035     xa = new MutableVector[actions.length];
00036   }
00037 
00038   @Override
00039   public double pi(Action a) {
00040     return distribution[atoi(a)];
00041   }
00042 
00043   @Override
00044   public void update(RealVector x) {
00045     linearRangeAveraged.reset();
00046     double sum = 0;
00047     clearBuffers(x);
00048     for (int a_i = 0; a_i < actions.length; a_i++) {
00049       xa[a_i].set(toStateAction.stateAction(x, actions[a_i]));
00050       final double linearCombination = u.dotProduct(xa[a_i]);
00051       linearRangeOverall.update(linearCombination);
00052       linearRangeAveraged.update(linearCombination);
00053       double probabilityNotNormalized = Math.exp(linearCombination);
00054       assert Utils.checkValue(probabilityNotNormalized);
00055       distribution[a_i] = probabilityNotNormalized;
00056       sum += probabilityNotNormalized;
00057       xaBar.addToSelf(probabilityNotNormalized, xa[a_i]);
00058     }
00059     for (int i = 0; i < distribution.length; i++) {
00060       distribution[i] /= sum;
00061       assert Utils.checkValue(distribution[i]);
00062     }
00063     xaBar.mapMultiplyToSelf(1.0 / sum);
00064   }
00065 
00066   private void clearBuffers(RealVector x) {
00067     if (xaBar == null) {
00068       xaBar = toStateAction.stateAction(x, actions[0]).newInstance(u.size);
00069       gradBuffer = xaBar.newInstance(u.size);
00070       for (int i = 0; i < xa.length; i++)
00071         xa[i] = xaBar.newInstance(u.size);
00072       return;
00073     }
00074     xaBar.clear();
00075   }
00076 
00077   @Override
00078   public Action sampleAction() {
00079     return chooseAction(distribution);
00080   }
00081 
00082   @Override
00083   public PVector[] createParameters(int nbFeatures) {
00084     u = new PVector(toStateAction.vectorSize());
00085     return new PVector[] { u };
00086   }
00087 
00088   @Override
00089   public RealVector[] computeGradLog(Action a_t) {
00090     gradBuffer.clear();
00091     gradBuffer.set(xa[atoi(a_t)]);
00092     return new RealVector[] { gradBuffer.subtractToSelf(xaBar) };
00093   }
00094 
00095   @Override
00096   public int nbParameterVectors() {
00097     return 1;
00098   }
00099 
00100   @Override
00101   public double[] distribution() {
00102     return distribution;
00103   }
00104 
00105   static public double probaToLinearValue(int nbAction, double proba) {
00106     double max = Math.log(proba * (nbAction - 1)) - Math.log(1 - proba);
00107     assert proba > .5 && max > 0 || proba < .5 && max < 0;
00108     return max;
00109   }
00110 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark