RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.supervised; 00002 00003 import rlpark.plugin.rltoys.algorithms.LinearLearner; 00004 import rlpark.plugin.rltoys.math.vector.MutableVector; 00005 import rlpark.plugin.rltoys.math.vector.RealVector; 00006 import rlpark.plugin.rltoys.math.vector.filters.Filters; 00007 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00008 import rlpark.plugin.rltoys.math.vector.implementations.PVectors; 00009 import rlpark.plugin.rltoys.math.vector.implementations.Vectors; 00010 import rlpark.plugin.rltoys.math.vector.pool.VectorPool; 00011 import rlpark.plugin.rltoys.math.vector.pool.VectorPools; 00012 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00013 00014 @Monitor 00015 public class Autostep implements LearningAlgorithm, LinearLearner { 00016 private static final long serialVersionUID = -3311074550497156281L; 00017 private static final double DefaultMetaStepSize = 0.01; 00018 private final double Tau = 10000; 00019 @Monitor(level = 4) 00020 protected final PVector alphas; 00021 @Monitor(level = 4) 00022 protected final PVector weights; 00023 @Monitor(level = 4) 00024 protected final PVector h; 00025 private final double kappa; 00026 @Monitor(level = 4) 00027 private final PVector v; 00028 private double delta; 00029 private double prediction; 00030 00031 public Autostep(PVector weights) { 00032 this(weights, DefaultMetaStepSize, 1.0); 00033 } 00034 00035 00036 public Autostep(int vectorSize) { 00037 this(new PVector(vectorSize)); 00038 } 00039 00040 public Autostep(int vectorSize, double kappa, double initStepsize) { 00041 this(new PVector(vectorSize), kappa, initStepsize); 00042 } 00043 00044 public Autostep(PVector weights, double kappa, double initStepsize) { 00045 this.weights = weights; 00046 this.kappa = kappa; 00047 int nbFeatures = weights.size; 00048 alphas = new PVector(nbFeatures); 00049 alphas.set(initStepsize); 00050 h = new PVector(nbFeatures); 00051 v = new PVector(nbFeatures); 00052 v.set(1.0); 00053 } 00054 00055 protected void updateAlphas(VectorPool pool, RealVector x, RealVector x2, RealVector deltaX) { 00056 MutableVector deltaXH = pool.newVector(deltaX).ebeMultiplyToSelf(h); 00057 MutableVector absDeltaXH = Vectors.absToSelf(pool.newVector(deltaXH)); 00058 MutableVector sparseV = pool.newVector(); 00059 Vectors.toBinary(sparseV, deltaX).ebeMultiplyToSelf(v); 00060 MutableVector vUpdate = pool.newVector(absDeltaXH).subtractToSelf(sparseV).ebeMultiplyToSelf(x2) 00061 .ebeMultiplyToSelf(alphas); 00062 v.addToSelf(1.0 / Tau, vUpdate); 00063 Vectors.positiveMaxToSelf(v, absDeltaXH); 00064 PVectors.multiplySelfByExponential(alphas, kappa, deltaXH.ebeDivideToSelf(v), IDBD.MinimumStepsize); 00065 deltaXH = null; 00066 RealVector x2ByAlphas = pool.newVector(x2).ebeMultiplyToSelf(alphas); 00067 double sum = x2ByAlphas.sum(); 00068 if (sum > 1) 00069 Filters.mapMultiplyToSelf(alphas, 1 / sum, x); 00070 } 00071 00072 @Override 00073 public double learn(RealVector x_t, double y_tp1) { 00074 VectorPool pool = VectorPools.pool(x_t); 00075 prediction = predict(x_t); 00076 delta = y_tp1 - prediction; 00077 MutableVector deltaX = pool.newVector(x_t).mapMultiplyToSelf(delta); 00078 MutableVector x2 = pool.newVector(x_t).ebeMultiplyToSelf(x_t); 00079 updateAlphas(pool, x_t, x2, deltaX); 00080 RealVector alphasDeltaX = deltaX.ebeMultiplyToSelf(alphas); 00081 deltaX = null; 00082 weights.addToSelf(alphasDeltaX); 00083 MutableVector x2AlphasH = x2.ebeMultiplyToSelf(alphas).ebeMultiplyToSelf(h); 00084 x2 = null; 00085 h.addToSelf(-1, x2AlphasH).addToSelf(alphasDeltaX); 00086 pool.releaseAll(); 00087 return delta; 00088 } 00089 00090 @Override 00091 public double predict(RealVector x) { 00092 return weights.dotProduct(x); 00093 } 00094 00095 @Override 00096 public PVector weights() { 00097 return weights; 00098 } 00099 00100 public PVector alphas() { 00101 return alphas; 00102 } 00103 00104 00105 @Override 00106 public void resetWeight(int index) { 00107 weights.setEntry(index, 0); 00108 } 00109 00110 00111 @Override 00112 public double error() { 00113 return delta; 00114 } 00115 }