Sebastien
first commit
4484b8a
import random
import torch
# TODO intensively test it!
# we got an error pop index out of range during
class BufferArray:
def __init__(self, array_size, batch_size):
self.array_size = array_size
self.batch_size = batch_size
self.batch_n_vectors = None
def init_buffers(self):
self.buffers = [
[[] for _ in range(self.batch_n_vectors)] for _ in range(self.array_size)
]
def get_batch(self):
"return the idx of the first buffer reaching batch_size and a batch"
assert hasattr(self, "buffers")
for idx, buffer in enumerate(self.buffers):
if len(buffer[0]) >= self.batch_size:
vectors = [[] for _ in range(self.batch_n_vectors)]
for _ in range(self.batch_size):
pop_idx = random.randrange(len(buffer[0]))
for v, b in zip(vectors, buffer):
v.append(b.pop(pop_idx))
return idx, tuple([torch.stack(v, dim=0) for v in vectors])
return 0, None
def append(self, idx, batch: tuple):
"append batch to the desired buffer. we suppose batch as a tuple such as (x,y)"
if idx >= self.array_size:
return
if self.batch_n_vectors is None:
self.batch_n_vectors = len(batch)
self.init_buffers()
else:
assert len(batch) == self.batch_n_vectors
for i, element_vectors in enumerate(batch):
self.buffers[idx][i] = self.buffers[idx][i] + [
vector for vector in element_vectors
]
# def append(self, X, Y) -> None:
# """Add experience to the buffer.
# Args:
# experience: tuple (state, action, reward, done, new_state)
# """
# X[Y == 0] = 0
# mask = ~(X == Y).view(-1, 2 * 729).all(dim=1)
# for x, y in zip(X[mask], Y[mask]):
# self.buffer.append((x, y))
class Buffer:
def __init__(self, batch_size):
self.batch_size = batch_size
self.batch_n_vectors = None
def init_buffer(self):
self.buffer = [[] for _ in range(self.batch_n_vectors)]
def get_batch(self):
"return the idx of the first buffer reaching batch_size and a batch"
if not hasattr(self, "buffer"):
return None
if len(self.buffer[0]) >= self.batch_size:
vectors = [[] for _ in range(self.batch_n_vectors)]
for _ in range(self.batch_size):
pop_idx = random.randrange(len(self.buffer[0]))
for v, b in zip(vectors, self.buffer):
v.append(b.pop(pop_idx))
return tuple([torch.stack(v, dim=0) for v in vectors])
return None
def append(self, batch: tuple):
"append batch to the desired buffer. we suppose batch as a tuple such as (x,y)"
if self.batch_n_vectors is None:
self.batch_n_vectors = len(batch)
self.init_buffer()
else:
assert len(batch) == self.batch_n_vectors
for i, element_vectors in enumerate(batch):
self.buffer[i] = self.buffer[i] + [
vector for vector in element_vectors
]