RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.pendulum; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.envio.actions.ActionArray; 00007 import rlpark.plugin.rltoys.envio.observations.Legend; 00008 import rlpark.plugin.rltoys.envio.rl.TRStep; 00009 import rlpark.plugin.rltoys.math.ranges.Range; 00010 import rlpark.plugin.rltoys.problems.ProblemBounded; 00011 import rlpark.plugin.rltoys.problems.ProblemContinuousAction; 00012 import rlpark.plugin.rltoys.problems.ProblemDiscreteAction; 00013 import rlpark.plugin.rltoys.utils.Utils; 00014 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00015 00016 public class SwingPendulum implements ProblemBounded, ProblemDiscreteAction, ProblemContinuousAction { 00017 public static final double uMax = 2.0; 00018 public boolean constantEpisodeTime = true; 00019 public static final ActionArray STOP = new ActionArray(0); 00020 public static final ActionArray RIGHT = new ActionArray(uMax); 00021 public static final ActionArray LEFT = new ActionArray(-uMax); 00022 private static final Action[] Actions = new Action[] { LEFT, STOP, RIGHT }; 00023 public static final Range ActionRange = new Range(-uMax, uMax); 00024 protected static final String VELOCITY = "velocity"; 00025 protected static final String THETA = "theta"; 00026 protected static final Legend Legend = new Legend(THETA, VELOCITY); 00027 public static final Range ThetaRange = new Range(-Math.PI, Math.PI); 00028 protected static final double Mass = 1.0; 00029 protected static final double Length = 1.0; 00030 protected static final double G = 9.8; 00031 protected static final double StepTime = 0.01; // seconds 00032 protected static final double RequiredUpTime = 10.0; // seconds 00033 protected static final double UpRange = Math.PI / 4.0; // seconds 00034 protected static final double MaxVelocity = (Math.PI / 4.0) / StepTime; 00035 public static final Range VelocityRange = new Range(-MaxVelocity, MaxVelocity); 00036 public static final Range InitialThetaRange = new Range(-Math.PI, Math.PI); 00037 protected static final double initialVelocity = 0.0; 00038 00039 final private boolean endOfEpisode; 00040 @Monitor 00041 protected double theta = 0.0; 00042 @Monitor 00043 protected double velocity = 0.0; 00044 protected final Random random; 00045 protected TRStep step; 00046 protected int upTime = 0; 00047 00048 public SwingPendulum(Random random) { 00049 this(random, true); 00050 } 00051 00052 public SwingPendulum(Random random, boolean endOfEpisode) { 00053 assert Mass * Length * G > uMax; 00054 this.random = random; 00055 this.endOfEpisode = endOfEpisode; 00056 } 00057 00058 protected void update(ActionArray action) { 00059 double torque = ActionRange.bound(ActionArray.toDouble(action)); 00060 assert Utils.checkValue(torque); 00061 double thetaAcceleration = -StepTime * velocity + Mass * G * Length * Math.sin(theta) + torque; 00062 assert Utils.checkValue(thetaAcceleration); 00063 velocity = VelocityRange.bound(velocity + thetaAcceleration); 00064 theta += velocity * StepTime; 00065 adjustTheta(); 00066 upTime = Math.abs(theta) > UpRange ? 0 : upTime + 1; 00067 assert Utils.checkValue(theta); 00068 assert Utils.checkValue(velocity); 00069 } 00070 00071 protected void adjustTheta() { 00072 if (theta >= Math.PI) 00073 theta -= 2 * Math.PI; 00074 if (theta < -Math.PI) 00075 theta += 2 * Math.PI; 00076 } 00077 00078 @Override 00079 public TRStep step(Action action) { 00080 assert !step.isEpisodeEnding(); 00081 update((ActionArray) action); 00082 step = new TRStep(step, action, new double[] { theta, velocity }, reward()); 00083 if (isGoalReached()) 00084 forceEndEpisode(); 00085 return step; 00086 } 00087 00088 protected double reward() { 00089 return Math.cos(theta); 00090 } 00091 00092 private boolean isGoalReached() { 00093 if (!endOfEpisode) 00094 return false; 00095 if (constantEpisodeTime) 00096 return false; 00097 return upTime + 1 >= RequiredUpTime / StepTime; 00098 } 00099 00100 @Override 00101 public TRStep forceEndEpisode() { 00102 step = step.createEndingStep(); 00103 return step; 00104 } 00105 00106 @Override 00107 public TRStep initialize() { 00108 initializeProblemData(); 00109 step = new TRStep(new double[] { theta, velocity }, -1); 00110 return step; 00111 } 00112 00113 protected void initializeProblemData() { 00114 upTime = 0; 00115 if (random == null) { 00116 theta = Math.PI / 2; 00117 velocity = 0.0; 00118 } else { 00119 theta = InitialThetaRange.choose(random); 00120 velocity = initialVelocity; 00121 } 00122 adjustTheta(); 00123 } 00124 00125 @Override 00126 public Legend legend() { 00127 return Legend; 00128 } 00129 00130 @Override 00131 public Range[] getObservationRanges() { 00132 return new Range[] { ThetaRange, VelocityRange }; 00133 } 00134 00135 public double theta() { 00136 return theta; 00137 } 00138 00139 public double velocity() { 00140 return velocity; 00141 } 00142 00143 @Override 00144 public Action[] actions() { 00145 return Actions; 00146 } 00147 00148 @Override 00149 public Range[] actionRanges() { 00150 return new Range[] { ActionRange }; 00151 } 00152 00153 @Override 00154 public TRStep lastStep() { 00155 return step; 00156 } 00157 00158 }