RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures; 00002 00003 import java.util.ArrayList; 00004 import java.util.List; 00005 00006 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.BoundedPdf; 00007 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution; 00008 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyParameterized; 00009 import rlpark.plugin.rltoys.envio.actions.Action; 00010 import rlpark.plugin.rltoys.envio.actions.ActionArray; 00011 import rlpark.plugin.rltoys.math.vector.RealVector; 00012 import rlpark.plugin.rltoys.math.vector.implementations.PVector; 00013 import zephyr.plugin.core.api.monitoring.annotations.IgnoreMonitor; 00014 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00015 00016 @Monitor 00017 public class JointDistribution implements PolicyParameterized, BoundedPdf { 00018 private static final long serialVersionUID = -7545331400083047916L; 00019 protected final PolicyDistribution[] distributions; 00020 @IgnoreMonitor 00021 private int[] weightsToAction; 00022 00023 public JointDistribution(PolicyDistribution[] distributions) { 00024 this.distributions = distributions; 00025 } 00026 00027 @Override 00028 public double pi(Action a) { 00029 double product = 1.0; 00030 for (int i = 0; i < distributions.length; i++) 00031 product *= distributions[i].pi(ActionArray.getDim(a, i)); 00032 return product; 00033 } 00034 00035 @Override 00036 public ActionArray sampleAction() { 00037 List<ActionArray> actions = new ArrayList<ActionArray>(); 00038 int nbDimension = 0; 00039 for (PolicyDistribution distribution : distributions) { 00040 ActionArray a = (ActionArray) distribution.sampleAction(); 00041 nbDimension += a.actions.length; 00042 actions.add(a); 00043 } 00044 double[] result = new double[nbDimension]; 00045 int currentPosition = 0; 00046 for (ActionArray a : actions) { 00047 System.arraycopy(a.actions, 0, result, currentPosition, a.actions.length); 00048 currentPosition += a.actions.length; 00049 } 00050 return new ActionArray(result); 00051 } 00052 00053 @Override 00054 public PVector[] createParameters(int nbFeatures) { 00055 List<PVector> parameters = new ArrayList<PVector>(); 00056 List<Integer> parametersToAction = new ArrayList<Integer>(); 00057 for (int i = 0; i < distributions.length; i++) 00058 for (PVector parameterVector : distributions[i].createParameters(nbFeatures)) { 00059 parameters.add(parameterVector); 00060 parametersToAction.add(i); 00061 } 00062 PVector[] result = new PVector[parameters.size()]; 00063 parameters.toArray(result); 00064 weightsToAction = new int[parameters.size()]; 00065 for (int i = 0; i < weightsToAction.length; i++) 00066 weightsToAction[i] = parametersToAction.get(i); 00067 return result; 00068 } 00069 00070 @Override 00071 public RealVector[] computeGradLog(Action a_t) { 00072 List<RealVector> gradLogs = new ArrayList<RealVector>(); 00073 for (int i = 0; i < distributions.length; i++) { 00074 PolicyDistribution distribution = distributions[i]; 00075 RealVector[] gradLog = distribution.computeGradLog(ActionArray.getDim(a_t, i)); 00076 for (RealVector parameterVector : gradLog) 00077 gradLogs.add(parameterVector); 00078 } 00079 RealVector[] result = new RealVector[gradLogs.size()]; 00080 gradLogs.toArray(result); 00081 return result; 00082 } 00083 00084 public int weightsIndexToActionIndex(int i) { 00085 return weightsToAction[i]; 00086 } 00087 00088 public PolicyDistribution policy(int actionIndex) { 00089 return distributions[actionIndex]; 00090 } 00091 00092 @Override 00093 public int nbParameterVectors() { 00094 int result = 0; 00095 for (PolicyDistribution distribution : distributions) 00096 result += distribution.nbParameterVectors(); 00097 return result; 00098 } 00099 00100 @Override 00101 public double piMax() { 00102 double result = 1; 00103 for (PolicyDistribution distribution : distributions) 00104 result *= ((BoundedPdf) distribution).piMax(); 00105 return result; 00106 } 00107 00108 @Override 00109 public void update(RealVector x) { 00110 for (PolicyDistribution distribution : distributions) 00111 distribution.update(x); 00112 } 00113 00114 public PolicyDistribution[] policies() { 00115 return distributions; 00116 } 00117 00118 @Override 00119 public void setParameters(PVector... parameters) { 00120 int index = 0; 00121 for (PolicyDistribution distribution : distributions) { 00122 PVector[] u = new PVector[distribution.nbParameterVectors()]; 00123 System.arraycopy(parameters, index, u, 0, u.length); 00124 ((PolicyParameterized) distribution).setParameters(u); 00125 index += u.length; 00126 } 00127 } 00128 00129 @Override 00130 public PVector[] parameters() { 00131 PVector[] parameters = new PVector[nbParameterVectors()]; 00132 int index = 0; 00133 for (PolicyDistribution distribution : distributions) { 00134 System.arraycopy(((PolicyParameterized) distribution).parameters(), 0, parameters, index, 00135 distribution.nbParameterVectors()); 00136 index += distribution.nbParameterVectors(); 00137 } 00138 return parameters; 00139 } 00140 }