SebastienGuissart commited on
Commit
dfa6863
·
verified ·
1 Parent(s): e3098c3

fix threshold pos

Browse files
Files changed (1) hide show
  1. sudoku/train.py +1 -1
sudoku/train.py CHANGED
@@ -876,7 +876,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
876
  new_x = deepcopy(x.view(2,729))
877
  output = self.forward_layer(x.view(1,2,729))
878
  pos_output = output[0,:,pos]
879
- if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item-self.threshold_pos[0].item()):
880
  trial_abs_pres = 0
881
  else:
882
  trial_abs_pres = 1
 
876
  new_x = deepcopy(x.view(2,729))
877
  output = self.forward_layer(x.view(1,2,729))
878
  pos_output = output[0,:,pos]
879
+ if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item-self.threshold_pres[0].item()):
880
  trial_abs_pres = 0
881
  else:
882
  trial_abs_pres = 1