RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.supervised; 00002 00003 import rlpark.plugin.rltoys.math.vector.RealVector; 00004 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00005 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00006 00007 @Monitor 00008 public class K1 implements LearningAlgorithm { 00009 private static final long serialVersionUID = 2943574757813500087L; 00010 private final double theta; 00011 private final PVector weights; 00012 private final PVector alphas; 00013 private final PVector betas; 00014 private final PVector hs; 00015 private double delta; 00016 private double prediction; 00017 00018 public K1(int size, double theta) { 00019 this.theta = theta; 00020 weights = new PVector(size); 00021 double initialAlpha = 0.1; 00022 betas = new PVector(size); 00023 betas.set(Math.log(initialAlpha)); 00024 alphas = new PVector(size); 00025 hs = new PVector(size); 00026 } 00027 00028 @Override 00029 public double learn(RealVector rx, double y_tp1) { 00030 PVector x = (PVector) rx; 00031 prediction = predict(x); 00032 delta = y_tp1 - prediction; 00033 double pnorm = 0.0; 00034 for (int i = 0; i < weights.size; i++) { 00035 betas.data[i] += theta * delta * x.data[i] * hs.data[i]; 00036 alphas.data[i] = Math.exp(betas.data[i]); 00037 pnorm += alphas.data[i] * x.data[i] * x.data[i]; 00038 } 00039 for (int i = 0; i < weights.size; i++) { 00040 double p_i = alphas.data[i] / (1 + pnorm); 00041 weights.data[i] += p_i * delta * x.data[i]; 00042 hs.data[i] = (hs.data[i] + p_i * delta * x.data[i]) * Math.max(0, 1 - p_i * x.data[i] * x.data[i]); 00043 } 00044 return delta; 00045 } 00046 00047 @Override 00048 public double predict(RealVector x) { 00049 return weights.dotProduct(x); 00050 } 00051 00052 public RealVector alphas() { 00053 return alphas; 00054 } 00055 00056 public RealVector h() { 00057 return hs; 00058 } 00059 }