RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.mazes; 00002 00003 import java.awt.Point; 00004 00005 import rlpark.plugin.rltoys.algorithms.functions.states.Projector; 00006 import rlpark.plugin.rltoys.envio.actions.Action; 00007 import rlpark.plugin.rltoys.envio.actions.ActionArray; 00008 import rlpark.plugin.rltoys.envio.observations.Legend; 00009 import rlpark.plugin.rltoys.envio.rl.TRStep; 00010 import rlpark.plugin.rltoys.math.vector.RealVector; 00011 import rlpark.plugin.rltoys.math.vector.implementations.BVector; 00012 import rlpark.plugin.rltoys.problems.ProblemDiscreteAction; 00013 00014 public class Maze implements ProblemDiscreteAction { 00015 static private Legend legend = new Legend("x", "y"); 00016 public static final ActionArray Left = new ActionArray(-1, 0); 00017 public static final ActionArray Right = new ActionArray(+1, 0); 00018 public static final ActionArray Up = new ActionArray(0, +1); 00019 public static final ActionArray Down = new ActionArray(0, -1); 00020 public static final ActionArray Stop = new ActionArray(0, 0); 00021 static final public Action[] Actions = { Left, Right, Stop, Up, Down }; 00022 private final byte[][] layout; 00023 private final Point start; 00024 private final boolean[][] endEpisode; 00025 private final double[][] rewardFunction; 00026 private TRStep step; 00027 00028 public Maze(byte[][] layout, double[][] rewardFunction, boolean[][] endEpisode, Point start) { 00029 this.layout = layout; 00030 this.rewardFunction = rewardFunction; 00031 this.endEpisode = endEpisode; 00032 this.start = start; 00033 initialize(); 00034 } 00035 00036 @Override 00037 public TRStep initialize() { 00038 step = new TRStep(new double[] { start.x, start.y }, rewardFunction[start.x][start.y]); 00039 return step; 00040 } 00041 00042 @Override 00043 public TRStep step(Action action) { 00044 double[] actions = ((ActionArray) action).actions; 00045 int newX = (int) (step.o_tp1[0] + actions[0]); 00046 int newY = (int) (step.o_tp1[1] + actions[1]); 00047 if (layout[newX][newY] != 0) { 00048 newX = (int) step.o_tp1[0]; 00049 newY = (int) step.o_tp1[1]; 00050 } 00051 step = new TRStep(step, action, new double[] { newX, newY }, rewardFunction[newX][newY]); 00052 if (endEpisode[newX][newY]) 00053 forceEndEpisode(); 00054 return step; 00055 } 00056 00057 @Override 00058 public TRStep forceEndEpisode() { 00059 step = step.createEndingStep(); 00060 return step; 00061 } 00062 00063 @Override 00064 public TRStep lastStep() { 00065 return step; 00066 } 00067 00068 @Override 00069 public Legend legend() { 00070 return legend; 00071 } 00072 00073 @Override 00074 public Action[] actions() { 00075 return Actions; 00076 } 00077 00078 public byte[][] layout() { 00079 return layout; 00080 } 00081 00082 public Point size() { 00083 return new Point(layout.length, layout[0].length); 00084 } 00085 00086 public boolean[][] endEpisode() { 00087 return endEpisode; 00088 } 00089 00090 @SuppressWarnings("serial") 00091 public Projector getMarkovProjector() { 00092 final Point size = size(); 00093 final BVector projection = new BVector(size.x * size.y); 00094 return new Projector() { 00095 @Override 00096 public int vectorSize() { 00097 return projection.size; 00098 } 00099 00100 @Override 00101 public double vectorNorm() { 00102 return 1; 00103 } 00104 00105 @Override 00106 public RealVector project(double[] obs) { 00107 projection.clear(); 00108 if (obs != null) 00109 projection.setOn((int) (obs[0] * size.y + obs[1])); 00110 return projection; 00111 } 00112 }; 00113 } 00114 }