RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.math.normalization; 00002 00003 import rlpark.plugin.rltoys.math.averages.MeanVar; 00004 import rlpark.plugin.rltoys.utils.Utils; 00005 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00006 00007 @Monitor 00008 public class MovingMeanVarNormalizer implements Normalizer, MeanVar { 00009 private static final long serialVersionUID = -1340053804929435288L; 00010 private double mean = 0.0; 00011 private double var = 1.0; 00012 private double c = 0.0; 00013 private final int trackingSpeed; 00014 private final double alpha; 00015 00016 public MovingMeanVarNormalizer(int trackingSpeed) { 00017 this.trackingSpeed = trackingSpeed; 00018 this.alpha = 1 - Utils.timeStepsToDiscount(trackingSpeed); 00019 } 00020 00021 @Override 00022 final public double normalize(double x) { 00023 if (var == 0.0) 00024 return 0.0; 00025 return ((x - mean) / Math.sqrt(var)) * c; 00026 } 00027 00028 @Override 00029 public void update(double x) { 00030 double delta = x - mean; 00031 mean = mean + alpha * delta; 00032 var = var + alpha * ((x - mean) * (x - mean) - var); 00033 c = c + alpha * (1 - c); 00034 } 00035 00036 @Override 00037 public double mean() { 00038 return mean; 00039 } 00040 00041 @Override 00042 public double var() { 00043 return var; 00044 } 00045 00046 @Override 00047 public MovingMeanVarNormalizer newInstance() { 00048 return new MovingMeanVarNormalizer(trackingSpeed); 00049 } 00050 }