RLPark 1.0.0
Reinforcement Learning Framework in Java

NoisyInputSumEvaluation.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.experiments.testing.predictions;
00002 
00003 import java.util.Random;
00004 
00005 import rlpark.plugin.rltoys.algorithms.predictions.supervised.LearningAlgorithm;
00006 import rlpark.plugin.rltoys.math.vector.RealVector;
00007 import rlpark.plugin.rltoys.math.vector.implementations.BVector;
00008 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00009 import rlpark.plugin.rltoys.math.vector.implementations.Vectors;
00010 import rlpark.plugin.rltoys.problems.noisyinputsum.NoisyInputSum;
00011 import rlpark.plugin.rltoys.utils.Utils;
00012 
00013 public class NoisyInputSumEvaluation {
00014   public static final int NbInputs = 20;
00015   public static final int NbNonZeroWeights = 5;
00016 
00017   static public double evaluateLearner(LearningAlgorithm algorithm, int learningEpisodes, int evaluationEpisodes) {
00018     NoisyInputSum noisyInputSum = new NoisyInputSum(new Random(0), NbNonZeroWeights, NbInputs);
00019     for (int i = 0; i < learningEpisodes; i++) {
00020       noisyInputSum.update();
00021       algorithm.learn(noisyInputSum.input(), noisyInputSum.target());
00022     }
00023     PVector errors = new PVector(evaluationEpisodes);
00024     for (int i = 0; i < evaluationEpisodes; i++) {
00025       noisyInputSum.update();
00026       errors.data[i] = algorithm.learn(noisyInputSum.input(), noisyInputSum.target());
00027       assert Utils.checkValue(errors.data[i]);
00028     }
00029     double mse = errors.dotProduct(errors) / errors.size;
00030     assert Utils.checkValue(mse);
00031     return mse;
00032   }
00033 
00034   static public double evaluateLearner(LearningAlgorithm algorithm) {
00035     return evaluateLearner(algorithm, 20000, 10000);
00036   }
00037 
00038   public static String infoString(RealVector v) {
00039     StringBuilder result = new StringBuilder();
00040     result.append("L1Norm: ");
00041     result.append(Vectors.l1Norm(v));
00042     result.append(" Ave Non Zero: ");
00043     double averageNonZero = average(v, 0, NbNonZeroWeights);
00044     result.append(averageNonZero);
00045     result.append(" Ave Zero: ");
00046     double averageZero = average(v, NbNonZeroWeights, NbInputs);
00047     result.append(averageZero);
00048     result.append(" Ratio=" + averageNonZero / averageZero);
00049     return result.toString();
00050   }
00051 
00052   private static double average(RealVector v, int start, int end) {
00053     return Vectors.l1Norm(v.ebeMultiply(mask(start, end))) / (end - start);
00054   }
00055 
00056   private static RealVector mask(int start, int end) {
00057     BVector mask = new BVector(NbInputs);
00058     for (int i = start; i < end; i++)
00059       mask.setOn(i);
00060     return mask;
00061   }
00062 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark