RLPark 1.0.0
Reinforcement Learning Framework in Java

JointDistribution.java

Go to the documentation of this file.
00001 package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures;
00002 
00003 import java.util.ArrayList;
00004 import java.util.List;
00005 
00006 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.BoundedPdf;
00007 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
00008 import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyParameterized;
00009 import rlpark.plugin.rltoys.envio.actions.Action;
00010 import rlpark.plugin.rltoys.envio.actions.ActionArray;
00011 import rlpark.plugin.rltoys.math.vector.RealVector;
00012 import rlpark.plugin.rltoys.math.vector.implementations.PVector;
00013 import zephyr.plugin.core.api.monitoring.annotations.IgnoreMonitor;
00014 import zephyr.plugin.core.api.monitoring.annotations.Monitor;
00015 
00016 @Monitor
00017 public class JointDistribution implements PolicyParameterized, BoundedPdf {
00018   private static final long serialVersionUID = -7545331400083047916L;
00019   protected final PolicyDistribution[] distributions;
00020   @IgnoreMonitor
00021   private int[] weightsToAction;
00022 
00023   public JointDistribution(PolicyDistribution[] distributions) {
00024     this.distributions = distributions;
00025   }
00026 
00027   @Override
00028   public double pi(Action a) {
00029     double product = 1.0;
00030     for (int i = 0; i < distributions.length; i++)
00031       product *= distributions[i].pi(ActionArray.getDim(a, i));
00032     return product;
00033   }
00034 
00035   @Override
00036   public ActionArray sampleAction() {
00037     List<ActionArray> actions = new ArrayList<ActionArray>();
00038     int nbDimension = 0;
00039     for (PolicyDistribution distribution : distributions) {
00040       ActionArray a = (ActionArray) distribution.sampleAction();
00041       nbDimension += a.actions.length;
00042       actions.add(a);
00043     }
00044     double[] result = new double[nbDimension];
00045     int currentPosition = 0;
00046     for (ActionArray a : actions) {
00047       System.arraycopy(a.actions, 0, result, currentPosition, a.actions.length);
00048       currentPosition += a.actions.length;
00049     }
00050     return new ActionArray(result);
00051   }
00052 
00053   @Override
00054   public PVector[] createParameters(int nbFeatures) {
00055     List<PVector> parameters = new ArrayList<PVector>();
00056     List<Integer> parametersToAction = new ArrayList<Integer>();
00057     for (int i = 0; i < distributions.length; i++)
00058       for (PVector parameterVector : distributions[i].createParameters(nbFeatures)) {
00059         parameters.add(parameterVector);
00060         parametersToAction.add(i);
00061       }
00062     PVector[] result = new PVector[parameters.size()];
00063     parameters.toArray(result);
00064     weightsToAction = new int[parameters.size()];
00065     for (int i = 0; i < weightsToAction.length; i++)
00066       weightsToAction[i] = parametersToAction.get(i);
00067     return result;
00068   }
00069 
00070   @Override
00071   public RealVector[] computeGradLog(Action a_t) {
00072     List<RealVector> gradLogs = new ArrayList<RealVector>();
00073     for (int i = 0; i < distributions.length; i++) {
00074       PolicyDistribution distribution = distributions[i];
00075       RealVector[] gradLog = distribution.computeGradLog(ActionArray.getDim(a_t, i));
00076       for (RealVector parameterVector : gradLog)
00077         gradLogs.add(parameterVector);
00078     }
00079     RealVector[] result = new RealVector[gradLogs.size()];
00080     gradLogs.toArray(result);
00081     return result;
00082   }
00083 
00084   public int weightsIndexToActionIndex(int i) {
00085     return weightsToAction[i];
00086   }
00087 
00088   public PolicyDistribution policy(int actionIndex) {
00089     return distributions[actionIndex];
00090   }
00091 
00092   @Override
00093   public int nbParameterVectors() {
00094     int result = 0;
00095     for (PolicyDistribution distribution : distributions)
00096       result += distribution.nbParameterVectors();
00097     return result;
00098   }
00099 
00100   @Override
00101   public double piMax() {
00102     double result = 1;
00103     for (PolicyDistribution distribution : distributions)
00104       result *= ((BoundedPdf) distribution).piMax();
00105     return result;
00106   }
00107 
00108   @Override
00109   public void update(RealVector x) {
00110     for (PolicyDistribution distribution : distributions)
00111       distribution.update(x);
00112   }
00113 
00114   public PolicyDistribution[] policies() {
00115     return distributions;
00116   }
00117 
00118   @Override
00119   public void setParameters(PVector... parameters) {
00120     int index = 0;
00121     for (PolicyDistribution distribution : distributions) {
00122       PVector[] u = new PVector[distribution.nbParameterVectors()];
00123       System.arraycopy(parameters, index, u, 0, u.length);
00124       ((PolicyParameterized) distribution).setParameters(u);
00125       index += u.length;
00126     }
00127   }
00128 
00129   @Override
00130   public PVector[] parameters() {
00131     PVector[] parameters = new PVector[nbParameterVectors()];
00132     int index = 0;
00133     for (PolicyDistribution distribution : distributions) {
00134       System.arraycopy(((PolicyParameterized) distribution).parameters(), 0, parameters, index,
00135                        distribution.nbParameterVectors());
00136       index += distribution.nbParameterVectors();
00137     }
00138     return parameters;
00139   }
00140 }
 All Classes Namespaces Files Functions Variables Enumerations
Zephyr
RLPark