RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.mazes; 00002 00003 import rlpark.plugin.rltoys.algorithms.functions.Predictor; 00004 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.envio.policy.Policy; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 00009 public class MazeValueFunction implements MazeFunction { 00010 private final Predictor predictor; 00011 private final Policy policy; 00012 private final MazeProjector mazeProjector; 00013 00014 public MazeValueFunction(Maze maze, Predictor predictor) { 00015 this(maze, predictor, null, null); 00016 } 00017 00018 public MazeValueFunction(Maze maze, Predictor predictor, StateToStateAction toStateAction, Policy policy) { 00019 this.predictor = predictor; 00020 this.policy = policy; 00021 mazeProjector = new MazeProjector(maze, maze.getMarkovProjector(), toStateAction); 00022 } 00023 00024 @Override 00025 public float value(int x, int y) { 00026 float sum = 0.0f; 00027 RealVector v_x = mazeProjector.toState(x, y); 00028 if (mazeProjector.toStateAction() == null) 00029 return (float) predictor.predict(v_x); 00030 policy.update(v_x); 00031 for (Action a : mazeProjector.maze().actions()) { 00032 double prob = policy.pi(a); 00033 if (prob == 0) 00034 continue; 00035 RealVector v_xa = mazeProjector.stateAction(v_x, a); 00036 sum += predictor.predict(v_xa) * prob; 00037 } 00038 return sum; 00039 } 00040 }