unfinished work on backtracking

#1
Files changed (2) hide show
  1. app.py +9 -2
  2. sudoku/train.py +33 -1
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from sudoku.train import SudokuTrialErrorLightning
3
  from sudoku.helper import display_as_dataframe, get_grid_number_soluce
4
  import numpy as np
5
  import re
@@ -116,7 +116,14 @@ if n_sol==1:
116
  while new_X.sum()<729:
117
  i+=1
118
  st.markdown(f'iteration {i}')
119
- new_X = model.predict(new_X)
 
 
 
 
 
 
 
120
  st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
121
  new_X_sum = new_X.sum()
122
  assert new_X_sum> X_sum
 
1
  import streamlit as st
2
+ from sudoku.train import SudokuTrialErrorLightning, TrialEveryPosException
3
  from sudoku.helper import display_as_dataframe, get_grid_number_soluce
4
  import numpy as np
5
  import re
 
116
  while new_X.sum()<729:
117
  i+=1
118
  st.markdown(f'iteration {i}')
119
+ try:
120
+ new_X = model.predict(new_X)
121
+ except TrialEveryPosException:
122
+ st.markdown('''## The grid is super evil!
123
+ please share it as A Discussion in the `Community` tab.
124
+ Except if it is this one: https://www.telegraph.co.uk/news/science/science-news/9359579/Worlds-hardest-sudoku-can-you-crack-it.html''')
125
+ is_valid, new_X = model.backtracking_predict(new_X)
126
+ assert is_valid
127
  st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
128
  new_X_sum = new_X.sum()
129
  assert new_X_sum> X_sum
sudoku/train.py CHANGED
@@ -401,6 +401,9 @@ class SudokuLightning(pl.LightningModule):
401
 
402
  # TODO adapt training to something softer
403
  #
 
 
 
404
  class SudokuTrialErrorLightning(SudokuLightning):
405
  def __init__(self, **kwargs):
406
  super().__init__(**kwargs)
@@ -752,7 +755,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
752
  mask_possibility[pos]=False
753
  if mask_possibility.sum()==0:
754
  print('mask_possible=0')
755
- raise ValueError()
756
 
757
  with torch.no_grad():
758
  x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
@@ -861,7 +864,36 @@ class SudokuTrialErrorLightning(SudokuLightning):
861
  return x
862
  X_tried = new_X
863
  # if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
 
 
865
  def on_validation_epoch_start(self) -> None:
866
  # self.buffer = BufferArray(self.nets_number, self.batch_size)
867
  self.trial_error_buffer = Buffer(self.batch_size)
 
401
 
402
  # TODO adapt training to something softer
403
  #
404
+ class TrialEveryPosException(Exception):
405
+ pass
406
+
407
  class SudokuTrialErrorLightning(SudokuLightning):
408
  def __init__(self, **kwargs):
409
  super().__init__(**kwargs)
 
755
  mask_possibility[pos]=False
756
  if mask_possibility.sum()==0:
757
  print('mask_possible=0')
758
+ raise TrialEveryPosException()
759
 
760
  with torch.no_grad():
761
  x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
 
864
  return x
865
  X_tried = new_X
866
  # if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
867
+
868
+ def backtracking_predict(self, x):
869
+ """
870
+ return is_valid, new_x
871
+ """
872
+ try:
873
+ next_x = self.predict(x)
874
+ except TrialEveryPosException:
875
+ pos = self.search_trial(x.view(2,729), [])
876
+ new_x = deepcopy(x.view(2,729))
877
+ output = self.forward_layer(x.view(1,2,729))
878
+ pos_output = torch.argmax(output[0,:,pos])
879
+ if (pos_output[0].item()-self.threshold_abs)>(pos_output[1].item-self.threshold_pos):
880
+ trial_abs_pres = 0
881
+ else:
882
+ trial_abs_pres = 1
883
+ new_x[0,trial_abs_pres,pos]=1
884
+ is_valid, new_x = self.backtracking_predict(new_x)
885
+ if not is_valid:
886
+ new_x = deepcopy(x.view(2,729))
887
+ new_x[0,(trial_abs_pres+1)%2,pos]=1
888
+ return self.backtracking_predict(new_x)
889
+ return True, new_x
890
+ except ValueError:
891
+ return False, None
892
+ if next_x.sum()==729:
893
+ return True, next_x
894
+ return self.backtracking_predict(next_x)
895
 
896
+
897
  def on_validation_epoch_start(self) -> None:
898
  # self.buffer = BufferArray(self.nets_number, self.batch_size)
899
  self.trial_error_buffer = Buffer(self.batch_size)