fix pos output
Browse files- sudoku/train.py +1 -1
sudoku/train.py
CHANGED
@@ -875,7 +875,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
875 |
pos = self.search_trial(x.view(2,729), [])
|
876 |
new_x = deepcopy(x.view(2,729))
|
877 |
output = self.forward_layer(x.view(1,2,729))
|
878 |
-
pos_output =
|
879 |
if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item-self.threshold_pos[0].item()):
|
880 |
trial_abs_pres = 0
|
881 |
else:
|
|
|
875 |
pos = self.search_trial(x.view(2,729), [])
|
876 |
new_x = deepcopy(x.view(2,729))
|
877 |
output = self.forward_layer(x.view(1,2,729))
|
878 |
+
pos_output = output[0,:,pos]
|
879 |
if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item-self.threshold_pos[0].item()):
|
880 |
trial_abs_pres = 0
|
881 |
else:
|