RLPark 1.0.0
Reinforcement Learning Framework in Java

NoisyInputSum.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.problems.noisyinputsum;
00002 
00003 import java.util.Random;
00004 
00005 import rlpark.plugin.rltoys.math.vector.RealVector;
00006 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00007 import rlpark.plugin.rltoys.problems.PredictionProblem;
00008 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00009 
00010 public class NoisyInputSum implements PredictionProblem {
00011   private final Random random;
00012   private int nbSteps = 0;
00013   @Monitor(level = 4)
00014   private final PVector weights;
00015   @Monitor(level = 4)
00016   private final PVector inputs;
00017   @Monitor
00018   private double target;
00019   private final int nbChangingWeights;
00020   private int changePeriod = 20;
00021 
00022   public NoisyInputSum(Random random, int nbNonZeroWeights, int nbInputs) {
00023     this(random, nbNonZeroWeights, nbNonZeroWeights, nbInputs);
00024   }
00025 
00026   public NoisyInputSum(Random random, int nbChangingWeights, int nbNonZeroWeights, int nbInputs) {
00027     this.random = random;
00028     this.nbChangingWeights = nbChangingWeights;
00029     weights = createWeights(random, nbNonZeroWeights, nbInputs);
00030     inputs = new PVector(nbInputs);
00031   }
00032 
00033   private PVector createWeights(Random random, int nbNonZeroWeights, int nbInputs) {
00034     PVector weights = new PVector(nbInputs);
00035     for (int i = 0; i < weights.size; i++)
00036       if (i < nbNonZeroWeights)
00037         weights.data[i] = random.nextBoolean() ? 1 : -1;
00038       else
00039         weights.data[i] = 0;
00040     return weights;
00041   }
00042 
00043   private void changeWeight() {
00044     weights.data[random.nextInt(nbChangingWeights)] *= -1;
00045   }
00046 
00047   @Override
00048   public boolean update() {
00049     nbSteps++;
00050     if (nbSteps % changePeriod == 0)
00051       changeWeight();
00052     for (int i = 0; i < inputs.size; i++)
00053       inputs.data[i] = random.nextGaussian();
00054     target = weights.dotProduct(inputs);
00055     return true;
00056   }
00057 
00058   @Override
00059   public RealVector input() {
00060     return inputs;
00061   }
00062 
00063   @Override
00064   public double target() {
00065     return target;
00066   }
00067 
00068   public void setChangePeriod(int changePeriod) {
00069     this.changePeriod = changePeriod;
00070   }
00071 
00072   @Override
00073   public int inputDimension() {
00074     return inputs.getDimension();
00075   }
00076 
00077   public PVector weights() {
00078     return weights;
00079   }
00080 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark