RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.example.irobot.surprise; 00002 00003 import java.util.ArrayList; 00004 import java.util.List; 00005 import java.util.Random; 00006 00007 import rlpark.plugin.irobot.data.CreateAction; 00008 import rlpark.plugin.irobot.data.IRobotSongs; 00009 import rlpark.plugin.irobot.robots.CreateRobot; 00010 import rlpark.plugin.irobot.robots.IRobotEnvironment; 00011 import rlpark.plugin.rltoys.algorithms.functions.states.AgentState; 00012 import rlpark.plugin.rltoys.algorithms.predictions.td.GTDLambda; 00013 import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD; 00014 import rlpark.plugin.rltoys.algorithms.predictions.td.TDLambda; 00015 import rlpark.plugin.rltoys.envio.actions.Action; 00016 import rlpark.plugin.rltoys.envio.observations.Legend; 00017 import rlpark.plugin.rltoys.envio.observations.ObsFilter; 00018 import rlpark.plugin.rltoys.envio.observations.Observation; 00019 import rlpark.plugin.rltoys.envio.policy.SingleActionPolicy; 00020 import rlpark.plugin.rltoys.envio.policy.Policies; 00021 import rlpark.plugin.rltoys.envio.policy.Policy; 00022 import rlpark.plugin.rltoys.horde.Horde; 00023 import rlpark.plugin.rltoys.horde.Surprise; 00024 import rlpark.plugin.rltoys.horde.demons.Demon; 00025 import rlpark.plugin.rltoys.horde.demons.PredictionDemon; 00026 import rlpark.plugin.rltoys.horde.demons.PredictionOffPolicyDemon; 00027 import rlpark.plugin.rltoys.horde.functions.HordeUpdatable; 00028 import rlpark.plugin.rltoys.horde.functions.RewardFunction; 00029 import rlpark.plugin.rltoys.horde.functions.RewardObservationFunction; 00030 import rlpark.plugin.rltoys.math.vector.RealVector; 00031 import zephyr.plugin.core.api.Zephyr; 00032 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00033 import zephyr.plugin.core.api.synchronization.Clock; 00034 00035 @SuppressWarnings("restriction") 00036 @Monitor 00037 public class CreateSurprise implements Runnable { 00038 static final private int SurpriseTrackingSpeed = 100; 00039 static final private Action[] Actions = new Action[] { CreateAction.DontMove, CreateAction.SpinLeft, 00040 CreateAction.SpinRight, CreateAction.Forward }; 00041 static final private String[] PredictedLabels = new String[] { "WheelDrop", "Bump", "WheelOverCurrent", "ICOmni", 00042 "DriveDistance", "DriveAngle", "BatteryCurrent", "BatteryCharge", "WallSignal", "CliffSignal", 00043 "ConnectedHomeBase", "OIMode", "WheelRequested" }; 00044 static final private double[] Gammas = new double[] { .0, 0.9, 0.99 }; 00045 static final private Policy[] TargetPolicies = new Policy[] { new SingleActionPolicy(CreateAction.SpinLeft), 00046 new SingleActionPolicy(CreateAction.Forward) }; 00047 static final private double Lambda = .7; 00048 final private IRobotEnvironment robot = new CreateRobot(); 00049 final private Clock clock = new Clock("Surprise"); 00050 final private Horde horde; 00051 final private Surprise surprise; 00052 private final AgentState agentState; 00053 private final Policy robotBehaviour; 00054 private RealVector x_t; 00055 private Action a_t; 00056 00057 public CreateSurprise() { 00058 agentState = new RobotState(); 00059 robotBehaviour = new RobotBehaviour(new Random(0), .25, Actions); 00060 horde = createHorde(); 00061 surprise = new Surprise(horde.demons(), SurpriseTrackingSpeed); 00062 Zephyr.advertise(clock, this); 00063 } 00064 00065 private Horde createHorde() { 00066 List<RewardFunction> rewardFunctions = createRewardFunctions(); 00067 List<Demon> demons = new ArrayList<Demon>(); 00068 for (RewardFunction rewardFunction : rewardFunctions) { 00069 for (double gamma : Gammas) { 00070 demons.add(newNextingPredictionDemon(rewardFunction, gamma)); 00071 for (Policy targetPolicy : TargetPolicies) 00072 demons.add(newOffPolicyPredictionDemon(rewardFunction, gamma, targetPolicy)); 00073 } 00074 } 00075 Horde horde = new Horde(); 00076 horde.demons().addAll(demons); 00077 for (RewardFunction rewardFunction : rewardFunctions) 00078 horde.addBeforeFunction((HordeUpdatable) rewardFunction); 00079 return horde; 00080 } 00081 00082 private PredictionOffPolicyDemon newOffPolicyPredictionDemon(RewardFunction rewardFunction, double gamma, 00083 Policy targetPolicy) { 00084 GTDLambda gtd = new GTDLambda(Lambda, gamma, .1 / agentState.stateNorm(), 0.0001 / agentState.stateNorm(), 00085 agentState.stateSize()); 00086 return new PredictionOffPolicyDemon(targetPolicy, robotBehaviour, gtd, rewardFunction); 00087 } 00088 00089 private PredictionDemon newNextingPredictionDemon(RewardFunction rewardFunction, double gamma) { 00090 OnPolicyTD td = new TDLambda(Lambda, gamma, .1 / agentState.stateNorm(), agentState.stateSize()); 00091 return new PredictionDemon(rewardFunction, td); 00092 } 00093 00094 private List<RewardFunction> createRewardFunctions() { 00095 ArrayList<RewardFunction> rewardFunctions = new ArrayList<RewardFunction>(); 00096 Legend legend = robot.legend(); 00097 ObsFilter filter = new ObsFilter(legend, PredictedLabels); 00098 for (String label : filter.legend().getLabels()) 00099 rewardFunctions.add(new RewardObservationFunction(legend, label)); 00100 return rewardFunctions; 00101 } 00102 00103 @Override 00104 public void run() { 00105 robot.fullMode(); 00106 while (clock.tick()) { 00107 Observation o_tp1 = robot.waitNewRawObs(); 00108 RealVector x_tp1 = agentState.update(a_t, o_tp1); 00109 horde.update(o_tp1, x_t, a_t, x_tp1); 00110 double surpriseMeasure = surprise.updateSurpriseMeasure(); 00111 if (surpriseMeasure > 8.0) 00112 robot.playSong(IRobotSongs.composeHappySong()); 00113 Action a_tp1 = Policies.decide(robotBehaviour, x_tp1); 00114 robot.sendAction((CreateAction) a_tp1); 00115 x_t = x_tp1; 00116 a_t = a_tp1; 00117 } 00118 } 00119 }