RLPark 1.0.0
Reinforcement Learning Framework in Java

K1.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.predictions.supervised;
00002 
00003 import rlpark.plugin.rltoys.math.vector.RealVector;
00004 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00005 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00006 
00007 @Monitor
00008 public class K1 implements LearningAlgorithm {
00009   private static final long serialVersionUID = 2943574757813500087L;
00010   private final double theta;
00011   private final PVector weights;
00012   private final PVector alphas;
00013   private final PVector betas;
00014   private final PVector hs;
00015   private double delta;
00016   private double prediction;
00017 
00018   public K1(int size, double theta) {
00019     this.theta = theta;
00020     weights = new PVector(size);
00021     double initialAlpha = 0.1;
00022     betas = new PVector(size);
00023     betas.set(Math.log(initialAlpha));
00024     alphas = new PVector(size);
00025     hs = new PVector(size);
00026   }
00027 
00028   @Override
00029   public double learn(RealVector rx, double y_tp1) {
00030     PVector x = (PVector) rx;
00031     prediction = predict(x);
00032     delta = y_tp1 - prediction;
00033     double pnorm = 0.0;
00034     for (int i = 0; i < weights.size; i++) {
00035       betas.data[i] += theta * delta * x.data[i] * hs.data[i];
00036       alphas.data[i] = Math.exp(betas.data[i]);
00037       pnorm += alphas.data[i] * x.data[i] * x.data[i];
00038     }
00039     for (int i = 0; i < weights.size; i++) {
00040       double p_i = alphas.data[i] / (1 + pnorm);
00041       weights.data[i] += p_i * delta * x.data[i];
00042       hs.data[i] = (hs.data[i] + p_i * delta * x.data[i]) * Math.max(0, 1 - p_i * x.data[i] * x.data[i]);
00043     }
00044     return delta;
00045   }
00046 
00047   @Override
00048   public double predict(RealVector x) {
00049     return weights.dotProduct(x);
00050   }
00051 
00052   public RealVector alphas() {
00053     return alphas;
00054   }
00055 
00056   public RealVector h() {
00057     return hs;
00058   }
00059 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark