from sudoku.helper import compute_loss | |
import torch | |
def test_compute_loss(): | |
x = torch.zeros((3, 2, 729)) | |
y = torch.zeros((3, 2, 729)) | |
output = torch.zeros((3, 2, 729)) | |
y[:, 0, 0] = 1 | |
output[0, 0, 0] = 0.1 | |
output[1, 0, 0] = 0.1 | |
output[2, 0, 1] = 0.1 | |
new_x = (output > 0).type("torch.FloatTensor") | |
loss_error, loss_no_improve, n_error, n_no_improve = compute_loss( | |
x, y, output, new_x | |
) | |