File size: 1,529 Bytes
4484b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import pytest
import numpy as np

from sudoku.buffer import BufferArray


# test fait to get batch
def test_get_batch_on_empty():
    ba = BufferArray(array_size=3, batch_size=2)
    with pytest.raises(AssertionError):
        ba.get_batch()


def test_append_get_batch():
    x = torch.tensor(
        [
            [[1, 2, 3], [4, 5, 6]],
            [[11, 12, 13], [14, 15, 16]],
        ]
    )
    y = torch.tensor(
        [
            1,
            2,
        ]
    )
    ba = BufferArray(array_size=3, batch_size=2)
    ba.append(0, (x, y))
    idx, batch = ba.get_batch()
    assert batch[0].shape == x.shape
    assert batch[1].shape == y.shape
    assert idx == 0
    assert len(ba.buffers[0][0]) == 0


@pytest.mark.parametrize("idx", [0, 1])
def test_append_get_batch_2(idx):
    x = torch.tensor(
        [
            [[1, 2, 3], [4, 5, 6]],
            [[11, 12, 13], [14, 15, 16]],
        ]
    )
    y = torch.tensor(
        [
            1,
            2,
        ]
    )
    ba = BufferArray(array_size=3, batch_size=2)
    ba.append(idx, (x, y))
    ba.append(idx, (x, y))
    batch_idx, batch = ba.get_batch()
    assert batch[0].shape == x.shape
    assert batch[1].shape == y.shape
    assert idx == batch_idx
    assert len(ba.buffers[idx][0]) == 2


# test buffer array ->

# test -> faire des scenarios
# append
# append
# append
# batch batch

# verifier que chaque batch est bien séparer
# que la limite de batch est bien respecté
# que le sampling se fait.

# que les assert pops