File size: 3,240 Bytes
4484b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import random
import torch

# TODO intensively test it!
# we got an error pop index out of range during 


class BufferArray:
    def __init__(self, array_size, batch_size):
        self.array_size = array_size
        self.batch_size = batch_size
        self.batch_n_vectors = None

    def init_buffers(self):
        self.buffers = [
            [[] for _ in range(self.batch_n_vectors)] for _ in range(self.array_size)
        ]

    def get_batch(self):
        "return the idx of the first buffer reaching batch_size and a batch"
        assert hasattr(self, "buffers")
        for idx, buffer in enumerate(self.buffers):
            if len(buffer[0]) >= self.batch_size:
                vectors = [[] for _ in range(self.batch_n_vectors)]
                for _ in range(self.batch_size):
                    pop_idx = random.randrange(len(buffer[0]))
                    for v, b in zip(vectors, buffer):
                        v.append(b.pop(pop_idx))
                return idx, tuple([torch.stack(v, dim=0) for v in vectors])
        return 0, None

    def append(self, idx, batch: tuple):
        "append batch to the desired buffer. we suppose batch as a tuple such as (x,y)"

        if idx >= self.array_size:
            return
        if self.batch_n_vectors is None:
            self.batch_n_vectors = len(batch)
            self.init_buffers()
        else:
            assert len(batch) == self.batch_n_vectors
        for i, element_vectors in enumerate(batch):
            self.buffers[idx][i] = self.buffers[idx][i] + [
                vector for vector in element_vectors
            ]


#     def append(self, X, Y) -> None:
#         """Add experience to the buffer.

#         Args:
#             experience: tuple (state, action, reward, done, new_state)
#         """

#         X[Y == 0] = 0
#         mask = ~(X == Y).view(-1, 2 * 729).all(dim=1)

#         for x, y in zip(X[mask], Y[mask]):
#             self.buffer.append((x, y))

class Buffer:
    def __init__(self, batch_size):
        self.batch_size = batch_size
        self.batch_n_vectors = None

    def init_buffer(self):
        self.buffer = [[] for _ in range(self.batch_n_vectors)]

    def get_batch(self):
        "return the idx of the first buffer reaching batch_size and a batch"
        if not hasattr(self, "buffer"):
            return None
        if len(self.buffer[0]) >= self.batch_size:
            vectors = [[] for _ in range(self.batch_n_vectors)]
            for _ in range(self.batch_size):
                pop_idx = random.randrange(len(self.buffer[0]))
                for v, b in zip(vectors, self.buffer):
                    v.append(b.pop(pop_idx))
            return tuple([torch.stack(v, dim=0) for v in vectors])
        return None

    def append(self, batch: tuple):
        "append batch to the desired buffer. we suppose batch as a tuple such as (x,y)"

        if self.batch_n_vectors is None:
            self.batch_n_vectors = len(batch)
            self.init_buffer()
        else:
            assert len(batch) == self.batch_n_vectors
        for i, element_vectors in enumerate(batch):
            self.buffer[i] = self.buffer[i] + [
                vector for vector in element_vectors
            ]