RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.predictions.supervised; 00002 00003 import rlpark.plugin.rltoys.algorithms.LinearLearner; 00004 import rlpark.plugin.rltoys.math.vector.RealVector; 00005 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00006 import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs; 00007 import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared; 00008 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00009 00010 @Monitor 00011 @SuppressWarnings("restriction") 00012 public class Adaline implements LearningAlgorithm, LinearLearner { 00013 private static final long serialVersionUID = -1427180343679219960L; 00014 private final double alpha; 00015 @Monitor(level = 4) 00016 private final PVector weights; 00017 private double prediction; 00018 private double target; 00019 @Monitor(wrappers = { Squared.ID, Abs.ID }) 00020 private double error; 00021 00022 public Adaline(int size, double alpha) { 00023 weights = new PVector(size); 00024 this.alpha = alpha; 00025 } 00026 00027 @Override 00028 public double learn(RealVector x_t, double y_tp1) { 00029 prediction = predict(x_t); 00030 target = y_tp1; 00031 error = target - prediction; 00032 weights.addToSelf(x_t.mapMultiply(alpha * error)); 00033 return error; 00034 } 00035 00036 @Override 00037 public double predict(RealVector x) { 00038 return weights.dotProduct(x); 00039 } 00040 00041 @Override 00042 public PVector weights() { 00043 return weights; 00044 } 00045 00046 @Override 00047 public void resetWeight(int i) { 00048 weights.data[i] = 0; 00049 } 00050 00051 @Override 00052 public double error() { 00053 return error; 00054 } 00055 }