RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.example.demos.learning; 00002 00003 import java.util.Random; 00004 00005 import rlpark.plugin.rltoys.agents.functions.FunctionProjected2D; 00006 import rlpark.plugin.rltoys.agents.functions.ValueFunction2D; 00007 import rlpark.plugin.rltoys.agents.offpolicy.OffPolicyAgentDirect; 00008 import rlpark.plugin.rltoys.agents.offpolicy.OffPolicyAgentEvaluable; 00009 import rlpark.plugin.rltoys.agents.offpolicy.OffPolicyAgentFA; 00010 import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.ActorLambdaOffPolicy; 00011 import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.ActorOffPolicy; 00012 import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.CriticAdapterFA; 00013 import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.OffPAC; 00014 import rlpark.plugin.rltoys.algorithms.functions.ContinuousFunction; 00015 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00016 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.helpers.RandomPolicy; 00017 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures.BoltzmannDistribution; 00018 import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction; 00019 import rlpark.plugin.rltoys.algorithms.functions.states.Projector; 00020 import rlpark.plugin.rltoys.algorithms.predictions.td.GTDLambda; 00021 import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD; 00022 import rlpark.plugin.rltoys.algorithms.representations.discretizer.TabularActionDiscretizer; 00023 import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.AbstractPartitionFactory; 00024 import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.BoundedSmallPartitionFactory; 00025 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.StateActionCoders; 00026 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCoders; 00027 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCodersHashing; 00028 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.Hashing; 00029 import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.MurmurHashing; 00030 import rlpark.plugin.rltoys.algorithms.traces.ATraces; 00031 import rlpark.plugin.rltoys.envio.policy.Policy; 00032 import rlpark.plugin.rltoys.experiments.helpers.Runner; 00033 import rlpark.plugin.rltoys.experiments.helpers.Runner.RunnerEvent; 00034 import rlpark.plugin.rltoys.math.ranges.Range; 00035 import rlpark.plugin.rltoys.problems.ProblemBounded; 00036 import rlpark.plugin.rltoys.problems.puddleworld.ConstantFunction; 00037 import rlpark.plugin.rltoys.problems.puddleworld.LocalFeatureSumFunction; 00038 import rlpark.plugin.rltoys.problems.puddleworld.PuddleWorld; 00039 import rlpark.plugin.rltoys.problems.puddleworld.SmoothPuddle; 00040 import rlpark.plugin.rltoys.problems.puddleworld.TargetReachedL1NormTermination; 00041 import zephyr.plugin.core.api.Zephyr; 00042 import zephyr.plugin.core.api.monitoring.abstracts.Monitored; 00043 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00044 import zephyr.plugin.core.api.signals.Listener; 00045 import zephyr.plugin.core.api.synchronization.Clock; 00046 00047 @SuppressWarnings("restriction") 00048 @Monitor 00049 public class OffPACPuddleWorld implements Runnable { 00050 private final Random random = new Random(0); 00051 private final PuddleWorld behaviourEnvironment = createEnvironment(random); 00052 private final PuddleWorld evaluationEnvironment = createEnvironment(random); 00053 private final Runner learningRunner; 00054 private final Runner evaluationRunner; 00055 // Visualization with Zephyr 00056 final FunctionProjected2D valueFunction; 00057 final Clock clock = new Clock("Off-PAC Demo"); 00058 final Clock episodeClock = new Clock("Episodes"); 00059 00060 public OffPACPuddleWorld() { 00061 Policy behaviour = new RandomPolicy(random, behaviourEnvironment.actions()); 00062 OffPolicyAgentEvaluable agent = createOffPACAgent(random, behaviourEnvironment, behaviour, .99); 00063 learningRunner = new Runner(behaviourEnvironment, agent, -1, 5000); 00064 evaluationRunner = new Runner(evaluationEnvironment, agent.createEvaluatedAgent(), -1, 5000); 00065 CriticAdapterFA criticAdapter = (CriticAdapterFA) ((OffPolicyAgentFA) agent).learner().predictor(); 00066 valueFunction = new ValueFunction2D(criticAdapter.projector(), behaviourEnvironment, criticAdapter.predictor()); 00067 connectEpisodesEventsForZephyr(); 00068 Zephyr.advertise(clock, this); 00069 } 00070 00071 private void connectEpisodesEventsForZephyr() { 00072 final double[] episodeInfos = new double[2]; 00073 evaluationRunner.onEpisodeEnd.connect(new Listener<Runner.RunnerEvent>() { 00074 @Override 00075 public void listen(RunnerEvent eventInfo) { 00076 episodeInfos[0] = eventInfo.step.time; 00077 episodeInfos[1] = eventInfo.episodeReward; 00078 episodeClock.tick(); 00079 System.out.println(String.format("Episodes %d: %d, %f", eventInfo.nbEpisodeDone, eventInfo.step.time, 00080 eventInfo.episodeReward)); 00081 } 00082 }); 00083 Zephyr.advertise(episodeClock, new Monitored() { 00084 00085 @Override 00086 public double monitoredValue() { 00087 return episodeInfos[0]; 00088 } 00089 }, "length"); 00090 Zephyr.advertise(episodeClock, new Monitored() { 00091 00092 @Override 00093 public double monitoredValue() { 00094 return episodeInfos[1]; 00095 } 00096 }, "reward"); 00097 } 00098 00099 static private Hashing createHashing(Random random) { 00100 return new MurmurHashing(random, 1000000); 00101 } 00102 00103 static private void setTileCoders(TileCoders projector) { 00104 projector.addFullTilings(10, 10); 00105 projector.includeActiveFeature(); 00106 } 00107 00108 static private AbstractPartitionFactory createPartitionFactory(Random random, Range[] observationRanges) { 00109 AbstractPartitionFactory partitionFactory = new BoundedSmallPartitionFactory(observationRanges); 00110 partitionFactory.setRandom(random, .2); 00111 return partitionFactory; 00112 } 00113 00114 static public Projector createProjector(Random random, PuddleWorld problem) { 00115 final Range[] observationRanges = ((ProblemBounded) problem).getObservationRanges(); 00116 final AbstractPartitionFactory discretizerFactory = createPartitionFactory(random, observationRanges); 00117 Hashing hashing = createHashing(random); 00118 TileCodersHashing projector = new TileCodersHashing(hashing, discretizerFactory, observationRanges.length); 00119 setTileCoders(projector); 00120 return projector; 00121 } 00122 00123 static public StateToStateAction createToStateAction(Random random, PuddleWorld problem) { 00124 final Range[] observationRanges = problem.getObservationRanges(); 00125 final AbstractPartitionFactory discretizerFactory = createPartitionFactory(random, observationRanges); 00126 TabularActionDiscretizer actionDiscretizer = new TabularActionDiscretizer(problem.actions()); 00127 Hashing hashing = createHashing(random); 00128 StateActionCoders stateActionCoders = new StateActionCoders(actionDiscretizer, hashing, discretizerFactory, 00129 observationRanges.length); 00130 setTileCoders(stateActionCoders.tileCoders()); 00131 return stateActionCoders; 00132 } 00133 00134 private OffPolicyTD createCritic(Projector criticProjector, double gamma) { 00135 double alpha_v = .1 / criticProjector.vectorNorm(); 00136 GTDLambda gtd = new GTDLambda(.4, gamma, alpha_v, 0, criticProjector.vectorSize(), new ATraces()); 00137 return new CriticAdapterFA(criticProjector, gtd); 00138 } 00139 00140 private OffPolicyAgentEvaluable createOffPACAgent(Random random, PuddleWorld problem, Policy behaviour, double gamma) { 00141 Projector criticProjector = createProjector(random, problem); 00142 OffPolicyTD critic = createCritic(criticProjector, gamma); 00143 StateToStateAction toStateAction = createToStateAction(random, problem); 00144 PolicyDistribution target = new BoltzmannDistribution(random, problem.actions(), toStateAction); 00145 double alpha_u = .001 / criticProjector.vectorNorm(); 00146 ActorOffPolicy actor = new ActorLambdaOffPolicy(.4, gamma, target, alpha_u, toStateAction.vectorSize(), 00147 new ATraces()); 00148 return new OffPolicyAgentDirect(behaviour, new OffPAC(behaviour, critic, actor)); 00149 } 00150 00151 static private PuddleWorld createEnvironment(Random random) { 00152 PuddleWorld problem = new PuddleWorld(random, 2, new Range(0, 1), new Range(-.05, .05), .1); 00153 final int[] patternIndexes = new int[] { 0, 1 }; 00154 final double smallStddev = 0.03; 00155 ContinuousFunction[] features = new ContinuousFunction[] { new ConstantFunction(1), 00156 new SmoothPuddle(patternIndexes, new double[] { .3, .6 }, new double[] { .1, smallStddev }), 00157 new SmoothPuddle(patternIndexes, new double[] { .4, .5 }, new double[] { smallStddev, .1 }), 00158 new SmoothPuddle(patternIndexes, new double[] { .8, .9 }, new double[] { smallStddev, .1 }) }; 00159 final double puddleMalus = -2; 00160 double[] weights = new double[] { -1, puddleMalus, puddleMalus, puddleMalus }; 00161 problem.setRewardFunction(new LocalFeatureSumFunction(weights, features, 0)); 00162 problem.setTermination(new TargetReachedL1NormTermination(new double[] { 1, 1 }, .1)); 00163 problem.setStart(new double[] { .2, .4 }); 00164 return problem; 00165 } 00166 00167 @Override 00168 public void run() { 00169 while (clock.tick()) { 00170 learningRunner.step(); 00171 evaluationRunner.step(); 00172 } 00173 } 00174 00175 public static void main(String[] args) { 00176 new OffPACPuddleWorld().run(); 00177 } 00178 }