RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.gq; 00002 00003 import rlpark.plugin.rltoys.algorithms.control.OffPolicyLearner; 00004 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00005 import rlpark.plugin.rltoys.envio.actions.Action; 00006 import rlpark.plugin.rltoys.envio.policy.Policies; 00007 import rlpark.plugin.rltoys.envio.policy.Policy; 00008 import rlpark.plugin.rltoys.math.vector.MutableVector; 00009 import rlpark.plugin.rltoys.math.vector.RealVector; 00010 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00011 import rlpark.plugin.rltoys.math.vector.pool.VectorPool; 00012 import rlpark.plugin.rltoys.math.vector.pool.VectorPools; 00013 import rlpark.plugin.rltoys.utils.Prototype; 00014 import rlpark.plugin.rltoys.utils.Utils; 00015 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00016 00017 public class GreedyGQ implements OffPolicyLearner { 00018 private static final long serialVersionUID = 7017521530598253457L; 00019 @Monitor 00020 protected final GQ gq; 00021 @Monitor 00022 protected final Policy target; 00023 protected final Policy behaviour; 00024 protected final StateToStateAction toStateAction; 00025 @Monitor 00026 public double rho_t; 00027 private final Action[] actions; 00028 private double delta_t; 00029 private final RealVector prototype; 00030 00031 @SuppressWarnings("unchecked") 00032 public GreedyGQ(GQ gq, Action[] actions, StateToStateAction toStateAction, Policy target, Policy behaviour) { 00033 this.gq = gq; 00034 this.target = target; 00035 this.behaviour = behaviour; 00036 this.toStateAction = toStateAction; 00037 this.actions = actions; 00038 prototype = ((Prototype<RealVector>) gq.e).prototype(); 00039 } 00040 00041 public double update(RealVector x_t, Action a_t, double r_tp1, double gamma_tp1, double z_tp1, RealVector x_tp1, 00042 Action a_tp1) { 00043 rho_t = 0.0; 00044 if (a_t != null) { 00045 target.update(x_t); 00046 behaviour.update(x_t); 00047 rho_t = target.pi(a_t) / behaviour.pi(a_t); 00048 } 00049 assert Utils.checkValue(rho_t); 00050 VectorPool pool = VectorPools.pool(prototype, gq.v.size); 00051 MutableVector sa_bar_tp1 = pool.newVector(); 00052 if (x_t != null && x_tp1 != null) { 00053 target.update(x_tp1); 00054 for (Action a : actions) { 00055 double pi = target.pi(a); 00056 if (pi == 0) 00057 continue; 00058 sa_bar_tp1.addToSelf(pi, toStateAction.stateAction(x_tp1, a)); 00059 } 00060 } 00061 RealVector phi_stat = x_t != null ? toStateAction.stateAction(x_t, a_t) : null; 00062 delta_t = gq.update(phi_stat, rho_t, r_tp1, sa_bar_tp1, z_tp1); 00063 pool.releaseAll(); 00064 return delta_t; 00065 } 00066 00067 public PVector theta() { 00068 return gq.v; 00069 } 00070 00071 public double gamma() { 00072 return 1 - gq.beta_tp1; 00073 } 00074 00075 public GQ gq() { 00076 return gq; 00077 } 00078 00079 @Override 00080 public Policy targetPolicy() { 00081 return target; 00082 } 00083 00084 @Override 00085 public void learn(RealVector x_t, Action a_t, RealVector x_tp1, Action a_tp1, double reward) { 00086 update(x_t, a_t, reward, gamma(), 0, x_tp1, a_tp1); 00087 } 00088 00089 @Override 00090 public Action proposeAction(RealVector x_t) { 00091 return Policies.decide(target, x_t); 00092 } 00093 00094 @Override 00095 public GQ predictor() { 00096 return gq; 00097 } 00098 }