Sebastien
first commit
4484b8a
raw
history blame contribute delete
436 Bytes
from sudoku.helper import compute_loss
import torch
def test_compute_loss():
x = torch.zeros((3, 2, 729))
y = torch.zeros((3, 2, 729))
output = torch.zeros((3, 2, 729))
y[:, 0, 0] = 1
output[0, 0, 0] = 0.1
output[1, 0, 0] = 0.1
output[2, 0, 1] = 0.1
new_x = (output > 0).type("torch.FloatTensor")
loss_error, loss_no_improve, n_error, n_no_improve = compute_loss(
x, y, output, new_x
)