Sebastien
commited on
Commit
·
19320e5
1
Parent(s):
0583ac6
improving backtracking algorithm and display in the app
Browse files- app.py +10 -3
- sudoku/helper.py +9 -1
- sudoku/train.py +59 -18
app.py
CHANGED
@@ -117,12 +117,19 @@ if n_sol==1:
|
|
117 |
i+=1
|
118 |
st.markdown(f'iteration {i}')
|
119 |
try:
|
120 |
-
new_X = model.predict(new_X)
|
121 |
except TrialEveryPosException:
|
122 |
st.markdown('''## The grid is super evil!
|
123 |
please share it as A Discussion in the `Community` tab.
|
124 |
-
Except if it is this one: https://www.telegraph.co.uk/news/science/science-news/9359579/Worlds-hardest-sudoku-can-you-crack-it.html
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
assert is_valid
|
127 |
st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
|
128 |
new_X_sum = new_X.sum()
|
|
|
117 |
i+=1
|
118 |
st.markdown(f'iteration {i}')
|
119 |
try:
|
120 |
+
new_X = model.predict(new_X, func_text_display=st.markdown)
|
121 |
except TrialEveryPosException:
|
122 |
st.markdown('''## The grid is super evil!
|
123 |
please share it as A Discussion in the `Community` tab.
|
124 |
+
Except if it is this one: https://www.telegraph.co.uk/news/science/science-news/9359579/Worlds-hardest-sudoku-can-you-crack-it.html
|
125 |
+
|
126 |
+
Using trail error model enhanced by backtracking
|
127 |
+
''')
|
128 |
+
is_valid, new_X = model.backtracking_predict(
|
129 |
+
new_X,
|
130 |
+
func_text_display=st.markdown,
|
131 |
+
func_tensor_display=lambda t: st.html(display_as_dataframe(t).to_html(escape=False, index=False)),
|
132 |
+
)
|
133 |
assert is_valid
|
134 |
st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
|
135 |
new_X_sum = new_X.sum()
|
sudoku/helper.py
CHANGED
@@ -157,4 +157,12 @@ def legal(row, col, num, grid):
|
|
157 |
return True
|
158 |
|
159 |
def get_grid_number_soluce(grid):
|
160 |
-
return solve(0,0,grid,0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
return True
|
158 |
|
159 |
def get_grid_number_soluce(grid):
|
160 |
+
return solve(0,0,grid,0)
|
161 |
+
|
162 |
+
|
163 |
+
def pos_to_digit_col_row(pos):
|
164 |
+
# pos = i+j*9+k*9*9
|
165 |
+
digit = pos%9+1
|
166 |
+
col = (pos//9)%9+1
|
167 |
+
row=pos//(9*9)+1
|
168 |
+
return digit, col, row
|
sudoku/train.py
CHANGED
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
8 |
|
9 |
from sudoku.buffer import BufferArray, Buffer
|
10 |
from sudoku.trial_grid import TrialGrid
|
|
|
11 |
|
12 |
from copy import deepcopy
|
13 |
|
@@ -374,6 +375,8 @@ class SudokuLightning(pl.LightningModule):
|
|
374 |
(self.sym_preprocess(x)[:, 17].max(dim=1).values > (1 / 8))
|
375 |
| (self.sym_preprocess(x)[:, 18].max(dim=1).values > (1 / 8))
|
376 |
| (self.sym_preprocess(x)[:, 19].max(dim=1).values > (1 / 8))
|
|
|
|
|
377 |
)
|
378 |
|
379 |
# steps to trial error
|
@@ -828,13 +831,15 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
828 |
self.log(f"{prefix}_y_pos_trial_error", y.sum())
|
829 |
self.log(f"{prefix}_y_neg_trial_eror", y.shape[0]-y.sum())
|
830 |
|
831 |
-
def predict(self, x):
|
832 |
""" return an improvement of x
|
833 |
|
834 |
"""
|
835 |
|
836 |
idx, new_X = self.forward(x.view(-1,2,729))
|
837 |
if (new_X.sum()>x.sum()) or (new_X.sum()==729):
|
|
|
|
|
838 |
return new_X
|
839 |
else:
|
840 |
# call trial error until we find a solution
|
@@ -854,34 +859,61 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
854 |
mask_validated = self.validate_grids(new_X)
|
855 |
if mask_validated.sum()<2:
|
856 |
x[0, mask_validated, pos] = 1 # TODO check if it work
|
|
|
|
|
|
|
|
|
857 |
return x
|
858 |
if X_tried.sum()==new_X.sum():
|
859 |
-
|
860 |
break
|
861 |
mask_complete = (X_tried.sum(dim=1)==729)# check if it works
|
862 |
if mask_complete.sum()>0:
|
863 |
x[0, mask_complete, pos] = 1
|
|
|
|
|
|
|
|
|
864 |
return x
|
865 |
X_tried = new_X
|
866 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
867 |
|
868 |
-
def backtracking_predict(self, x):
|
869 |
"""
|
870 |
return is_valid, new_x
|
871 |
"""
|
872 |
next_X = deepcopy(x)
|
873 |
sum_1 = next_X.sum().item()
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
885 |
|
886 |
pos = self.search_trial(next_X.view(2,729), [])
|
887 |
new_x = deepcopy(next_X.view(1,2,729))
|
@@ -889,16 +921,25 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
889 |
pos_output = output[0,:,pos]
|
890 |
if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item()-self.threshold_pres[0].item()):
|
891 |
trial_abs_pres = 0
|
|
|
892 |
else:
|
893 |
trial_abs_pres = 1
|
|
|
894 |
new_x[0,trial_abs_pres,pos]=1
|
895 |
-
print(f'assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), trial_abs_pres)}')# i+j*9+n*9*9
|
896 |
-
|
|
|
|
|
|
|
|
|
|
|
897 |
if not is_valid:
|
898 |
new_x = deepcopy(next_X.view(1,2,729))
|
899 |
new_x[0,(trial_abs_pres+1)%2,pos]=1
|
900 |
-
print(f'finally assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), (trial_abs_pres+1)%2)}')# i+j*9+n*9*9
|
901 |
-
|
|
|
|
|
902 |
return True, new_x
|
903 |
|
904 |
|
|
|
8 |
|
9 |
from sudoku.buffer import BufferArray, Buffer
|
10 |
from sudoku.trial_grid import TrialGrid
|
11 |
+
from sudoku.helper import pos_to_digit_col_row
|
12 |
|
13 |
from copy import deepcopy
|
14 |
|
|
|
375 |
(self.sym_preprocess(x)[:, 17].max(dim=1).values > (1 / 8))
|
376 |
| (self.sym_preprocess(x)[:, 18].max(dim=1).values > (1 / 8))
|
377 |
| (self.sym_preprocess(x)[:, 19].max(dim=1).values > (1 / 8))
|
378 |
+
| (x.view(-1,2,9,9,9)[:,1].sum(dim=-1)>1).any(dim=1).any(dim=1)
|
379 |
+
| (x.view(-1,2,9,9,9)[:,0].sum(dim=-1)>8).any(dim=1).any(dim=1)
|
380 |
)
|
381 |
|
382 |
# steps to trial error
|
|
|
831 |
self.log(f"{prefix}_y_pos_trial_error", y.sum())
|
832 |
self.log(f"{prefix}_y_neg_trial_eror", y.shape[0]-y.sum())
|
833 |
|
834 |
+
def predict(self, x, func_text_display=None):
|
835 |
""" return an improvement of x
|
836 |
|
837 |
"""
|
838 |
|
839 |
idx, new_X = self.forward(x.view(-1,2,729))
|
840 |
if (new_X.sum()>x.sum()) or (new_X.sum()==729):
|
841 |
+
if func_text_display:
|
842 |
+
func_text_display(f'boost layer step: {idx}')
|
843 |
return new_X
|
844 |
else:
|
845 |
# call trial error until we find a solution
|
|
|
859 |
mask_validated = self.validate_grids(new_X)
|
860 |
if mask_validated.sum()<2:
|
861 |
x[0, mask_validated, pos] = 1 # TODO check if it work
|
862 |
+
if func_text_display:
|
863 |
+
digit, col, row = pos_to_digit_col_row(pos)
|
864 |
+
func_text_display('model failed to improve the grid')
|
865 |
+
func_text_display(f'trial error alogorithm, found error at digit: {digit}, col: {col}, row: {row}')
|
866 |
return x
|
867 |
if X_tried.sum()==new_X.sum():
|
868 |
+
# if both stop to improve -> break it will tried an new pos
|
869 |
break
|
870 |
mask_complete = (X_tried.sum(dim=1)==729)# check if it works
|
871 |
if mask_complete.sum()>0:
|
872 |
x[0, mask_complete, pos] = 1
|
873 |
+
if func_text_display:
|
874 |
+
digit, col, row = pos_to_digit_col_row(pos)
|
875 |
+
func_text_display('model failed to improve the grid')
|
876 |
+
func_text_display(f'trial error alogorithm, complete the grid with digit: {digit}, col: {col}, row: {row}')
|
877 |
return x
|
878 |
X_tried = new_X
|
879 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
880 |
|
881 |
+
def backtracking_predict(self, x, use_trial_error=False, assumption=[], func_text_display=None, func_tensor_display=None):
|
882 |
"""
|
883 |
return is_valid, new_x
|
884 |
"""
|
885 |
next_X = deepcopy(x)
|
886 |
sum_1 = next_X.sum().item()
|
887 |
+
if use_trial_error:
|
888 |
+
while True:
|
889 |
+
try:
|
890 |
+
next_X = self.predict(next_X.view(1,2,729))
|
891 |
+
if not self.validate_grids(next_X)[0].item():
|
892 |
+
if assumption is not []:
|
893 |
+
func_text_display(f'assumption {assumption[-1]}, failed')
|
894 |
+
func_text_display(f'assumption length {len(assumption)}')
|
895 |
+
func_tensor_display(next_X)
|
896 |
+
return False, None
|
897 |
+
if next_X.sum()>=729:
|
898 |
+
if assumption is not []:
|
899 |
+
func_text_display(f'assumption length {len(assumption)}')
|
900 |
+
func_tensor_display(next_X)
|
901 |
+
return True, next_X
|
902 |
+
except TrialEveryPosException:
|
903 |
+
break
|
904 |
+
|
905 |
+
else:
|
906 |
+
while True:
|
907 |
+
_idx, next_X = self.forward(next_X.view(1,2,729))
|
908 |
+
if not self.validate_grids(next_X)[0].item():
|
909 |
+
return False, None
|
910 |
+
sum_2 = next_X.sum().item()
|
911 |
+
if sum_1==sum_2:
|
912 |
+
break
|
913 |
+
else:
|
914 |
+
sum_1=sum_2
|
915 |
+
if next_X.sum()>=729:
|
916 |
+
return True, next_X
|
917 |
|
918 |
pos = self.search_trial(next_X.view(2,729), [])
|
919 |
new_x = deepcopy(next_X.view(1,2,729))
|
|
|
921 |
pos_output = output[0,:,pos]
|
922 |
if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item()-self.threshold_pres[0].item()):
|
923 |
trial_abs_pres = 0
|
924 |
+
abs_pres_str = 'abs'
|
925 |
else:
|
926 |
trial_abs_pres = 1
|
927 |
+
abs_pres_str = 'pres'
|
928 |
new_x[0,trial_abs_pres,pos]=1
|
929 |
+
# print(f'assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), trial_abs_pres)}')# i+j*9+n*9*9
|
930 |
+
digit, col, row = pos_to_digit_col_row(pos)
|
931 |
+
func_text_display('grid not improving adding the assuption')
|
932 |
+
current_assumption = f'{abs_pres_str}, digit: {digit}, col: {col}, row: {row}'
|
933 |
+
func_text_display(current_assumption)
|
934 |
+
assumption.append(current_assumption)
|
935 |
+
is_valid, new_x = self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display)
|
936 |
if not is_valid:
|
937 |
new_x = deepcopy(next_X.view(1,2,729))
|
938 |
new_x[0,(trial_abs_pres+1)%2,pos]=1
|
939 |
+
# print(f'finally assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), (trial_abs_pres+1)%2)}')# i+j*9+n*9*9
|
940 |
+
current_assumption = f'{"pres" if abs_pres_str=="abs" else "abs"}, digit: {digit}, col: {col}, row: {row}'
|
941 |
+
func_text_display(current_assumption)
|
942 |
+
return self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display)
|
943 |
return True, new_x
|
944 |
|
945 |
|