fix threshold pos
Browse files- sudoku/train.py +1 -1
sudoku/train.py
CHANGED
@@ -876,7 +876,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
876 |
new_x = deepcopy(x.view(2,729))
|
877 |
output = self.forward_layer(x.view(1,2,729))
|
878 |
pos_output = output[0,:,pos]
|
879 |
-
if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item-self.
|
880 |
trial_abs_pres = 0
|
881 |
else:
|
882 |
trial_abs_pres = 1
|
|
|
876 |
new_x = deepcopy(x.view(2,729))
|
877 |
output = self.forward_layer(x.view(1,2,729))
|
878 |
pos_output = output[0,:,pos]
|
879 |
+
if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item-self.threshold_pres[0].item()):
|
880 |
trial_abs_pres = 0
|
881 |
else:
|
882 |
trial_abs_pres = 1
|