RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.representations.ltu; 00002 00003 import rlpark.plugin.rltoys.algorithms.representations.ltu.networks.RandomNetwork; 00004 import rlpark.plugin.rltoys.math.vector.BinaryVector; 00005 import rlpark.plugin.rltoys.math.vector.implementations.BVector; 00006 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00007 00008 public class StateUpdate { 00009 private final RandomNetwork network; 00010 private final int nbObsInput; 00011 private BVector networkOutput; 00012 @Monitor 00013 private final BVector networkInput; 00014 00015 public StateUpdate(RandomNetwork network, int nbObsInput) { 00016 this.network = network; 00017 this.nbObsInput = nbObsInput; 00018 networkInput = new BVector(network.inputSize); 00019 } 00020 00021 public BVector updateState(BinaryVector o_tp1) { 00022 if (o_tp1 == null) { 00023 networkOutput = null; 00024 return null; 00025 } 00026 networkInput.clear(); 00027 if (networkOutput != null) 00028 networkInput.mergeSubVector(0, networkOutput); 00029 networkInput.mergeSubVector(network.outputSize, o_tp1); 00030 networkOutput = network.project(networkInput); 00031 BVector s_tp1 = new BVector(stateSize()); 00032 s_tp1.mergeSubVector(0, networkOutput); 00033 s_tp1.mergeSubVector(network.outputSize, o_tp1); 00034 s_tp1.setOn(s_tp1.size - 1); 00035 return s_tp1; 00036 } 00037 00038 public int stateSize() { 00039 return network.outputSize + nbObsInput + 1; 00040 } 00041 }