RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.stategraph02; 00002 00003 import java.io.Serializable; 00004 import java.util.Random; 00005 00006 import rlpark.plugin.rltoys.algorithms.functions.states.Projector; 00007 import rlpark.plugin.rltoys.envio.actions.Action; 00008 import rlpark.plugin.rltoys.envio.observations.Legend; 00009 import rlpark.plugin.rltoys.envio.rl.TRStep; 00010 import rlpark.plugin.rltoys.problems.RLProblem; 00011 00012 public class GraphProblem implements Serializable, RLProblem { 00013 private static final long serialVersionUID = 6251650836939403789L; 00014 private final State s0; 00015 private State currentState; 00016 private TRStep step; 00017 private final Legend legend = new Legend("stateIndex"); 00018 private final StateGraph stateGraph; 00019 private final Random random; 00020 private final Projector projector; 00021 00022 public GraphProblem(Random random, State s0, StateGraph stateGraph, Projector projector) { 00023 this.random = random; 00024 this.stateGraph = stateGraph; 00025 this.s0 = s0; 00026 this.projector = projector; 00027 assert stateGraph.checkDistribution(); 00028 } 00029 00030 @Override 00031 public TRStep initialize() { 00032 currentState = s0; 00033 step = new TRStep(toObs(currentState), currentState.reward); 00034 return step; 00035 } 00036 00037 private double[] toObs(State s) { 00038 return new double[] { stateGraph.indexOf(s) }; 00039 } 00040 00041 @Override 00042 public TRStep step(Action action) { 00043 currentState = stateGraph.sampleNextState(random, currentState, action); 00044 step = new TRStep(step, action, toObs(currentState), currentState.reward); 00045 if (stateGraph.isTerminal(currentState)) 00046 step = step.createEndingStep(); 00047 return step; 00048 } 00049 00050 @Override 00051 public TRStep forceEndEpisode() { 00052 step = step.createEndingStep(); 00053 return step; 00054 } 00055 00056 @Override 00057 public TRStep lastStep() { 00058 return step; 00059 } 00060 00061 @Override 00062 public Legend legend() { 00063 return legend; 00064 } 00065 00066 public StateGraph stateGraph() { 00067 return stateGraph; 00068 } 00069 00070 public Projector projector() { 00071 return projector; 00072 } 00073 }