Sebastien
first commit
4484b8a
import torch
import pytest
import numpy as np
from sudoku.buffer import BufferArray
# test fait to get batch
def test_get_batch_on_empty():
ba = BufferArray(array_size=3, batch_size=2)
with pytest.raises(AssertionError):
ba.get_batch()
def test_append_get_batch():
x = torch.tensor(
[
[[1, 2, 3], [4, 5, 6]],
[[11, 12, 13], [14, 15, 16]],
]
)
y = torch.tensor(
[
1,
2,
]
)
ba = BufferArray(array_size=3, batch_size=2)
ba.append(0, (x, y))
idx, batch = ba.get_batch()
assert batch[0].shape == x.shape
assert batch[1].shape == y.shape
assert idx == 0
assert len(ba.buffers[0][0]) == 0
@pytest.mark.parametrize("idx", [0, 1])
def test_append_get_batch_2(idx):
x = torch.tensor(
[
[[1, 2, 3], [4, 5, 6]],
[[11, 12, 13], [14, 15, 16]],
]
)
y = torch.tensor(
[
1,
2,
]
)
ba = BufferArray(array_size=3, batch_size=2)
ba.append(idx, (x, y))
ba.append(idx, (x, y))
batch_idx, batch = ba.get_batch()
assert batch[0].shape == x.shape
assert batch[1].shape == y.shape
assert idx == batch_idx
assert len(ba.buffers[idx][0]) == 2
# test buffer array ->
# test -> faire des scenarios
# append
# append
# append
# batch batch
# verifier que chaque batch est bien séparer
# que la limite de batch est bien respecté
# que le sampling se fait.
# que les assert pops