RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.representations.tilescoding; 00002 00003 import java.util.Arrays; 00004 00005 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00006 import rlpark.plugin.rltoys.algorithms.representations.discretizer.ActionDiscretizer; 00007 import rlpark.plugin.rltoys.algorithms.representations.discretizer.Discretizer; 00008 import rlpark.plugin.rltoys.algorithms.representations.discretizer.DiscretizerFactory; 00009 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.Hashing; 00010 import rlpark.plugin.rltoys.envio.actions.Action; 00011 import rlpark.plugin.rltoys.math.vector.RealVector; 00012 00013 public class StateActionCoders implements StateToStateAction { 00014 private static final long serialVersionUID = 6906465332938314787L; 00015 private final TileCoders tileCoders; 00016 private final ActionDiscretizer actionDiscretizer; 00017 00018 public StateActionCoders(ActionDiscretizer actionDiscretizer, DiscretizerFactory discretizerFactory, int nbInputs) { 00019 this(actionDiscretizer, new TileCodersNoHashing(createDiscretizerFactory(actionDiscretizer, discretizerFactory, 00020 nbInputs), nbInputs 00021 + actionDiscretizer.nbOutput())); 00022 } 00023 00024 public StateActionCoders(ActionDiscretizer actionDiscretizer, Hashing hashing, DiscretizerFactory discretizerFactory, 00025 int nbInputs) { 00026 this(actionDiscretizer, new TileCodersHashing(hashing, createDiscretizerFactory(actionDiscretizer, 00027 discretizerFactory, nbInputs), 00028 nbInputs + actionDiscretizer.nbOutput())); 00029 } 00030 00031 public StateActionCoders(ActionDiscretizer actionDiscretizer, TileCoders tileCoders) { 00032 this.actionDiscretizer = actionDiscretizer; 00033 this.tileCoders = tileCoders; 00034 } 00035 00036 @Override 00037 public RealVector stateAction(RealVector s, Action a) { 00038 if (s == null || a == null) 00039 return tileCoders.project(null); 00040 double[] sa = Arrays.copyOf(s.accessData(), s.getDimension() + actionDiscretizer.nbOutput()); 00041 System.arraycopy(actionDiscretizer.discretize(a), 0, sa, s.getDimension(), actionDiscretizer.nbOutput()); 00042 return tileCoders.project(sa); 00043 } 00044 00045 @Override 00046 public double vectorNorm() { 00047 return tileCoders.vectorNorm(); 00048 } 00049 00050 @Override 00051 public int vectorSize() { 00052 return tileCoders.vectorSize(); 00053 } 00054 00055 public TileCoders tileCoders() { 00056 return tileCoders; 00057 } 00058 00059 static public DiscretizerFactory createDiscretizerFactory(ActionDiscretizer actionDiscretizerFactory, 00060 final DiscretizerFactory discretizerFactory, final int nbInputs) { 00061 final Discretizer[] actionDiscretizers = actionDiscretizerFactory.actionDiscretizers(); 00062 return new DiscretizerFactory() { 00063 private static final long serialVersionUID = -4362287012399520301L; 00064 00065 @Override 00066 public Discretizer createDiscretizer(int inputIndex, int resolution, int tilingIndex, int nbTilings) { 00067 if (inputIndex < nbInputs) 00068 return discretizerFactory.createDiscretizer(inputIndex, resolution, tilingIndex, nbTilings); 00069 return actionDiscretizers[inputIndex - nbInputs]; 00070 } 00071 }; 00072 } 00073 }