RLPark 1.0.0
Reinforcement Learning Framework in Java

SoftMax.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.acting;
00002 
00003 import java.util.Arrays;
00004 import java.util.Random;
00005 
00006 import rlpark.plugin.rltoys.algorithms.functions.Predictor;
00007 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
00008 import rlpark.plugin.rltoys.envio.actions.Action;
00009 import rlpark.plugin.rltoys.envio.policy.StochasticPolicy;
00010 import rlpark.plugin.rltoys.math.vector.RealVector;
00011 import rlpark.plugin.rltoys.utils.Utils;
00012 
00013 public class SoftMax extends StochasticPolicy {
00014   private static final long serialVersionUID = -2129719316562814077L;
00015   private final StateToStateAction toStateAction;
00016   private final double temperature;
00017   private final Predictor predictor;
00018   private final double[] distribution;
00019 
00020   public SoftMax(Random random, Predictor predictor, Action[] actions, StateToStateAction toStateAction,
00021       double temperature) {
00022     super(random, actions);
00023     this.toStateAction = toStateAction;
00024     this.temperature = temperature;
00025     this.predictor = predictor;
00026     distribution = new double[actions.length];
00027   }
00028 
00029   public SoftMax(Random random, Predictor predictor, Action[] actions, StateToStateAction toStateAction) {
00030     this(random, predictor, actions, toStateAction, 1);
00031   }
00032 
00033   @Override
00034   public Action sampleAction() {
00035     return chooseAction(distribution);
00036   }
00037 
00038   @Override
00039   public void update(RealVector x) {
00040     double sum = 0.0;
00041     for (int i = 0; i < actions.length; i++) {
00042       Action action = actions[i];
00043       RealVector phi_sa = toStateAction.stateAction(x, action);
00044       double value = Math.exp(predictor.predict(phi_sa) / temperature);
00045       assert Utils.checkValue(value);
00046       sum += value;
00047       distribution[i] = value;
00048     }
00049     if (sum == 0) {
00050       Arrays.fill(distribution, 1.0);
00051       sum = distribution.length;
00052     }
00053     for (int i = 0; i < distribution.length; i++) {
00054       distribution[i] /= sum;
00055       assert Utils.checkValue(distribution[i]);
00056     }
00057     assert checkDistribution(distribution);
00058   }
00059 
00060   @Override
00061   public double pi(Action a) {
00062     return distribution[atoi(a)];
00063   }
00064 
00065   @Override
00066   public double[] distribution() {
00067     return distribution;
00068   }
00069 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark