SebastienGuissart commited on
Commit
e3098c3
·
verified ·
1 Parent(s): 6a20fc0

fix pos output

Browse files
Files changed (1) hide show
  1. sudoku/train.py +1 -1
sudoku/train.py CHANGED
@@ -875,7 +875,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
875
  pos = self.search_trial(x.view(2,729), [])
876
  new_x = deepcopy(x.view(2,729))
877
  output = self.forward_layer(x.view(1,2,729))
878
- pos_output = torch.argmax(output[0,:,pos])
879
  if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item-self.threshold_pos[0].item()):
880
  trial_abs_pres = 0
881
  else:
 
875
  pos = self.search_trial(x.view(2,729), [])
876
  new_x = deepcopy(x.view(2,729))
877
  output = self.forward_layer(x.view(1,2,729))
878
+ pos_output = output[0,:,pos]
879
  if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item-self.threshold_pos[0].item()):
880
  trial_abs_pres = 0
881
  else: