|
import torch |
|
import pytest |
|
import numpy as np |
|
|
|
from sudoku.buffer import BufferArray |
|
|
|
|
|
|
|
def test_get_batch_on_empty(): |
|
ba = BufferArray(array_size=3, batch_size=2) |
|
with pytest.raises(AssertionError): |
|
ba.get_batch() |
|
|
|
|
|
def test_append_get_batch(): |
|
x = torch.tensor( |
|
[ |
|
[[1, 2, 3], [4, 5, 6]], |
|
[[11, 12, 13], [14, 15, 16]], |
|
] |
|
) |
|
y = torch.tensor( |
|
[ |
|
1, |
|
2, |
|
] |
|
) |
|
ba = BufferArray(array_size=3, batch_size=2) |
|
ba.append(0, (x, y)) |
|
idx, batch = ba.get_batch() |
|
assert batch[0].shape == x.shape |
|
assert batch[1].shape == y.shape |
|
assert idx == 0 |
|
assert len(ba.buffers[0][0]) == 0 |
|
|
|
|
|
@pytest.mark.parametrize("idx", [0, 1]) |
|
def test_append_get_batch_2(idx): |
|
x = torch.tensor( |
|
[ |
|
[[1, 2, 3], [4, 5, 6]], |
|
[[11, 12, 13], [14, 15, 16]], |
|
] |
|
) |
|
y = torch.tensor( |
|
[ |
|
1, |
|
2, |
|
] |
|
) |
|
ba = BufferArray(array_size=3, batch_size=2) |
|
ba.append(idx, (x, y)) |
|
ba.append(idx, (x, y)) |
|
batch_idx, batch = ba.get_batch() |
|
assert batch[0].shape == x.shape |
|
assert batch[1].shape == y.shape |
|
assert idx == batch_idx |
|
assert len(ba.buffers[idx][0]) == 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|