|
import random |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
class BufferArray: |
|
def __init__(self, array_size, batch_size): |
|
self.array_size = array_size |
|
self.batch_size = batch_size |
|
self.batch_n_vectors = None |
|
|
|
def init_buffers(self): |
|
self.buffers = [ |
|
[[] for _ in range(self.batch_n_vectors)] for _ in range(self.array_size) |
|
] |
|
|
|
def get_batch(self): |
|
"return the idx of the first buffer reaching batch_size and a batch" |
|
assert hasattr(self, "buffers") |
|
for idx, buffer in enumerate(self.buffers): |
|
if len(buffer[0]) >= self.batch_size: |
|
vectors = [[] for _ in range(self.batch_n_vectors)] |
|
for _ in range(self.batch_size): |
|
pop_idx = random.randrange(len(buffer[0])) |
|
for v, b in zip(vectors, buffer): |
|
v.append(b.pop(pop_idx)) |
|
return idx, tuple([torch.stack(v, dim=0) for v in vectors]) |
|
return 0, None |
|
|
|
def append(self, idx, batch: tuple): |
|
"append batch to the desired buffer. we suppose batch as a tuple such as (x,y)" |
|
|
|
if idx >= self.array_size: |
|
return |
|
if self.batch_n_vectors is None: |
|
self.batch_n_vectors = len(batch) |
|
self.init_buffers() |
|
else: |
|
assert len(batch) == self.batch_n_vectors |
|
for i, element_vectors in enumerate(batch): |
|
self.buffers[idx][i] = self.buffers[idx][i] + [ |
|
vector for vector in element_vectors |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Buffer: |
|
def __init__(self, batch_size): |
|
self.batch_size = batch_size |
|
self.batch_n_vectors = None |
|
|
|
def init_buffer(self): |
|
self.buffer = [[] for _ in range(self.batch_n_vectors)] |
|
|
|
def get_batch(self): |
|
"return the idx of the first buffer reaching batch_size and a batch" |
|
if not hasattr(self, "buffer"): |
|
return None |
|
if len(self.buffer[0]) >= self.batch_size: |
|
vectors = [[] for _ in range(self.batch_n_vectors)] |
|
for _ in range(self.batch_size): |
|
pop_idx = random.randrange(len(self.buffer[0])) |
|
for v, b in zip(vectors, self.buffer): |
|
v.append(b.pop(pop_idx)) |
|
return tuple([torch.stack(v, dim=0) for v in vectors]) |
|
return None |
|
|
|
def append(self, batch: tuple): |
|
"append batch to the desired buffer. we suppose batch as a tuple such as (x,y)" |
|
|
|
if self.batch_n_vectors is None: |
|
self.batch_n_vectors = len(batch) |
|
self.init_buffer() |
|
else: |
|
assert len(batch) == self.batch_n_vectors |
|
for i, element_vectors in enumerate(batch): |
|
self.buffer[i] = self.buffer[i] + [ |
|
vector for vector in element_vectors |
|
] |
|
|