RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.traces; 00002 00003 import rlpark.plugin.rltoys.math.vector.DenseVector; 00004 import rlpark.plugin.rltoys.math.vector.MutableVector; 00005 import rlpark.plugin.rltoys.math.vector.RealVector; 00006 import rlpark.plugin.rltoys.math.vector.implementations.SVector; 00007 import rlpark.plugin.rltoys.math.vector.implementations.Vectors; 00008 import rlpark.plugin.rltoys.utils.Prototype; 00009 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00010 00014 public class ATraces implements Traces, Prototype<RealVector> { 00015 private static final long serialVersionUID = 6241887723527497111L; 00016 public static final SVector DefaultPrototype = new SVector(0); 00017 public static final double DefaultThreshold = 1e-8; 00018 @Monitor 00019 protected double threshold = 1e-8; 00020 protected final MutableVector prototype; 00021 @Monitor 00022 protected final MutableVector vector; 00023 00024 public ATraces() { 00025 this(DefaultPrototype); 00026 } 00027 00028 public ATraces(MutableVector prototype) { 00029 this(prototype, DefaultThreshold); 00030 } 00031 00032 public ATraces(MutableVector prototype, double threshold) { 00033 this(prototype, threshold, 0); 00034 } 00035 00036 protected ATraces(MutableVector prototype, double threshold, int size) { 00037 this.prototype = prototype; 00038 vector = size > 0 ? prototype.newInstance(size) : null; 00039 } 00040 00041 @Override 00042 public ATraces newTraces(int size) { 00043 return new ATraces(prototype, threshold, size); 00044 } 00045 00046 @Override 00047 public void update(double lambda, RealVector phi) { 00048 updateVector(lambda, phi); 00049 adjustUpdate(); 00050 if (clearRequired(phi, lambda)) 00051 clearBelowThreshold(); 00052 assert Vectors.checkValues(vector); 00053 } 00054 00055 protected void adjustUpdate() { 00056 } 00057 00058 protected void updateVector(double lambda, RealVector phi) { 00059 vector.mapMultiplyToSelf(lambda); 00060 vector.addToSelf(phi); 00061 } 00062 00063 private boolean clearRequired(RealVector phi, double lambda) { 00064 if (threshold == 0) 00065 return false; 00066 if (vector instanceof DenseVector) 00067 return false; 00068 return true; 00069 } 00070 00071 protected void clearBelowThreshold() { 00072 SVector svector = (SVector) vector; 00073 double[] values = svector.values; 00074 int[] indexes = svector.activeIndexes; 00075 int i = 0; 00076 while (i < svector.nonZeroElements()) { 00077 final double absValue = Math.abs(values[i]); 00078 if (absValue <= threshold) 00079 svector.removeEntry(indexes[i]); 00080 else 00081 i++; 00082 } 00083 } 00084 00085 @Override 00086 public MutableVector vect() { 00087 return vector; 00088 } 00089 00090 @Override 00091 public void clear() { 00092 vector.clear(); 00093 } 00094 00095 @Override 00096 public RealVector prototype() { 00097 return prototype; 00098 } 00099 }