RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.acting; 00002 00003 import java.util.Arrays; 00004 import java.util.Random; 00005 00006 import rlpark.plugin.rltoys.algorithms.functions.Predictor; 00007 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00008 import rlpark.plugin.rltoys.envio.actions.Action; 00009 import rlpark.plugin.rltoys.envio.policy.StochasticPolicy; 00010 import rlpark.plugin.rltoys.math.vector.RealVector; 00011 import rlpark.plugin.rltoys.utils.Utils; 00012 00013 public class SoftMax extends StochasticPolicy { 00014 private static final long serialVersionUID = -2129719316562814077L; 00015 private final StateToStateAction toStateAction; 00016 private final double temperature; 00017 private final Predictor predictor; 00018 private final double[] distribution; 00019 00020 public SoftMax(Random random, Predictor predictor, Action[] actions, StateToStateAction toStateAction, 00021 double temperature) { 00022 super(random, actions); 00023 this.toStateAction = toStateAction; 00024 this.temperature = temperature; 00025 this.predictor = predictor; 00026 distribution = new double[actions.length]; 00027 } 00028 00029 public SoftMax(Random random, Predictor predictor, Action[] actions, StateToStateAction toStateAction) { 00030 this(random, predictor, actions, toStateAction, 1); 00031 } 00032 00033 @Override 00034 public Action sampleAction() { 00035 return chooseAction(distribution); 00036 } 00037 00038 @Override 00039 public void update(RealVector x) { 00040 double sum = 0.0; 00041 for (int i = 0; i < actions.length; i++) { 00042 Action action = actions[i]; 00043 RealVector phi_sa = toStateAction.stateAction(x, action); 00044 double value = Math.exp(predictor.predict(phi_sa) / temperature); 00045 assert Utils.checkValue(value); 00046 sum += value; 00047 distribution[i] = value; 00048 } 00049 if (sum == 0) { 00050 Arrays.fill(distribution, 1.0); 00051 sum = distribution.length; 00052 } 00053 for (int i = 0; i < distribution.length; i++) { 00054 distribution[i] /= sum; 00055 assert Utils.checkValue(distribution[i]); 00056 } 00057 assert checkDistribution(distribution); 00058 } 00059 00060 @Override 00061 public double pi(Action a) { 00062 return distribution[atoi(a)]; 00063 } 00064 00065 @Override 00066 public double[] distribution() { 00067 return distribution; 00068 } 00069 }