RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.noisyinputsum; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.math.vector.RealVector; 00006 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00007 import rlpark.plugin.rltoys.problems.PredictionProblem; 00008 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00009 00010 public class NoisyInputSum implements PredictionProblem { 00011 private final Random random; 00012 private int nbSteps = 0; 00013 @Monitor(level = 4) 00014 private final PVector weights; 00015 @Monitor(level = 4) 00016 private final PVector inputs; 00017 @Monitor 00018 private double target; 00019 private final int nbChangingWeights; 00020 private int changePeriod = 20; 00021 00022 public NoisyInputSum(Random random, int nbNonZeroWeights, int nbInputs) { 00023 this(random, nbNonZeroWeights, nbNonZeroWeights, nbInputs); 00024 } 00025 00026 public NoisyInputSum(Random random, int nbChangingWeights, int nbNonZeroWeights, int nbInputs) { 00027 this.random = random; 00028 this.nbChangingWeights = nbChangingWeights; 00029 weights = createWeights(random, nbNonZeroWeights, nbInputs); 00030 inputs = new PVector(nbInputs); 00031 } 00032 00033 private PVector createWeights(Random random, int nbNonZeroWeights, int nbInputs) { 00034 PVector weights = new PVector(nbInputs); 00035 for (int i = 0; i < weights.size; i++) 00036 if (i < nbNonZeroWeights) 00037 weights.data[i] = random.nextBoolean() ? 1 : -1; 00038 else 00039 weights.data[i] = 0; 00040 return weights; 00041 } 00042 00043 private void changeWeight() { 00044 weights.data[random.nextInt(nbChangingWeights)] *= -1; 00045 } 00046 00047 @Override 00048 public boolean update() { 00049 nbSteps++; 00050 if (nbSteps % changePeriod == 0) 00051 changeWeight(); 00052 for (int i = 0; i < inputs.size; i++) 00053 inputs.data[i] = random.nextGaussian(); 00054 target = weights.dotProduct(inputs); 00055 return true; 00056 } 00057 00058 @Override 00059 public RealVector input() { 00060 return inputs; 00061 } 00062 00063 @Override 00064 public double target() { 00065 return target; 00066 } 00067 00068 public void setChangePeriod(int changePeriod) { 00069 this.changePeriod = changePeriod; 00070 } 00071 00072 @Override 00073 public int inputDimension() { 00074 return inputs.getDimension(); 00075 } 00076 00077 public PVector weights() { 00078 return weights; 00079 } 00080 }