RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.qlearning; 00002 00003 import rlpark.plugin.rltoys.algorithms.LinearLearner; 00004 import rlpark.plugin.rltoys.algorithms.control.acting.Greedy; 00005 import rlpark.plugin.rltoys.algorithms.functions.Predictor; 00006 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00007 import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm; 00008 import rlpark.plugin.rltoys.algorithms.traces.Traces; 00009 import rlpark.plugin.rltoys.envio.actions.Action; 00010 import rlpark.plugin.rltoys.envio.policy.Policy; 00011 import rlpark.plugin.rltoys.math.vector.RealVector; 00012 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00014 00015 @Monitor 00016 public class QLearning implements Predictor, LinearLearner, EligibilityTraceAlgorithm { 00017 private static final long serialVersionUID = -404558746167490755L; 00018 @Monitor(level = 4) 00019 protected final PVector theta; 00020 private final Traces e; 00021 private final double lambda; 00022 private final double gamma; 00023 private final double alpha; 00024 private final StateToStateAction toStateAction; 00025 private double delta; 00026 private final Greedy greedy; 00027 00028 public QLearning(Action[] actions, double alpha, double gamma, double lambda, StateToStateAction toStateAction, 00029 Traces prototype) { 00030 this.alpha = alpha; 00031 this.gamma = gamma; 00032 this.lambda = lambda; 00033 this.toStateAction = toStateAction; 00034 greedy = new Greedy(this, actions, toStateAction); 00035 theta = new PVector(toStateAction.vectorSize()); 00036 e = prototype.newTraces(toStateAction.vectorSize()); 00037 } 00038 00039 public double update(RealVector x_t, Action a_t, RealVector x_tp1, Action a_tp1, double r_tp1) { 00040 if (x_t == null) 00041 return initEpisode(); 00042 greedy.update(x_t); 00043 Action at_star = greedy.bestAction(); 00044 greedy.update(x_tp1); 00045 RealVector phi_sa_t = toStateAction.stateAction(x_t, a_t); 00046 delta = r_tp1 + gamma * greedy.bestActionValue() - theta.dotProduct(phi_sa_t); 00047 if (a_t == at_star) 00048 e.update(gamma * lambda, phi_sa_t); 00049 else { 00050 e.clear(); 00051 e.update(0, phi_sa_t); 00052 } 00053 theta.addToSelf(alpha * delta, e.vect()); 00054 return delta; 00055 } 00056 00057 private double initEpisode() { 00058 if (e != null) 00059 e.clear(); 00060 delta = 0.0; 00061 return delta; 00062 } 00063 00064 @Override 00065 public double predict(RealVector phi_sa) { 00066 return theta.dotProduct(phi_sa); 00067 } 00068 00069 public PVector theta() { 00070 return theta; 00071 } 00072 00073 @Override 00074 public void resetWeight(int index) { 00075 theta.setEntry(index, 0); 00076 } 00077 00078 @Override 00079 public PVector weights() { 00080 return theta; 00081 } 00082 00083 @Override 00084 public double error() { 00085 return delta; 00086 } 00087 00088 public Policy greedy() { 00089 return greedy; 00090 } 00091 00092 @Override 00093 public Traces traces() { 00094 return e; 00095 } 00096 }