RLPark 1.0.0
Reinforcement Learning Framework in Java

SwingPendulum.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.problems.pendulum;
00002 
00003 import java.util.Random;
00004 
00005 import rlpark.plugin.rltoys.envio.actions.Action;
00006 import rlpark.plugin.rltoys.envio.actions.ActionArray;
00007 import rlpark.plugin.rltoys.envio.observations.Legend;
00008 import rlpark.plugin.rltoys.envio.rl.TRStep;
00009 import rlpark.plugin.rltoys.math.ranges.Range;
00010 import rlpark.plugin.rltoys.problems.ProblemBounded;
00011 import rlpark.plugin.rltoys.problems.ProblemContinuousAction;
00012 import rlpark.plugin.rltoys.problems.ProblemDiscreteAction;
00013 import rlpark.plugin.rltoys.utils.Utils;
00014 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00015 
00016 public class SwingPendulum implements ProblemBounded, ProblemDiscreteAction, ProblemContinuousAction {
00017   public static final double uMax = 2.0;
00018   public boolean constantEpisodeTime = true;
00019   public static final ActionArray STOP = new ActionArray(0);
00020   public static final ActionArray RIGHT = new ActionArray(uMax);
00021   public static final ActionArray LEFT = new ActionArray(-uMax);
00022   private static final Action[] Actions = new Action[] { LEFT, STOP, RIGHT };
00023   public static final Range ActionRange = new Range(-uMax, uMax);
00024   protected static final String VELOCITY = "velocity";
00025   protected static final String THETA = "theta";
00026   protected static final Legend Legend = new Legend(THETA, VELOCITY);
00027   public static final Range ThetaRange = new Range(-Math.PI, Math.PI);
00028   protected static final double Mass = 1.0;
00029   protected static final double Length = 1.0;
00030   protected static final double G = 9.8;
00031   protected static final double StepTime = 0.01; // seconds
00032   protected static final double RequiredUpTime = 10.0; // seconds
00033   protected static final double UpRange = Math.PI / 4.0; // seconds
00034   protected static final double MaxVelocity = (Math.PI / 4.0) / StepTime;
00035   public static final Range VelocityRange = new Range(-MaxVelocity, MaxVelocity);
00036   public static final Range InitialThetaRange = new Range(-Math.PI, Math.PI);
00037   protected static final double initialVelocity = 0.0;
00038 
00039   final private boolean endOfEpisode;
00040   @Monitor
00041   protected double theta = 0.0;
00042   @Monitor
00043   protected double velocity = 0.0;
00044   protected final Random random;
00045   protected TRStep step;
00046   protected int upTime = 0;
00047 
00048   public SwingPendulum(Random random) {
00049     this(random, true);
00050   }
00051 
00052   public SwingPendulum(Random random, boolean endOfEpisode) {
00053     assert Mass * Length * G > uMax;
00054     this.random = random;
00055     this.endOfEpisode = endOfEpisode;
00056   }
00057 
00058   protected void update(ActionArray action) {
00059     double torque = ActionRange.bound(ActionArray.toDouble(action));
00060     assert Utils.checkValue(torque);
00061     double thetaAcceleration = -StepTime * velocity + Mass * G * Length * Math.sin(theta) + torque;
00062     assert Utils.checkValue(thetaAcceleration);
00063     velocity = VelocityRange.bound(velocity + thetaAcceleration);
00064     theta += velocity * StepTime;
00065     adjustTheta();
00066     upTime = Math.abs(theta) > UpRange ? 0 : upTime + 1;
00067     assert Utils.checkValue(theta);
00068     assert Utils.checkValue(velocity);
00069   }
00070 
00071   protected void adjustTheta() {
00072     if (theta >= Math.PI)
00073       theta -= 2 * Math.PI;
00074     if (theta < -Math.PI)
00075       theta += 2 * Math.PI;
00076   }
00077 
00078   @Override
00079   public TRStep step(Action action) {
00080     assert !step.isEpisodeEnding();
00081     update((ActionArray) action);
00082     step = new TRStep(step, action, new double[] { theta, velocity }, reward());
00083     if (isGoalReached())
00084       forceEndEpisode();
00085     return step;
00086   }
00087 
00088   protected double reward() {
00089     return Math.cos(theta);
00090   }
00091 
00092   private boolean isGoalReached() {
00093     if (!endOfEpisode)
00094       return false;
00095     if (constantEpisodeTime)
00096       return false;
00097     return upTime + 1 >= RequiredUpTime / StepTime;
00098   }
00099 
00100   @Override
00101   public TRStep forceEndEpisode() {
00102     step = step.createEndingStep();
00103     return step;
00104   }
00105 
00106   @Override
00107   public TRStep initialize() {
00108     initializeProblemData();
00109     step = new TRStep(new double[] { theta, velocity }, -1);
00110     return step;
00111   }
00112 
00113   protected void initializeProblemData() {
00114     upTime = 0;
00115     if (random == null) {
00116       theta = Math.PI / 2;
00117       velocity = 0.0;
00118     } else {
00119       theta = InitialThetaRange.choose(random);
00120       velocity = initialVelocity;
00121     }
00122     adjustTheta();
00123   }
00124 
00125   @Override
00126   public Legend legend() {
00127     return Legend;
00128   }
00129 
00130   @Override
00131   public Range[] getObservationRanges() {
00132     return new Range[] { ThetaRange, VelocityRange };
00133   }
00134 
00135   public double theta() {
00136     return theta;
00137   }
00138 
00139   public double velocity() {
00140     return velocity;
00141   }
00142 
00143   @Override
00144   public Action[] actions() {
00145     return Actions;
00146   }
00147 
00148   @Override
00149   public Range[] actionRanges() {
00150     return new Range[] { ActionRange };
00151   }
00152 
00153   @Override
00154   public TRStep lastStep() {
00155     return step;
00156   }
00157 
00158 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark