RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.horde; 00002 00003 import java.util.List; 00004 00005 import rlpark.plugin.rltoys.horde.demons.Demon; 00006 import rlpark.plugin.rltoys.math.normalization.MovingMeanVarNormalizer; 00007 import zephyr.plugin.core.api.labels.Labels; 00008 import zephyr.plugin.core.api.monitoring.annotations.IgnoreMonitor; 00009 import zephyr.plugin.core.api.monitoring.annotations.LabelProvider; 00010 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00011 00012 @Monitor 00013 public class Surprise { 00014 private final MovingMeanVarNormalizer[] errorNormalizers; 00015 @IgnoreMonitor 00016 private final Demon[] demons; 00017 private final double[] errors; 00018 private double surpriseMeasure; 00019 00020 public Surprise(List<Demon> demons, int trackingSpeed) { 00021 this.demons = new Demon[demons.size()]; 00022 demons.toArray(this.demons); 00023 errors = new double[demons.size()]; 00024 errorNormalizers = new MovingMeanVarNormalizer[demons.size()]; 00025 for (int i = 0; i < errorNormalizers.length; i++) 00026 errorNormalizers[i] = new MovingMeanVarNormalizer(trackingSpeed); 00027 } 00028 00029 public double updateSurpriseMeasure() { 00030 surpriseMeasure = 0; 00031 for (int i = 0; i < demons.length; i++) { 00032 double error = demons[i].learner().error(); 00033 errorNormalizers[i].update(error); 00034 double scaledError = errorNormalizers[i].normalize(error); 00035 errors[i] = scaledError; 00036 surpriseMeasure = Math.max(surpriseMeasure, Math.abs(scaledError)); 00037 } 00038 return surpriseMeasure; 00039 } 00040 00041 @LabelProvider(ids = { "demons", "errors", "errorNormalizers" }) 00042 String labelOf(int index) { 00043 return Labels.label(demons[index]); 00044 } 00045 }