SebastienGuissart commited on
Commit
2102761
·
verified ·
1 Parent(s): 6f43ac3

finish backtracking algo

Browse files
Files changed (1) hide show
  1. sudoku/train.py +10 -2
sudoku/train.py CHANGED
@@ -870,7 +870,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
870
  return is_valid, new_x
871
  """
872
  try:
873
- x = self.predict(x)
874
  except TrialEveryPosException:
875
  pos = self.search_trial(x.view(2,729), [])
876
  new_x = deepcopy(x.view(2,729))
@@ -883,7 +883,15 @@ class SudokuTrialErrorLightning(SudokuLightning):
883
  new_x[0,trial_abs_pres,pos]=1
884
  is_valid, new_x = self.backtracking_predict(new_x)
885
  if not is_valid:
886
- return is_valid ### TO continue
 
 
 
 
 
 
 
 
887
 
888
 
889
  def on_validation_epoch_start(self) -> None:
 
870
  return is_valid, new_x
871
  """
872
  try:
873
+ next_x = self.predict(x)
874
  except TrialEveryPosException:
875
  pos = self.search_trial(x.view(2,729), [])
876
  new_x = deepcopy(x.view(2,729))
 
883
  new_x[0,trial_abs_pres,pos]=1
884
  is_valid, new_x = self.backtracking_predict(new_x)
885
  if not is_valid:
886
+ new_x = deepcopy(x.view(2,729))
887
+ new_x[0,(trial_abs_pres+1)%2,pos]=1
888
+ return self.backtracking_predict(new_x)
889
+ return True, new_x
890
+ except ValueError:
891
+ return False, None
892
+ if next_x.sum()==729:
893
+ return True, next_x
894
+ return self.backtracking_predict(next_x)
895
 
896
 
897
  def on_validation_epoch_start(self) -> None: