from sudoku.loader import data_loader, DataIterBuffer, get_datasets, train_dataset from torch.utils.data import DataLoader def test_data_loader(): train_loader, test_loader = data_loader() X, Y = next(iter(train_loader)) assert X.shape == Y.shape == (32, 2, 9 * 9 * 9) assert (Y.sum(-1)[:, 0] == (9 * 9 * 8)).all() assert (Y.sum(-1)[:, 1] == (9 * 9)).all() assert (X.sum(-1)[:, 0] < (9 * 9 * 8)).all() assert (X.sum(-1)[:, 1] < (9 * 9)).all() def test_data_iter_buffer(): i = 0 data_iter = DataIterBuffer(train_dataset) data_loader_buffer = DataLoader(data_iter, 32) for X, Y in data_loader_buffer: i += 1 if i % 3 == 0: data_iter.append(X, Y) assert i == 29 def test_max_holes(): train_dataset, test_dataset = get_datasets(train_size=2, test_size=2, max_holes=2) X, Y = next(iter(train_dataset)) x_holes = X[1].reshape(9, 9, 9).sum(-1) == 0 print(X.shape, x_holes.shape) assert x_holes.sum() == 2