File size: 436 Bytes
4484b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from sudoku.helper import compute_loss
import torch


def test_compute_loss():
    x = torch.zeros((3, 2, 729))
    y = torch.zeros((3, 2, 729))
    output = torch.zeros((3, 2, 729))
    y[:, 0, 0] = 1
    output[0, 0, 0] = 0.1
    output[1, 0, 0] = 0.1
    output[2, 0, 1] = 0.1
    new_x = (output > 0).type("torch.FloatTensor")
    loss_error, loss_no_improve, n_error, n_no_improve = compute_loss(
        x, y, output, new_x
    )