Sebastien commited on
Commit
19320e5
·
1 Parent(s): 0583ac6

improving backtracking algorithm and display in the app

Browse files
Files changed (3) hide show
  1. app.py +10 -3
  2. sudoku/helper.py +9 -1
  3. sudoku/train.py +59 -18
app.py CHANGED
@@ -117,12 +117,19 @@ if n_sol==1:
117
  i+=1
118
  st.markdown(f'iteration {i}')
119
  try:
120
- new_X = model.predict(new_X)
121
  except TrialEveryPosException:
122
  st.markdown('''## The grid is super evil!
123
  please share it as A Discussion in the `Community` tab.
124
- Except if it is this one: https://www.telegraph.co.uk/news/science/science-news/9359579/Worlds-hardest-sudoku-can-you-crack-it.html''')
125
- is_valid, new_X = model.backtracking_predict(new_X)
 
 
 
 
 
 
 
126
  assert is_valid
127
  st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
128
  new_X_sum = new_X.sum()
 
117
  i+=1
118
  st.markdown(f'iteration {i}')
119
  try:
120
+ new_X = model.predict(new_X, func_text_display=st.markdown)
121
  except TrialEveryPosException:
122
  st.markdown('''## The grid is super evil!
123
  please share it as A Discussion in the `Community` tab.
124
+ Except if it is this one: https://www.telegraph.co.uk/news/science/science-news/9359579/Worlds-hardest-sudoku-can-you-crack-it.html
125
+
126
+ Using trail error model enhanced by backtracking
127
+ ''')
128
+ is_valid, new_X = model.backtracking_predict(
129
+ new_X,
130
+ func_text_display=st.markdown,
131
+ func_tensor_display=lambda t: st.html(display_as_dataframe(t).to_html(escape=False, index=False)),
132
+ )
133
  assert is_valid
134
  st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
135
  new_X_sum = new_X.sum()
sudoku/helper.py CHANGED
@@ -157,4 +157,12 @@ def legal(row, col, num, grid):
157
  return True
158
 
159
  def get_grid_number_soluce(grid):
160
- return solve(0,0,grid,0)
 
 
 
 
 
 
 
 
 
157
  return True
158
 
159
  def get_grid_number_soluce(grid):
160
+ return solve(0,0,grid,0)
161
+
162
+
163
+ def pos_to_digit_col_row(pos):
164
+ # pos = i+j*9+k*9*9
165
+ digit = pos%9+1
166
+ col = (pos//9)%9+1
167
+ row=pos//(9*9)+1
168
+ return digit, col, row
sudoku/train.py CHANGED
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
 
9
  from sudoku.buffer import BufferArray, Buffer
10
  from sudoku.trial_grid import TrialGrid
 
11
 
12
  from copy import deepcopy
13
 
@@ -374,6 +375,8 @@ class SudokuLightning(pl.LightningModule):
374
  (self.sym_preprocess(x)[:, 17].max(dim=1).values > (1 / 8))
375
  | (self.sym_preprocess(x)[:, 18].max(dim=1).values > (1 / 8))
376
  | (self.sym_preprocess(x)[:, 19].max(dim=1).values > (1 / 8))
 
 
377
  )
378
 
379
  # steps to trial error
@@ -828,13 +831,15 @@ class SudokuTrialErrorLightning(SudokuLightning):
828
  self.log(f"{prefix}_y_pos_trial_error", y.sum())
829
  self.log(f"{prefix}_y_neg_trial_eror", y.shape[0]-y.sum())
830
 
831
- def predict(self, x):
832
  """ return an improvement of x
833
 
834
  """
835
 
836
  idx, new_X = self.forward(x.view(-1,2,729))
837
  if (new_X.sum()>x.sum()) or (new_X.sum()==729):
 
 
838
  return new_X
839
  else:
840
  # call trial error until we find a solution
