RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.math.vector.implementations; 00002 00003 import java.util.Arrays; 00004 00005 import rlpark.plugin.rltoys.math.vector.BinaryVector; 00006 import rlpark.plugin.rltoys.math.vector.MutableVector; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 import rlpark.plugin.rltoys.math.vector.SparseRealVector; 00009 import rlpark.plugin.rltoys.math.vector.SparseVector; 00010 import zephyr.plugin.core.api.monitoring.abstracts.DataMonitor; 00011 import zephyr.plugin.core.api.monitoring.abstracts.MonitorContainer; 00012 import zephyr.plugin.core.api.monitoring.abstracts.Monitored; 00013 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00014 00015 public class SVector extends AbstractVector implements SparseRealVector, MonitorContainer { 00016 private static final long serialVersionUID = -3324707947990480491L; 00017 public int[] activeIndexes; 00018 public double[] values; 00019 public int[] indexesPosition; 00020 @Monitor 00021 int nbActive = 0; 00022 00023 public SVector(int size) { 00024 super(size); 00025 values = new double[10]; 00026 activeIndexes = new int[10]; 00027 indexesPosition = new int[size]; 00028 Arrays.fill(indexesPosition, -1); 00029 } 00030 00031 public SVector(SVector other) { 00032 this(other.getDimension()); 00033 set(other); 00034 } 00035 00036 public SVector(BVector other, double value) { 00037 this(other.size); 00038 if (value == 0) 00039 return; 00040 setFromBVector(other, value); 00041 } 00042 00043 @Override 00044 public SVector copy() { 00045 return new SVector(this); 00046 } 00047 00048 @Override 00049 public MutableVector newInstance(int size) { 00050 return new SVector(size); 00051 } 00052 00053 @Override 00054 public MutableVector copyAsMutable() { 00055 return copy(); 00056 } 00057 00058 @Override 00059 public MutableVector addToSelf(RealVector other) { 00060 return addToSelf(1, other); 00061 } 00062 00063 private MutableVector addToSelf(SVector other, double factor) { 00064 for (int position = 0; position < other.nbActive; position++) { 00065 final int index = other.activeIndexes[position]; 00066 setNonZeroEntry(index, getEntry(index) + factor * other.values[position]); 00067 } 00068 return this; 00069 } 00070 00071 private MutableVector addToSelf(BinaryVector other, double factor) { 00072 int[] nonNullIndexes = other.getActiveIndexes(); 00073 for (int index : nonNullIndexes) 00074 setNonZeroEntry(index, getEntry(index) + factor); 00075 return this; 00076 } 00077 00078 @Override 00079 public MutableVector addToSelf(double factor, RealVector other) { 00080 if (other instanceof SVector) 00081 return addToSelf((SVector) other, factor); 00082 if (other instanceof BinaryVector) 00083 return addToSelf((BinaryVector) other, factor); 00084 for (int i = 0; i < other.getDimension(); i++) 00085 setEntry(i, getEntry(i) + factor * other.getEntry(i)); 00086 return this; 00087 } 00088 00089 @Override 00090 public MutableVector subtractToSelf(RealVector other) { 00091 return addToSelf(-1, other); 00092 } 00093 00094 @Override 00095 public MutableVector mapMultiplyToSelf(double factor) { 00096 if (factor == 0) { 00097 clear(); 00098 return this; 00099 } 00100 for (int position = 0; position < nbActive; position++) 00101 values[position] *= factor; 00102 return this; 00103 } 00104 00105 @Override 00106 public void removeEntry(int index) { 00107 int position = indexesPosition[index]; 00108 if (position != -1) 00109 removeEntry(position, index); 00110 } 00111 00112 @Override 00113 public void setEntry(int index, double value) { 00114 if (value == 0) 00115 removeEntry(index); 00116 else 00117 setNonZeroEntry(index, value); 00118 } 00119 00120 private void setNonZeroEntry(int index, double value) { 00121 int position = indexesPosition[index]; 00122 if (position != -1) 00123 updateEntry(index, value, position); 00124 else 00125 insertEntry(index, value); 00126 } 00127 00128 protected void insertEntry(int index, double value) { 00129 appendEntry(index, value); 00130 } 00131 00132 protected void appendEntry(int index, double value) { 00133 allocate(nbActive + 1); 00134 activeIndexes[nbActive] = index; 00135 values[nbActive] = value; 00136 indexesPosition[index] = nbActive; 00137 nbActive++; 00138 } 00139 00140 protected void allocate(int sizeRequired) { 00141 if (activeIndexes.length >= sizeRequired) 00142 return; 00143 int newCapacity = (sizeRequired * 3) / 2 + 1; 00144 activeIndexes = Arrays.copyOf(activeIndexes, newCapacity); 00145 values = Arrays.copyOf(values, newCapacity); 00146 } 00147 00148 protected void updateEntry(int index, double value, int position) { 00149 values[position] = value; 00150 } 00151 00152 protected void removeEntry(int position, int index) { 00153 swapEntry(nbActive - 1, position); 00154 indexesPosition[activeIndexes[nbActive - 1]] = -1; 00155 nbActive--; 00156 } 00157 00158 private void swapEntry(int positionA, int positionB) { 00159 final int indexA = activeIndexes[positionA]; 00160 final double valueA = values[positionA]; 00161 final int indexB = activeIndexes[positionB]; 00162 final double valueB = values[positionB]; 00163 indexesPosition[indexA] = positionB; 00164 indexesPosition[indexB] = positionA; 00165 activeIndexes[positionA] = indexB; 00166 activeIndexes[positionB] = indexA; 00167 values[positionA] = valueB; 00168 values[positionB] = valueA; 00169 } 00170 00171 @Override 00172 public MutableVector ebeDivideToSelf(RealVector other) { 00173 for (int position = 0; position < nbActive; position++) { 00174 final int index = activeIndexes[position]; 00175 values[position] /= other.getEntry(index); 00176 } 00177 return this; 00178 } 00179 00180 @Override 00181 public MutableVector ebeMultiplyToSelf(RealVector other) { 00182 int position = 0; 00183 while (position < nbActive) { 00184 final int index = activeIndexes[position]; 00185 double value = values[position] * other.getEntry(index); 00186 if (value != 0) { 00187 values[position] = value; 00188 position++; 00189 } else 00190 removeEntry(position, index); 00191 } 00192 return this; 00193 } 00194 00195 @Override 00196 public double getEntry(int index) { 00197 final int position = indexesPosition[index]; 00198 return position != -1 ? values[position] : 0; 00199 } 00200 00201 @Override 00202 public double[] accessData() { 00203 double[] result = new double[size]; 00204 for (int position = 0; position < nbActive; position++) { 00205 final int index = activeIndexes[position]; 00206 result[index] = values[position]; 00207 } 00208 return result; 00209 } 00210 00211 @Override 00212 public SVector clear() { 00213 for (int i = 0; i < nbActive; i++) 00214 indexesPosition[activeIndexes[i]] = -1; 00215 nbActive = 0; 00216 return this; 00217 } 00218 00219 @Override 00220 public double dotProduct(double[] data) { 00221 double result = 0.0; 00222 for (int position = 0; position < nbActive; position++) 00223 result += data[activeIndexes[position]] * values[position]; 00224 return result; 00225 } 00226 00227 @Override 00228 public double dotProduct(RealVector other) { 00229 if (other instanceof SparseVector && ((SparseVector) other).nonZeroElements() < nonZeroElements()) 00230 return other.dotProduct(this); 00231 double result = 0.0; 00232 for (int position = 0; position < nbActive; position++) 00233 result += other.getEntry(activeIndexes[position]) * values[position]; 00234 return result; 00235 } 00236 00237 @Override 00238 public void addSelfTo(double[] data) { 00239 for (int position = 0; position < nbActive; position++) 00240 data[activeIndexes[position]] += values[position]; 00241 } 00242 00243 @Override 00244 public void subtractSelfTo(double[] data) { 00245 for (int position = 0; position < nbActive; position++) 00246 data[activeIndexes[position]] -= values[position]; 00247 } 00248 00249 public void addSelfTo(double factor, double[] data) { 00250 for (int position = 0; position < nbActive; position++) 00251 data[activeIndexes[position]] += factor * values[position]; 00252 } 00253 00254 public int[] getActiveIndexes() { 00255 return Arrays.copyOf(activeIndexes, nbActive); 00256 } 00257 00258 @Override 00259 public int nonZeroElements() { 00260 return nbActive; 00261 } 00262 00263 @Override 00264 public String toString() { 00265 StringBuilder result = new StringBuilder("["); 00266 for (int position = 0; position < nbActive; position++) { 00267 result.append(activeIndexes[position]); 00268 result.append(":"); 00269 result.append(values[position]); 00270 if (position < nbActive - 1) 00271 result.append(", "); 00272 } 00273 result.append("]"); 00274 return result.toString(); 00275 } 00276 00277 @Override 00278 public SVector set(RealVector other) { 00279 if (other instanceof SVector) 00280 return set((SVector) other); 00281 if (other instanceof BVector) 00282 return setFromBVector((BVector) other, 1.0); 00283 clear(); 00284 for (int i = 0; i < other.getDimension(); i++) 00285 setEntry(i, other.getEntry(i)); 00286 return this; 00287 } 00288 00289 private SVector set(SVector other) { 00290 clear(); 00291 allocate(other.nbActive); 00292 nbActive = other.nbActive; 00293 System.arraycopy(other.activeIndexes, 0, activeIndexes, 0, nbActive); 00294 System.arraycopy(other.values, 0, values, 0, nbActive); 00295 for (int position = 0; position < nbActive; position++) 00296 indexesPosition[activeIndexes[position]] = position; 00297 return this; 00298 } 00299 00300 00301 @Override 00302 public MutableVector set(RealVector other, int start) { 00303 for (int i = 0; i < other.getDimension(); i++) 00304 setEntry(start + i, other.getEntry(i)); 00305 return this; 00306 } 00307 00308 private SVector setFromBVector(BVector other, double value) { 00309 clear(); 00310 allocate(other.nonZeroElements()); 00311 for (int i = 0; i < other.nonZeroElements(); i++) 00312 setNonZeroEntry(other.activeIndexes[i], value); 00313 return this; 00314 } 00315 00316 @Override 00317 public void addToMonitor(DataMonitor monitor) { 00318 monitor.add("l1norm", new Monitored() { 00319 @Override 00320 public double monitoredValue() { 00321 return Vectors.l1Norm(SVector.this); 00322 } 00323 }); 00324 } 00325 00326 @Override 00327 public int[] nonZeroIndexes() { 00328 return activeIndexes; 00329 } 00330 00331 @Override 00332 public double sum() { 00333 double sum = 0; 00334 for (int i = 0; i < nbActive; i++) 00335 sum += values[i]; 00336 return sum; 00337 } 00338 }