RLPark 1.0.0
Reinforcement Learning Framework in Java
|
00001 package rlpark.plugin.rltoys.math.vector.pool; 00002 00003 import java.util.Arrays; 00004 import java.util.Stack; 00005 00006 import rlpark.plugin.rltoys.math.vector.MutableVector; 00007 import rlpark.plugin.rltoys.math.vector.RealVector; 00008 import zephyr.plugin.core.api.monitoring.annotations.Monitor; 00009 00010 public class ThreadVectorPool implements VectorPool { 00011 class AllocatedBuffer { 00012 final MutableVector[] buffers; 00013 final int lastAllocation; 00014 final RealVector prototype; 00015 00016 AllocatedBuffer(RealVector prototype, MutableVector[] buffers, int lastAllocation) { 00017 this.prototype = prototype; 00018 this.buffers = buffers; 00019 this.lastAllocation = lastAllocation; 00020 } 00021 } 00022 00023 @Monitor 00024 int nbAllocation; 00025 private final Thread thread; 00026 private final Stack<MutableVector[]> stackedVectors = new Stack<MutableVector[]>(); 00027 private final Stack<AllocatedBuffer> stackedBuffers = new Stack<AllocatedBuffer>(); 00028 private MutableVector[] buffers; 00029 private int lastAllocation; 00030 private final RealVector prototype; 00031 private final int dimension; 00032 00033 public ThreadVectorPool(RealVector prototype, int dimension) { 00034 this.dimension = dimension; 00035 this.thread = Thread.currentThread(); 00036 this.prototype = prototype; 00037 } 00038 00039 public void allocate() { 00040 if (buffers != null) { 00041 stackedBuffers.push(new AllocatedBuffer(prototype, buffers, lastAllocation)); 00042 buffers = null; 00043 lastAllocation = -2; 00044 } 00045 buffers = stackedVectors.isEmpty() ? new MutableVector[1] : stackedVectors.pop(); 00046 lastAllocation = -1; 00047 } 00048 00049 @Override 00050 public MutableVector newVector() { 00051 return vectorCached().clear(); 00052 } 00053 00054 private MutableVector vectorCached() { 00055 if (Thread.currentThread() != thread) 00056 throw new RuntimeException("Called from a wrong thread"); 00057 lastAllocation++; 00058 if (lastAllocation == buffers.length) 00059 buffers = Arrays.copyOf(buffers, buffers.length * 2); 00060 MutableVector cached = buffers[lastAllocation]; 00061 if (cached == null) { 00062 nbAllocation++; 00063 cached = prototype.newInstance(dimension); 00064 buffers[lastAllocation] = cached; 00065 } 00066 return cached; 00067 } 00068 00069 @Override 00070 public MutableVector newVector(RealVector v) { 00071 assert dimension == v.getDimension(); 00072 return vectorCached().set(v); 00073 } 00074 00075 @Override 00076 public void releaseAll() { 00077 stackedVectors.push(buffers); 00078 if (stackedBuffers.isEmpty()) { 00079 buffers = null; 00080 lastAllocation = -2; 00081 } else { 00082 AllocatedBuffer allocated = stackedBuffers.pop(); 00083 buffers = allocated.buffers; 00084 lastAllocation = allocated.lastAllocation; 00085 } 00086 } 00087 }