RLPark 1.0.0
Reinforcement Learning Framework in Java

CriticAdapterFA.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy;
00002 
00003 import rlpark.plugin.rltoys.algorithms.functions.Predictor;
00004 import rlpark.plugin.rltoys.algorithms.functions.states.Projector;
00005 import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD;
00006 import rlpark.plugin.rltoys.math.vector.RealVector;
00007 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00008 import rlpark.plugin.rltoys.math.vector.implementations.Vectors;
00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00010 
00011 
00012 public class CriticAdapterFA implements OffPolicyTD {
00013   private static final long serialVersionUID = 4767252828929104353L;
00014   @Monitor
00015   private final OffPolicyTD offPolicyTD;
00016   private final Projector projector;
00017   private RealVector o_t = null;
00018   private RealVector x_t = null;
00019 
00020   public CriticAdapterFA(Projector projector, OffPolicyTD offPolicyTD) {
00021     this.projector = projector;
00022     this.offPolicyTD = offPolicyTD;
00023   }
00024 
00025   @Override
00026   public void resetWeight(int index) {
00027     offPolicyTD.resetWeight(index);
00028   }
00029 
00030   @Override
00031   public PVector weights() {
00032     return offPolicyTD.weights();
00033   }
00034 
00035   private RealVector projectIFN(RealVector o) {
00036     return projector.project(o instanceof PVector ? ((PVector) o).data : null);
00037   }
00038 
00039   @Override
00040   public double predict(RealVector x) {
00041     return offPolicyTD.predict(projectIFN(x));
00042   }
00043 
00044   @Override
00045   public double error() {
00046     return offPolicyTD.error();
00047   }
00048 
00049   @Override
00050   public double update(double pi_t, double b_t, RealVector o_t, RealVector o_tp1, double r_tp1) {
00051     if (o_t != this.o_t) {
00052       x_t = Vectors.bufferedCopy(projectIFN(o_t), x_t);
00053       this.o_t = o_t;
00054     }
00055     RealVector x_tp1 = projectIFN(o_tp1);
00056     double delta = offPolicyTD.update(pi_t, b_t, x_t, x_tp1, r_tp1);
00057     x_t = Vectors.bufferedCopy(x_tp1, x_t);
00058     this.o_t = o_tp1;
00059     return delta;
00060   }
00061 
00062   @Override
00063   public double prediction() {
00064     return offPolicyTD.prediction();
00065   }
00066 
00067   @Override
00068   public PVector secondaryWeights() {
00069     return offPolicyTD.secondaryWeights();
00070   }
00071 
00072   public Projector projector() {
00073     return projector;
00074   }
00075 
00076   public Predictor predictor() {
00077     return offPolicyTD;
00078   }
00079 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark