RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy; 00002 00003 import rlpark.plugin.rltoys.algorithms.functions.Predictor; 00004 import rlpark.plugin.rltoys.algorithms.functions.states.Projector; 00005 import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD; 00006 import rlpark.plugin.rltoys.math.vector.RealVector; 00007 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00008 import rlpark.plugin.rltoys.math.vector.implementations.Vectors; 00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00010 00011 00012 public class CriticAdapterFA implements OffPolicyTD { 00013 private static final long serialVersionUID = 4767252828929104353L; 00014 @Monitor 00015 private final OffPolicyTD offPolicyTD; 00016 private final Projector projector; 00017 private RealVector o_t = null; 00018 private RealVector x_t = null; 00019 00020 public CriticAdapterFA(Projector projector, OffPolicyTD offPolicyTD) { 00021 this.projector = projector; 00022 this.offPolicyTD = offPolicyTD; 00023 } 00024 00025 @Override 00026 public void resetWeight(int index) { 00027 offPolicyTD.resetWeight(index); 00028 } 00029 00030 @Override 00031 public PVector weights() { 00032 return offPolicyTD.weights(); 00033 } 00034 00035 private RealVector projectIFN(RealVector o) { 00036 return projector.project(o instanceof PVector ? ((PVector) o).data : null); 00037 } 00038 00039 @Override 00040 public double predict(RealVector x) { 00041 return offPolicyTD.predict(projectIFN(x)); 00042 } 00043 00044 @Override 00045 public double error() { 00046 return offPolicyTD.error(); 00047 } 00048 00049 @Override 00050 public double update(double pi_t, double b_t, RealVector o_t, RealVector o_tp1, double r_tp1) { 00051 if (o_t != this.o_t) { 00052 x_t = Vectors.bufferedCopy(projectIFN(o_t), x_t); 00053 this.o_t = o_t; 00054 } 00055 RealVector x_tp1 = projectIFN(o_tp1); 00056 double delta = offPolicyTD.update(pi_t, b_t, x_t, x_tp1, r_tp1); 00057 x_t = Vectors.bufferedCopy(x_tp1, x_t); 00058 this.o_t = o_tp1; 00059 return delta; 00060 } 00061 00062 @Override 00063 public double prediction() { 00064 return offPolicyTD.prediction(); 00065 } 00066 00067 @Override 00068 public PVector secondaryWeights() { 00069 return offPolicyTD.secondaryWeights(); 00070 } 00071 00072 public Projector projector() { 00073 return projector; 00074 } 00075 00076 public Predictor predictor() { 00077 return offPolicyTD; 00078 } 00079 }