RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.mountaincar; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.envio.actions.ActionArray; 00007 import rlpark.plugin.rltoys.envio.observations.Legend; 00008 import rlpark.plugin.rltoys.envio.rl.TRStep; 00009 import rlpark.plugin.rltoys.math.ranges.Range; 00010 import rlpark.plugin.rltoys.problems.ProblemBounded; 00011 import rlpark.plugin.rltoys.problems.ProblemContinuousAction; 00012 import rlpark.plugin.rltoys.problems.ProblemDiscreteAction; 00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00014 00015 public class MountainCar implements ProblemBounded, ProblemDiscreteAction, ProblemContinuousAction { 00016 static private final double MaxActionValue = 1.0; 00017 public static final ActionArray LEFT = new ActionArray(-MaxActionValue); 00018 public static final ActionArray RIGHT = new ActionArray(MaxActionValue); 00019 public static final ActionArray STOP = new ActionArray(0.0); 00020 protected static final Action[] Actions = { LEFT, STOP, RIGHT }; 00021 static public final Range ActionRange = new Range(-MaxActionValue, MaxActionValue); 00022 00023 public static final String VELOCITY = "velocity"; 00024 public static final String POSITION = "position"; 00025 public static final Legend legend = new Legend(POSITION, VELOCITY); 00026 00027 @Monitor 00028 protected double position; 00029 @Monitor 00030 protected double velocity = 0.0; 00031 protected static final Range positionRange = new Range(-1.2, 0.6); 00032 protected static final Range velocityRange = new Range(-0.07, 0.07); 00033 00034 private static final double target = positionRange.max(); 00035 private double throttleFactor = 1.0; 00036 private final Random random; 00037 private TRStep step; 00038 private final int episodeLengthMax; 00039 00040 public MountainCar(Random random) { 00041 this(random, -1); 00042 } 00043 00044 public MountainCar(Random random, int episodeLengthMax) { 00045 this.random = random; 00046 this.episodeLengthMax = episodeLengthMax; 00047 } 00048 00049 protected void update(ActionArray action) { 00050 double actionThrottle = ActionRange.bound(ActionArray.toDouble(action)); 00051 double throttle = actionThrottle * throttleFactor; 00052 velocity = velocityRange.bound(velocity + 0.001 * throttle - 0.0025 * Math.cos(3 * position)); 00053 position += velocity; 00054 if (position < positionRange.min()) 00055 velocity = 0.0; 00056 position = positionRange.bound(position); 00057 } 00058 00059 @Override 00060 public TRStep step(Action action) { 00061 update((ActionArray) action); 00062 step = new TRStep(step, action, new double[] { position, velocity }, -1.0); 00063 if (isGoalReached()) 00064 forceEndEpisode(); 00065 return step; 00066 } 00067 00068 @Override 00069 public TRStep forceEndEpisode() { 00070 step = step.createEndingStep(); 00071 return step; 00072 } 00073 00074 private boolean isGoalReached() { 00075 return position >= target || (episodeLengthMax > 0 && step != null && step.time > episodeLengthMax); 00076 } 00077 00078 @Override 00079 public TRStep initialize() { 00080 if (random == null) { 00081 position = -0.5; 00082 velocity = 0.0; 00083 } else { 00084 position = positionRange.choose(random); 00085 velocity = velocityRange.choose(random); 00086 } 00087 step = new TRStep(new double[] { position, velocity }, -1); 00088 return step; 00089 } 00090 00091 @Override 00092 public Legend legend() { 00093 return legend; 00094 } 00095 00096 @Override 00097 public Action[] actions() { 00098 return Actions; 00099 } 00100 00101 public void setThrottleFactor(double factor) { 00102 throttleFactor = factor; 00103 } 00104 00105 @Override 00106 public Range[] getObservationRanges() { 00107 return new Range[] { positionRange, velocityRange }; 00108 } 00109 00110 @Override 00111 public Range[] actionRanges() { 00112 return new Range[] { ActionRange }; 00113 } 00114 00115 @Override 00116 public TRStep lastStep() { 00117 return step; 00118 } 00119 00120 static public double height(double position) { 00121 return Math.sin(3.0 * position); 00122 } 00123 }