import torch import pytest import numpy as np from sudoku.buffer import BufferArray # test fait to get batch def test_get_batch_on_empty(): ba = BufferArray(array_size=3, batch_size=2) with pytest.raises(AssertionError): ba.get_batch() def test_append_get_batch(): x = torch.tensor( [ [[1, 2, 3], [4, 5, 6]], [[11, 12, 13], [14, 15, 16]], ] ) y = torch.tensor( [ 1, 2, ] ) ba = BufferArray(array_size=3, batch_size=2) ba.append(0, (x, y)) idx, batch = ba.get_batch() assert batch[0].shape == x.shape assert batch[1].shape == y.shape assert idx == 0 assert len(ba.buffers[0][0]) == 0 @pytest.mark.parametrize("idx", [0, 1]) def test_append_get_batch_2(idx): x = torch.tensor( [ [[1, 2, 3], [4, 5, 6]], [[11, 12, 13], [14, 15, 16]], ] ) y = torch.tensor( [ 1, 2, ] ) ba = BufferArray(array_size=3, batch_size=2) ba.append(idx, (x, y)) ba.append(idx, (x, y)) batch_idx, batch = ba.get_batch() assert batch[0].shape == x.shape assert batch[1].shape == y.shape assert idx == batch_idx assert len(ba.buffers[idx][0]) == 2 # test buffer array -> # test -> faire des scenarios # append # append # append # batch batch # verifier que chaque batch est bien séparer # que la limite de batch est bien respecté # que le sampling se fait. # que les assert pops