File size: 436 Bytes
4484b8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from sudoku.helper import compute_loss
import torch
def test_compute_loss():
x = torch.zeros((3, 2, 729))
y = torch.zeros((3, 2, 729))
output = torch.zeros((3, 2, 729))
y[:, 0, 0] = 1
output[0, 0, 0] = 0.1
output[1, 0, 0] = 0.1
output[2, 0, 1] = 0.1
new_x = (output > 0).type("torch.FloatTensor")
loss_error, loss_no_improve, n_error, n_no_improve = compute_loss(
x, y, output, new_x
)
|