RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.nostate; 00002 00003 import static rlpark.plugin.rltoys.utils.Utils.square; 00004 import rlpark.plugin.rltoys.envio.actions.Action; 00005 import rlpark.plugin.rltoys.envio.actions.ActionArray; 00006 import rlpark.plugin.rltoys.envio.observations.Legend; 00007 import rlpark.plugin.rltoys.envio.rl.TRStep; 00008 import rlpark.plugin.rltoys.math.ranges.Range; 00009 import rlpark.plugin.rltoys.problems.RLProblem; 00010 00011 public class NoStateProblem implements RLProblem { 00012 public interface NoStateRewardFunction { 00013 double reward(double action); 00014 } 00015 00016 public static class NormalReward implements NoStateRewardFunction { 00017 public final double mu; 00018 private final double sigma; 00019 00020 public NormalReward(double mu, double sigma) { 00021 this.mu = mu; 00022 this.sigma = sigma; 00023 } 00024 00025 @Override 00026 public double reward(double x) { 00027 return 1.0 / Math.sqrt(2 * Math.PI * square(sigma)) * Math.exp(-square(x - mu) / (2 * square(sigma))); 00028 } 00029 } 00030 00031 private TRStep step = null; 00032 private final NoStateRewardFunction reward; 00033 public final Range range; 00034 private static final Legend legend = new Legend("State"); 00035 00036 public NoStateProblem(NoStateRewardFunction reward) { 00037 this(null, reward); 00038 } 00039 00040 00041 public NoStateProblem(Range range, NoStateRewardFunction reward) { 00042 this.reward = reward; 00043 this.range = range; 00044 } 00045 00046 @Override 00047 public TRStep initialize() { 00048 step = new TRStep(state(), 0); 00049 return step; 00050 } 00051 00052 private double[] state() { 00053 return new double[] { 1.0 }; 00054 } 00055 00056 00057 @Override 00058 public TRStep step(Action a_t) { 00059 assert step != null; 00060 if (a_t == null) 00061 return new TRStep(step, null, null, -Double.MAX_VALUE); 00062 double a = ActionArray.toDouble(a_t); 00063 if (range != null) 00064 a = range.bound(a); 00065 double r = reward.reward(a); 00066 step = new TRStep(step, a_t, state(), r); 00067 return step; 00068 } 00069 00070 @Override 00071 public Legend legend() { 00072 return legend; 00073 } 00074 00075 00076 @Override 00077 public TRStep lastStep() { 00078 return step; 00079 } 00080 00081 @Override 00082 public TRStep forceEndEpisode() { 00083 step = step.createEndingStep(); 00084 return step; 00085 } 00086 }