@@ -854,34 +859,61 @@ class SudokuTrialErrorLightning(SudokuLightning):
854
  mask_validated = self.validate_grids(new_X)
855
  if mask_validated.sum()<2:
856
  x[0, mask_validated, pos] = 1 # TODO check if it work
 
 
 
 
857
  return x
858
  if X_tried.sum()==new_X.sum():
859
- # if both stop to improve -> break it will tried an new pos
860
  break
861
  mask_complete = (X_tried.sum(dim=1)==729)# check if it works
862
  if mask_complete.sum()>0:
863
  x[0, mask_complete, pos] = 1
 
 
 
 
864
  return x
865
  X_tried = new_X
866
  # if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
867
 
868
- def backtracking_predict(self, x):
869
  """
870
  return is_valid, new_x
871
  """
872
  next_X = deepcopy(x)
873
  sum_1 = next_X.sum().item()
874
- while True:
875
- idx, next_X = self.forward(next_X.view(1,2,729))
876
- if not self.validate_grids(next_X)[0].item():
877
- return False, None
878
- sum_2 = next_X.sum().item()
879
- if sum_1==sum_2:
880
- break
881
- else:
882
- sum_1=sum_2
883
- if next_X.sum()>=729:
884
- return True, next_X
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885
 
886
  pos = self.search_trial(next_X.view(2,729), [])
887
  new_x = deepcopy(next_X.view(1,2,729))
@@ -889,16 +921,25 @@ class SudokuTrialErrorLightning(SudokuLightning):
889
  pos_output = output[0,:,pos]
890
  if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item()-self.threshold_pres[0].item()):
891
  trial_abs_pres = 0
 
892
  else:
893
  trial_abs_pres = 1
 
894
  new_x[0,trial_abs_pres,pos]=1
895
- print(f'assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), trial_abs_pres)}')# i+j*9+n*9*9
896
- is_valid, new_x = self.backtracking_predict(new_x)
 
 
 
 
 
897
  if not is_valid:
898
  new_x = deepcopy(next_X.view(1,2,729))
899
  new_x[0,(trial_abs_pres+1)%2,pos]=1
900
- print(f'finally assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), (trial_abs_pres+1)%2)}')# i+j*9+n*9*9
901
- return self.backtracking_predict(new_x)
 
 
902
  return True, new_x
903
 
904
 
 
8
 
9
  from sudoku.buffer import BufferArray, Buffer
10
  from sudoku.trial_grid import TrialGrid
11
+ from sudoku.helper import pos_to_digit_col_row
12
 
13
  from copy import deepcopy
14
 
 
375
  (self.sym_preprocess(x)[:, 17].max(dim=1).values > (1 / 8))
376
  | (self.sym_preprocess(x)[:, 18].max(dim=1).values > (1 / 8))
377
  | (self.sym_preprocess(x)[:, 19].max(dim=1).values > (1 / 8))
378
+ | (x.view(-1,2,9,9,9)[:,1].sum(dim=-1)>1).any(dim=1).any(dim=1)
379
+ | (x.view(-1,2,9,9,9)[:,0].sum(dim=-1)>8).any(dim=1).any(dim=1)
380
  )
381
 
382
  # steps to trial error
 
831
  self.log(f"{prefix}_y_pos_trial_error", y.sum())
832
  self.log(f"{prefix}_y_neg_trial_eror", y.shape[0]-y.sum())
833
 
834
+ def predict(self, x, func_text_display=None):
835
  """ return an improvement of x
836
 
837
  """
838
 
839
  idx, new_X = self.forward(x.view(-1,2,729))
840
  if (new_X.sum()>x.sum()) or (new_X.sum()==729):
841
+ if func_text_display:
842
+ func_text_display(f'boost layer step: {idx}')
843
  return new_X
844
  else:
845
  # call trial error until we find a solution
 
859
  mask_validated = self.validate_grids(new_X)
860
  if mask_validated.sum()<2:
861
  x[0, mask_validated, pos] = 1 # TODO check if it work
862
+ if func_text_display:
863
+ digit, col, row = pos_to_digit_col_row(pos)
864
+ func_text_display('model failed to improve the grid')
865
+ func_text_display(f'trial error alogorithm, found error at digit: {digit}, col: {col}, row: {row}')
866
  return x
867
  if X_tried.sum()==new_X.sum():
868
+ # if both stop to improve -> break it will tried an new pos
869
  break
870
  mask_complete = (X_tried.sum(dim=1)==729)# check if it works
871
  if mask_complete.sum()>0:
872
  x[0, mask_complete, pos] = 1
873
+ if func_text_display:
874
+ digit, col, row = pos_to_digit_col_row(pos)
875
+ func_text_display('model failed to improve the grid')
876
+ func_text_display(f'trial error alogorithm, complete the grid with digit: {digit}, col: {col}, row: {row}')
877
  return x
878
  X_tried = new_X
879
  # if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
880
 
881
+ def backtracking_predict(self, x, use_trial_error=False, assumption=[], func_text_display=None, func_tensor_display=None):
882
  """
883
  return is_valid, new_x
884
  """
885
  next_X = deepcopy(x)
886
  sum_1 = next_X.sum().item()
887
+ if use_trial_error:
888
+ while True:
889
+ try:
890
+ next_X = self.predict(next_X.view(1,2,729))
891
+ if not self.validate_grids(next_X)[0].item():
892
+ if assumption is not []:
893
+ func_text_display(f'assumption {assumption[-1]}, failed')
894
+ func_text_display(f'assumption length {len(assumption)}')
895
+ func_tensor_display(next_X)
896
+ return False, None
897
+ if next_X.sum()>=729:
898
+ if assumption is not []:
899
+ func_text_display(f'assumption length {len(assumption)}')
900
+ func_tensor_display(next_X)
901
+ return True, next_X
902
+ except TrialEveryPosException:
903
+ break
904
+
905
+ else:
906
+ while True:
907
+ _idx, next_X = self.forward(next_X.view(1,2,729))
908
+ if not self.validate_grids(next_X)[0].item():
909
+ return False, None
910
+ sum_2 = next_X.sum().item()
911
+ if sum_1==sum_2:
912
+ break
913
+ else:
914
+ sum_1=sum_2
915
+ if next_X.sum()>=729:
916
+ return True, next_X
917
 
918
  pos = self.search_trial(next_X.view(2,729), [])
919
  new_x = deepcopy(next_X.view(1,2,729))
 
921
  pos_output = output[0,:,pos]
922
  if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item()-self.threshold_pres[0].item()):
923
  trial_abs_pres = 0
924
+ abs_pres_str = 'abs'
925
  else:
926
  trial_abs_pres = 1
927
+ abs_pres_str = 'pres'
928
  new_x[0,trial_abs_pres,pos]=1
929
+ # print(f'assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), trial_abs_pres)}')# i+j*9+n*9*9
930
+ digit, col, row = pos_to_digit_col_row(pos)
931
+ func_text_display('grid not improving adding the assuption')
932
+ current_assumption = f'{abs_pres_str}, digit: {digit}, col: {col}, row: {row}'
933
+ func_text_display(current_assumption)
934
+ assumption.append(current_assumption)
935
+ is_valid, new_x = self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display)
936
  if not is_valid:
937
  new_x = deepcopy(next_X.view(1,2,729))
938
  new_x[0,(trial_abs_pres+1)%2,pos]=1
939
+ # print(f'finally assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), (trial_abs_pres+1)%2)}')# i+j*9+n*9*9
940
+ current_assumption = f'{"pres" if abs_pres_str=="abs" else "abs"}, digit: {digit}, col: {col}, row: {row}'
941
+ func_text_display(current_assumption)
942
+ return self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display)
943
  return True, new_x
944
 
945