import random import torch # TODO intensively test it! # we got an error pop index out of range during class BufferArray: def __init__(self, array_size, batch_size): self.array_size = array_size self.batch_size = batch_size self.batch_n_vectors = None def init_buffers(self): self.buffers = [ [[] for _ in range(self.batch_n_vectors)] for _ in range(self.array_size) ] def get_batch(self): "return the idx of the first buffer reaching batch_size and a batch" assert hasattr(self, "buffers") for idx, buffer in enumerate(self.buffers): if len(buffer[0]) >= self.batch_size: vectors = [[] for _ in range(self.batch_n_vectors)] for _ in range(self.batch_size): pop_idx = random.randrange(len(buffer[0])) for v, b in zip(vectors, buffer): v.append(b.pop(pop_idx)) return idx, tuple([torch.stack(v, dim=0) for v in vectors]) return 0, None def append(self, idx, batch: tuple): "append batch to the desired buffer. we suppose batch as a tuple such as (x,y)" if idx >= self.array_size: return if self.batch_n_vectors is None: self.batch_n_vectors = len(batch) self.init_buffers() else: assert len(batch) == self.batch_n_vectors for i, element_vectors in enumerate(batch): self.buffers[idx][i] = self.buffers[idx][i] + [ vector for vector in element_vectors ] # def append(self, X, Y) -> None: # """Add experience to the buffer. # Args: # experience: tuple (state, action, reward, done, new_state) # """ # X[Y == 0] = 0 # mask = ~(X == Y).view(-1, 2 * 729).all(dim=1) # for x, y in zip(X[mask], Y[mask]): # self.buffer.append((x, y)) class Buffer: def __init__(self, batch_size): self.batch_size = batch_size self.batch_n_vectors = None def init_buffer(self): self.buffer = [[] for _ in range(self.batch_n_vectors)] def get_batch(self): "return the idx of the first buffer reaching batch_size and a batch" if not hasattr(self, "buffer"): return None if len(self.buffer[0]) >= self.batch_size: vectors = [[] for _ in range(self.batch_n_vectors)] for _ in range(self.batch_size): pop_idx = random.randrange(len(self.buffer[0])) for v, b in zip(vectors, self.buffer): v.append(b.pop(pop_idx)) return tuple([torch.stack(v, dim=0) for v in vectors]) return None def append(self, batch: tuple): "append batch to the desired buffer. we suppose batch as a tuple such as (x,y)" if self.batch_n_vectors is None: self.batch_n_vectors = len(batch) self.init_buffer() else: assert len(batch) == self.batch_n_vectors for i, element_vectors in enumerate(batch): self.buffer[i] = self.buffer[i] + [ vector for vector in element_vectors ]