Sebastien commited on
Commit
0583ac6
·
1 Parent(s): dfa6863

add notebook and fix code

Browse files
experiments/test_harder_sudoku.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
sudoku/train.py CHANGED
@@ -869,29 +869,37 @@ class SudokuTrialErrorLightning(SudokuLightning):
869
  """
870
  return is_valid, new_x
871
  """
872
- try:
873
- next_x = self.predict(x)
874
- except TrialEveryPosException:
875
- pos = self.search_trial(x.view(2,729), [])
876
- new_x = deepcopy(x.view(2,729))
877
- output = self.forward_layer(x.view(1,2,729))
878
- pos_output = output[0,:,pos]
879
- if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item-self.threshold_pres[0].item()):
880
- trial_abs_pres = 0
881
  else:
882
- trial_abs_pres = 1
883
- new_x[0,trial_abs_pres,pos]=1
884
- is_valid, new_x = self.backtracking_predict(new_x)
885
- if not is_valid:
886
- new_x = deepcopy(x.view(2,729))
887
- new_x[0,(trial_abs_pres+1)%2,pos]=1
888
- return self.backtracking_predict(new_x)
889
- return True, new_x
890
- except ValueError:
891
- return False, None
892
- if next_x.sum()==729:
893
- return True, next_x
894
- return self.backtracking_predict(next_x)
 
 
 
 
 
 
 
 
895
 
896
 
897
  def on_validation_epoch_start(self) -> None:
 
869
  """
870
  return is_valid, new_x
871
  """
872
+ next_X = deepcopy(x)
873
+ sum_1 = next_X.sum().item()
874
+ while True:
875
+ idx, next_X = self.forward(next_X.view(1,2,729))
876
+ if not self.validate_grids(next_X)[0].item():
877
+ return False, None
878
+ sum_2 = next_X.sum().item()
879
+ if sum_1==sum_2:
880
+ break
881
  else:
882
+ sum_1=sum_2
883
+ if next_X.sum()>=729:
884
+ return True, next_X
885
+
886
+ pos = self.search_trial(next_X.view(2,729), [])
887
+ new_x = deepcopy(next_X.view(1,2,729))
888
+ output = self.forward_layer(next_X.view(1,2,729))
889
+ pos_output = output[0,:,pos]
890
+ if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item()-self.threshold_pres[0].item()):
891
+ trial_abs_pres = 0
892
+ else:
893
+ trial_abs_pres = 1
894
+ new_x[0,trial_abs_pres,pos]=1
895
+ print(f'assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), trial_abs_pres)}')# i+j*9+n*9*9
896
+ is_valid, new_x = self.backtracking_predict(new_x)
897
+ if not is_valid:
898
+ new_x = deepcopy(next_X.view(1,2,729))
899
+ new_x[0,(trial_abs_pres+1)%2,pos]=1
900
+ print(f'finally assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), (trial_abs_pres+1)%2)}')# i+j*9+n*9*9
901
+ return self.backtracking_predict(new_x)
902
+ return True, new_x
903
 
904
 
905
  def on_validation_epoch_start(self) -> None: