RLPark 1.0.0
Reinforcement Learning Framework in Java

StateActionCoders.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.representations.tilescoding;
00002 
00003 import java.util.Arrays;
00004 
00005 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
00006 import rlpark.plugin.rltoys.algorithms.representations.discretizer.ActionDiscretizer;
00007 import rlpark.plugin.rltoys.algorithms.representations.discretizer.Discretizer;
00008 import rlpark.plugin.rltoys.algorithms.representations.discretizer.DiscretizerFactory;
00009 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.Hashing;
00010 import rlpark.plugin.rltoys.envio.actions.Action;
00011 import rlpark.plugin.rltoys.math.vector.RealVector;
00012 
00013 public class StateActionCoders implements StateToStateAction {
00014   private static final long serialVersionUID = 6906465332938314787L;
00015   private final TileCoders tileCoders;
00016   private final ActionDiscretizer actionDiscretizer;
00017 
00018   public StateActionCoders(ActionDiscretizer actionDiscretizer, DiscretizerFactory discretizerFactory, int nbInputs) {
00019     this(actionDiscretizer, new TileCodersNoHashing(createDiscretizerFactory(actionDiscretizer, discretizerFactory,
00020                                                                              nbInputs), nbInputs
00021         + actionDiscretizer.nbOutput()));
00022   }
00023 
00024   public StateActionCoders(ActionDiscretizer actionDiscretizer, Hashing hashing, DiscretizerFactory discretizerFactory,
00025       int nbInputs) {
00026     this(actionDiscretizer, new TileCodersHashing(hashing, createDiscretizerFactory(actionDiscretizer,
00027                                                                                     discretizerFactory, nbInputs),
00028                                                   nbInputs + actionDiscretizer.nbOutput()));
00029   }
00030 
00031   public StateActionCoders(ActionDiscretizer actionDiscretizer, TileCoders tileCoders) {
00032     this.actionDiscretizer = actionDiscretizer;
00033     this.tileCoders = tileCoders;
00034   }
00035 
00036   @Override
00037   public RealVector stateAction(RealVector s, Action a) {
00038     if (s == null || a == null)
00039       return tileCoders.project(null);
00040     double[] sa = Arrays.copyOf(s.accessData(), s.getDimension() + actionDiscretizer.nbOutput());
00041     System.arraycopy(actionDiscretizer.discretize(a), 0, sa, s.getDimension(), actionDiscretizer.nbOutput());
00042     return tileCoders.project(sa);
00043   }
00044 
00045   @Override
00046   public double vectorNorm() {
00047     return tileCoders.vectorNorm();
00048   }
00049 
00050   @Override
00051   public int vectorSize() {
00052     return tileCoders.vectorSize();
00053   }
00054 
00055   public TileCoders tileCoders() {
00056     return tileCoders;
00057   }
00058 
00059   static public DiscretizerFactory createDiscretizerFactory(ActionDiscretizer actionDiscretizerFactory,
00060       final DiscretizerFactory discretizerFactory, final int nbInputs) {
00061     final Discretizer[] actionDiscretizers = actionDiscretizerFactory.actionDiscretizers();
00062     return new DiscretizerFactory() {
00063       private static final long serialVersionUID = -4362287012399520301L;
00064 
00065       @Override
00066       public Discretizer createDiscretizer(int inputIndex, int resolution, int tilingIndex, int nbTilings) {
00067         if (inputIndex < nbInputs)
00068           return discretizerFactory.createDiscretizer(inputIndex, resolution, tilingIndex, nbTilings);
00069         return actionDiscretizers[inputIndex - nbInputs];
00070       }
00071     };
00072   }
00073 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark