|
from sudoku.loader import data_loader, DataIterBuffer, get_datasets, train_dataset |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
def test_data_loader(): |
|
train_loader, test_loader = data_loader() |
|
|
|
X, Y = next(iter(train_loader)) |
|
assert X.shape == Y.shape == (32, 2, 9 * 9 * 9) |
|
assert (Y.sum(-1)[:, 0] == (9 * 9 * 8)).all() |
|
assert (Y.sum(-1)[:, 1] == (9 * 9)).all() |
|
assert (X.sum(-1)[:, 0] < (9 * 9 * 8)).all() |
|
assert (X.sum(-1)[:, 1] < (9 * 9)).all() |
|
|
|
|
|
def test_data_iter_buffer(): |
|
i = 0 |
|
data_iter = DataIterBuffer(train_dataset) |
|
data_loader_buffer = DataLoader(data_iter, 32) |
|
for X, Y in data_loader_buffer: |
|
i += 1 |
|
if i % 3 == 0: |
|
data_iter.append(X, Y) |
|
assert i == 29 |
|
|
|
|
|
def test_max_holes(): |
|
train_dataset, test_dataset = get_datasets(train_size=2, test_size=2, max_holes=2) |
|
X, Y = next(iter(train_dataset)) |
|
x_holes = X[1].reshape(9, 9, 9).sum(-1) == 0 |
|
print(X.shape, x_holes.shape) |
|
assert x_holes.sum() == 2 |
|
|