RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.problems.puddleworld; 00002 00003 import java.util.Arrays; 00004 import java.util.Random; 00005 00006 import rlpark.plugin.rltoys.algorithms.functions.ContinuousFunction; 00007 import rlpark.plugin.rltoys.envio.actions.Action; 00008 import rlpark.plugin.rltoys.envio.actions.ActionArray; 00009 import rlpark.plugin.rltoys.envio.observations.Legend; 00010 import rlpark.plugin.rltoys.envio.rl.TRStep; 00011 import rlpark.plugin.rltoys.math.ranges.Range; 00012 import rlpark.plugin.rltoys.problems.ProblemBounded; 00013 import rlpark.plugin.rltoys.problems.ProblemContinuousAction; 00014 import rlpark.plugin.rltoys.problems.ProblemDiscreteAction; 00015 import rlpark.plugin.rltoys.utils.Utils; 00016 import zephyr.plugin.core.api.monitoring.abstracts.DataMonitor; 00017 import zephyr.plugin.core.api.monitoring.abstracts.MonitorContainer; 00018 import zephyr.plugin.core.api.monitoring.abstracts.Monitored; 00019 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00020 00021 public class PuddleWorld implements ProblemBounded, ProblemDiscreteAction, ProblemContinuousAction, 00022 MonitorContainer { 00023 private final Action[] actions; 00024 protected TRStep step = null; 00025 private final int nbDimensions; 00026 private double[] start = null; 00027 @Monitor 00028 private ContinuousFunction rewardFunction = null; 00029 private final Legend legend; 00030 private final Random random; 00031 private TerminationFunction terminationFunction = null; 00032 private final Range observationRange; 00033 private final Range actionRange; 00034 private final double absoluteNoise; 00035 @Monitor 00036 private final double[] lastActions; 00037 00038 public PuddleWorld(Random random, int nbDimension, Range observationRange, Range actionRange, 00039 double relativeNoise) { 00040 this.random = random; 00041 this.observationRange = observationRange; 00042 this.actionRange = actionRange; 00043 this.nbDimensions = nbDimension; 00044 this.absoluteNoise = (actionRange.length() / 2.0) * relativeNoise; 00045 legend = createLegend(); 00046 actions = createActions(); 00047 lastActions = new double[nbDimension]; 00048 } 00049 00050 private Action[] createActions() { 00051 Action[] actions = new Action[2 * nbDimensions + 1]; 00052 for (int i = 0; i < actions.length - 1; i++) { 00053 int dimension = i / 2; 00054 int dimensionAction = i % 2; 00055 double[] actionValues = Utils.newFilledArray(nbDimensions, 0); 00056 if (dimensionAction == 0) 00057 actionValues[dimension] = -1; 00058 else 00059 actionValues[dimension] = 1; 00060 actions[i] = new ActionArray(actionValues); 00061 } 00062 actions[actions.length - 1] = new ActionArray(Utils.newFilledArray(nbDimensions, 0)); 00063 return actions; 00064 } 00065 00066 public void setStart(double[] start) { 00067 this.start = start; 00068 } 00069 00070 public void setRewardFunction(ContinuousFunction rewardFunction) { 00071 this.rewardFunction = rewardFunction; 00072 } 00073 00074 public void setTermination(TerminationFunction terminationFunction) { 00075 this.terminationFunction = terminationFunction; 00076 } 00077 00078 private Legend createLegend() { 00079 String[] labels = new String[nbDimensions]; 00080 for (int i = 0; i < nbDimensions; i++) 00081 labels[i] = "x" + i; 00082 return new Legend(labels); 00083 } 00084 00085 @Override 00086 public TRStep initialize() { 00087 double[] position = start; 00088 if (position == null) { 00089 position = new double[nbDimensions]; 00090 do { 00091 for (int i = 0; i < position.length; i++) 00092 position[i] = observationRange.choose(random); 00093 } while (isTerminated(position)); 00094 } 00095 step = new TRStep(position, reward(position)); 00096 return step; 00097 } 00098 00099 @Override 00100 public TRStep step(Action action) { 00101 if (isTerminated(step.o_tp1)) { 00102 step = step.createEndingStep(); 00103 return step; 00104 } 00105 double[] envAction = computeEnvironmentAction(action); 00106 double[] x_tp1 = new double[nbDimensions]; 00107 for (int i = 0; i < x_tp1.length; i++) 00108 x_tp1[i] = observationRange.bound(step.o_tp1[i] + envAction[i]); 00109 step = new TRStep(step, action, x_tp1, reward(x_tp1)); 00110 return step; 00111 } 00112 00113 private double reward(double[] position) { 00114 if (rewardFunction == null) 00115 return 0.0; 00116 return rewardFunction.value(position); 00117 } 00118 00119 private boolean isTerminated(double[] position) { 00120 if (terminationFunction == null) 00121 return false; 00122 return terminationFunction.isTerminated(position); 00123 } 00124 00125 private double[] computeEnvironmentAction(Action action) { 00126 double[] agentAction = ((ActionArray) action).actions; 00127 System.arraycopy(agentAction, 0, lastActions, 0, nbDimensions); 00128 double[] envAction = new double[agentAction.length]; 00129 for (int i = 0; i < envAction.length; i++) { 00130 double noise = (random.nextDouble() * absoluteNoise) - (absoluteNoise / 2); 00131 envAction[i] = actionRange.bound(agentAction[i]) + noise; 00132 } 00133 return envAction; 00134 } 00135 00136 @Override 00137 public Legend legend() { 00138 return legend; 00139 } 00140 00141 @Override 00142 public Range[] actionRanges() { 00143 Range[] ranges = new Range[nbDimensions]; 00144 Arrays.fill(ranges, actionRange); 00145 return ranges; 00146 } 00147 00148 @Override 00149 public Range[] getObservationRanges() { 00150 Range[] ranges = new Range[nbDimensions]; 00151 Arrays.fill(ranges, observationRange); 00152 return ranges; 00153 } 00154 00155 @Override 00156 public void addToMonitor(DataMonitor monitor) { 00157 monitor.add("Reward", new Monitored() { 00158 @Override 00159 public double monitoredValue() { 00160 return step != null ? step.r_tp1 : 0.0; 00161 } 00162 }); 00163 for (int i = 0; i < legend.nbLabels(); i++) { 00164 final int index = i; 00165 monitor.add(legend.label(i), new Monitored() { 00166 @Override 00167 public double monitoredValue() { 00168 return step != null && step.o_tp1 != null ? step.o_tp1[index] : 0.0; 00169 } 00170 }); 00171 } 00172 } 00173 00174 public int nbDimensions() { 00175 return nbDimensions; 00176 } 00177 00178 public ContinuousFunction rewardFunction() { 00179 return rewardFunction; 00180 } 00181 00182 public double[] start() { 00183 return start; 00184 } 00185 00186 @Override 00187 public Action[] actions() { 00188 return actions; 00189 } 00190 00191 @Override 00192 public TRStep lastStep() { 00193 return step; 00194 } 00195 00196 @Override 00197 public TRStep forceEndEpisode() { 00198 step = step.createEndingStep(); 00199 return step; 00200 } 00201 }