RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.helpers; 00002 00003 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.BoundedPdf; 00004 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00005 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyParameterized; 00006 import rlpark.plugin.rltoys.envio.actions.Action; 00007 import rlpark.plugin.rltoys.envio.actions.ActionArray; 00008 import rlpark.plugin.rltoys.envio.policy.BoundedPolicy; 00009 import rlpark.plugin.rltoys.math.ranges.Range; 00010 import rlpark.plugin.rltoys.math.vector.RealVector; 00011 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00012 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00013 00014 public class ScaledPolicyDistribution implements BoundedPdf, PolicyParameterized { 00015 private static final long serialVersionUID = -7521424991872961399L; 00016 @Monitor 00017 protected final PolicyDistribution policy; 00018 protected final Range policyRange; 00019 protected final Range problemRange; 00020 00021 public ScaledPolicyDistribution(BoundedPolicy policy, Range problemRange) { 00022 this((PolicyDistribution) policy, policy.range(), problemRange); 00023 } 00024 00025 00026 public ScaledPolicyDistribution(PolicyDistribution policy, Range policyRange, Range problemRange) { 00027 this.policy = policy; 00028 this.policyRange = policyRange; 00029 this.problemRange = problemRange; 00030 } 00031 00032 @Override 00033 public double pi(Action a) { 00034 return policy.pi(problemToPolicy(ActionArray.toDouble(a))); 00035 } 00036 00037 @Override 00038 public PVector[] createParameters(int nbFeatures) { 00039 return policy.createParameters(nbFeatures); 00040 } 00041 00042 @Override 00043 public Action sampleAction() { 00044 return policyToProblem(ActionArray.toDouble(policy.sampleAction())); 00045 } 00046 00047 @Override 00048 public RealVector[] computeGradLog(Action a_t) { 00049 return policy.computeGradLog(problemToPolicy(ActionArray.toDouble(a_t))); 00050 } 00051 00052 private ActionArray policyToProblem(double policyAction) { 00053 double normalizedAction = normalize(policyRange, policyAction); 00054 return new ActionArray(scale(problemRange, normalizedAction)); 00055 } 00056 00057 protected ActionArray problemToPolicy(double problemAction) { 00058 double normalizedAction = normalize(problemRange, problemAction); 00059 return new ActionArray(scale(policyRange, normalizedAction)); 00060 } 00061 00062 private double normalize(Range range, double a) { 00063 return (a - range.center()) / (range.length() / 2.0); 00064 } 00065 00066 private double scale(Range range, double a) { 00067 return (a * (range.length() / 2.0)) + range.center(); 00068 } 00069 00070 @Override 00071 public int nbParameterVectors() { 00072 return policy.nbParameterVectors(); 00073 } 00074 00075 00076 @Override 00077 public double piMax() { 00078 return ((BoundedPdf) policy).piMax(); 00079 } 00080 00081 00082 @Override 00083 public void update(RealVector x) { 00084 policy.update(x); 00085 } 00086 00087 @Override 00088 public void setParameters(PVector... parameters) { 00089 ((PolicyParameterized) policy).setParameters(parameters); 00090 } 00091 00092 00093 @Override 00094 public PVector[] parameters() { 00095 return ((PolicyParameterized) policy).parameters(); 00096 } 00097 }