File size: 1,397 Bytes
4484b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch

from sudoku.models import SudokuNet


def test_same_output_under_rotation():
    model = SudokuNet()
    arr1 = torch.zeros((1, 2, 9, 9, 9))
    arr1[0, 0, 1, 2, 3] = 1

    output_1 = model.forward(arr1.view(1, 2, 9 * 9 * 9))
    assert output_1.shape == (1, 2, 9 * 9 * 9), output_1
    arr2 = torch.zeros((1, 2, 9, 9, 9))
    arr2[0, 0, 2, 3, 4] = 1

    output_2 = model.forward(arr2.view(1, 2, 9 * 9 * 9))
    assert (
        output_1.view(1, 2, 9, 9, 9)[0, 0, 1, 2, 3]
        == output_2.view(1, 2, 9, 9, 9)[0, 0, 2, 3, 4]
    )
    assert (
        output_1.view(1, 2, 9, 9, 9)[0, 0, 1, 2, 4]
        == output_2.view(1, 2, 9, 9, 9)[0, 0, 2, 3, 6]
    )
    assert (
        output_1.view(1, 2, 9, 9, 9)[0, 1, 1, 2, 4]
        == output_2.view(1, 2, 9, 9, 9)[0, 1, 2, 3, 6]
    )
    assert (
        output_1.view(1, 2, 9, 9, 9)[0, 1, 2, 2, 4]
        == output_2.view(1, 2, 9, 9, 9)[0, 1, 1, 3, 6]
    )
    assert (
        output_1.view(1, 2, 9, 9, 9)[0, 1, 2, 3, 4]
        == output_2.view(1, 2, 9, 9, 9)[0, 1, 1, 2, 6]
    )


#  0, 1, 2 | 3, 4, 5 | 6, 7, 8
#  0, 1, a | 3, 4, 5 | 6, 7, 8
#  0, 1, 2 | b, 4, 5 | 6, 7, 8
# ----------------------------
#  0, 1, 2 | 3, 4, 5 | 6, 7, 8
#  0, 1, 2 | 3, 4, 5 | 6, 7, 8
#  0, 1, 2 | 3, 4, 5 | 6, 7, 8
# ----------------------------
#  0, 1, 2 | 3, 4, 5 | 6, 7, 8
#  0, 1, 2 | 3, 4, 5 | 6, 7, 8
#  0, 1, 2 | 3, 4, 5 | 6, 7, 8