RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.supervised; 00002 00003 import rlpark.plugin.rltoys.math.vector.MutableVector; 00004 import rlpark.plugin.rltoys.math.vector.RealVector; 00005 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00006 import rlpark.plugin.rltoys.math.vector.implementations.PVectors; 00007 import rlpark.plugin.rltoys.math.vector.pool.VectorPool; 00008 import rlpark.plugin.rltoys.math.vector.pool.VectorPools; 00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00010 00011 public class IDBD implements LearningAlgorithm { 00012 private static final long serialVersionUID = 6961877310325699208L; 00013 public final static double MinimumStepsize = 1e-6; 00014 private final double theta; 00015 @Monitor(level = 4) 00016 private final PVector weights; 00017 @Monitor(level = 4) 00018 private final PVector alphas; 00019 @Monitor(level = 4) 00020 private final PVector hs; 00021 00022 public IDBD(int size, double theta) { 00023 this(size, theta, .1 / size); 00024 } 00025 00026 public IDBD(int size, double theta, double alphaInit) { 00027 this.theta = theta; 00028 weights = new PVector(size); 00029 alphas = new PVector(size); 00030 alphas.set(alphaInit); 00031 hs = new PVector(size); 00032 } 00033 00034 @Override 00035 public double learn(RealVector x_t, double y_tp1) { 00036 VectorPool pool = VectorPools.pool(x_t); 00037 double delta = y_tp1 - predict(x_t); 00038 MutableVector deltaX = pool.newVector(x_t).mapMultiplyToSelf(delta); 00039 RealVector deltaXH = pool.newVector(deltaX).ebeMultiplyToSelf(hs); 00040 PVectors.multiplySelfByExponential(alphas, theta, deltaXH, MinimumStepsize); 00041 RealVector alphaDeltaX = deltaX.ebeMultiplyToSelf(alphas); 00042 deltaX = null; 00043 weights.addToSelf(alphaDeltaX); 00044 RealVector alphaX2 = pool.newVector(x_t).ebeMultiplyToSelf(x_t).ebeMultiplyToSelf(alphas).ebeMultiplyToSelf(hs); 00045 hs.addToSelf(-1, alphaX2); 00046 hs.addToSelf(alphaDeltaX); 00047 pool.releaseAll(); 00048 return delta; 00049 } 00050 00051 @Override 00052 public double predict(RealVector x) { 00053 return weights.dotProduct(x); 00054 } 00055 00056 public PVector alphas() { 00057 return alphas; 00058 } 00059 00060 public RealVector h() { 00061 return hs; 00062 } 00063 }