Sebastien
first commit
4484b8a
from sudoku.loader import data_loader, DataIterBuffer, get_datasets, train_dataset
from torch.utils.data import DataLoader
def test_data_loader():
train_loader, test_loader = data_loader()
X, Y = next(iter(train_loader))
assert X.shape == Y.shape == (32, 2, 9 * 9 * 9)
assert (Y.sum(-1)[:, 0] == (9 * 9 * 8)).all()
assert (Y.sum(-1)[:, 1] == (9 * 9)).all()
assert (X.sum(-1)[:, 0] < (9 * 9 * 8)).all()
assert (X.sum(-1)[:, 1] < (9 * 9)).all()
def test_data_iter_buffer():
i = 0
data_iter = DataIterBuffer(train_dataset)
data_loader_buffer = DataLoader(data_iter, 32)
for X, Y in data_loader_buffer:
i += 1
if i % 3 == 0:
data_iter.append(X, Y)
assert i == 29
def test_max_holes():
train_dataset, test_dataset = get_datasets(train_size=2, test_size=2, max_holes=2)
X, Y = next(iter(train_dataset))
x_holes = X[1].reshape(9, 9, 9).sum(-1) == 0
print(X.shape, x_holes.shape)
assert x_holes.sum() == 2