File size: 1,001 Bytes
4484b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from sudoku.loader import data_loader, DataIterBuffer, get_datasets, train_dataset
from torch.utils.data import DataLoader


def test_data_loader():
    train_loader, test_loader = data_loader()

    X, Y = next(iter(train_loader))
    assert X.shape == Y.shape == (32, 2, 9 * 9 * 9)
    assert (Y.sum(-1)[:, 0] == (9 * 9 * 8)).all()
    assert (Y.sum(-1)[:, 1] == (9 * 9)).all()
    assert (X.sum(-1)[:, 0] < (9 * 9 * 8)).all()
    assert (X.sum(-1)[:, 1] < (9 * 9)).all()


def test_data_iter_buffer():
    i = 0
    data_iter = DataIterBuffer(train_dataset)
    data_loader_buffer = DataLoader(data_iter, 32)
    for X, Y in data_loader_buffer:
        i += 1
        if i % 3 == 0:
            data_iter.append(X, Y)
    assert i == 29


def test_max_holes():
    train_dataset, test_dataset = get_datasets(train_size=2, test_size=2, max_holes=2)
    X, Y = next(iter(train_dataset))
    x_holes = X[1].reshape(9, 9, 9).sum(-1) == 0
    print(X.shape, x_holes.shape)
    assert x_holes.sum() == 2