RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.example.demos.learning; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.agents.rl.LearnerAgentFA; 00006 import rlpark.plugin.rltoys.algorithms.control.ControlLearner; 00007 import rlpark.plugin.rltoys.algorithms.control.acting.EpsilonGreedy; 00008 import rlpark.plugin.rltoys.algorithms.control.qlearning.QLearning; 00009 import rlpark.plugin.rltoys.algorithms.control.qlearning.QLearningControl; 00010 import rlpark.plugin.rltoys.algorithms.functions.stateactions.TabularAction; 00011 import rlpark.plugin.rltoys.algorithms.functions.states.Projector; 00012 import rlpark.plugin.rltoys.algorithms.traces.RTraces; 00013 import rlpark.plugin.rltoys.envio.policy.Policy; 00014 import rlpark.plugin.rltoys.experiments.helpers.Runner; 00015 import rlpark.plugin.rltoys.experiments.helpers.Runner.RunnerEvent; 00016 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00017 import rlpark.plugin.rltoys.problems.mazes.Maze; 00018 import rlpark.plugin.rltoys.problems.mazes.MazeValueFunction; 00019 import rlpark.plugin.rltoys.problems.mazes.Mazes; 00020 import zephyr.plugin.core.api.Zephyr; 00021 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00022 import zephyr.plugin.core.api.signals.Listener; 00023 import zephyr.plugin.core.api.synchronization.Clock; 00024 00025 @Monitor 00026 public class QLearningMaze implements Runnable { 00027 final MazeValueFunction mazeValueFunction; 00028 private final Maze problem = Mazes.createBookMaze(); 00029 private final ControlLearner control; 00030 private final Clock clock = new Clock("QLearningMaze"); 00031 private final Projector projector; 00032 private final PVector occupancy; 00033 private final LearnerAgentFA agent; 00034 00035 public QLearningMaze() { 00036 projector = problem.getMarkovProjector(); 00037 occupancy = new PVector(projector.vectorSize()); 00038 TabularAction toStateAction = new TabularAction(problem.actions(), projector.vectorNorm(), projector.vectorSize()); 00039 double alpha = .15 / projector.vectorNorm(); 00040 double gamma = 1.0; 00041 double lambda = 0.6; 00042 QLearning qlearning = new QLearning(problem.actions(), alpha, gamma, lambda, toStateAction, new RTraces()); 00043 double epsilon = 0.3; 00044 Policy acting = new EpsilonGreedy(new Random(0), problem.actions(), toStateAction, qlearning, epsilon); 00045 control = new QLearningControl(acting, qlearning); 00046 agent = new LearnerAgentFA(control, projector); 00047 mazeValueFunction = new MazeValueFunction(problem, qlearning, toStateAction, qlearning.greedy()); 00048 Zephyr.advertise(clock, this); 00049 } 00050 00051 @Override 00052 public void run() { 00053 Runner runner = new Runner(problem, agent); 00054 runner.onEpisodeEnd.connect(new Listener<Runner.RunnerEvent>() { 00055 @Override 00056 public void listen(RunnerEvent eventInfo) { 00057 System.out.println(String.format("Episode %d: %d steps", eventInfo.nbEpisodeDone, eventInfo.step.time)); 00058 } 00059 }); 00060 while (clock.tick()) { 00061 runner.step(); 00062 occupancy.addToSelf(agent.lastState()); 00063 } 00064 } 00065 00066 public static void main(String[] args) { 00067 new QLearningMaze().run(); 00068 } 00069 }