unfinished work on backtracking
#1
by
SebastienGuissart
- opened
- app.py +9 -2
- sudoku/train.py +33 -1
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
-
from sudoku.train import SudokuTrialErrorLightning
|
3 |
from sudoku.helper import display_as_dataframe, get_grid_number_soluce
|
4 |
import numpy as np
|
5 |
import re
|
@@ -116,7 +116,14 @@ if n_sol==1:
|
|
116 |
while new_X.sum()<729:
|
117 |
i+=1
|
118 |
st.markdown(f'iteration {i}')
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
|
121 |
new_X_sum = new_X.sum()
|
122 |
assert new_X_sum> X_sum
|
|
|
1 |
import streamlit as st
|
2 |
+
from sudoku.train import SudokuTrialErrorLightning, TrialEveryPosException
|
3 |
from sudoku.helper import display_as_dataframe, get_grid_number_soluce
|
4 |
import numpy as np
|
5 |
import re
|
|
|
116 |
while new_X.sum()<729:
|
117 |
i+=1
|
118 |
st.markdown(f'iteration {i}')
|
119 |
+
try:
|
120 |
+
new_X = model.predict(new_X)
|
121 |
+
except TrialEveryPosException:
|
122 |
+
st.markdown('''## The grid is super evil!
|
123 |
+
please share it as A Discussion in the `Community` tab.
|
124 |
+
Except if it is this one: https://www.telegraph.co.uk/news/science/science-news/9359579/Worlds-hardest-sudoku-can-you-crack-it.html''')
|
125 |
+
is_valid, new_X = model.backtracking_predict(new_X)
|
126 |
+
assert is_valid
|
127 |
st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
|
128 |
new_X_sum = new_X.sum()
|
129 |
assert new_X_sum> X_sum
|
sudoku/train.py
CHANGED
@@ -401,6 +401,9 @@ class SudokuLightning(pl.LightningModule):
|
|
401 |
|
402 |
# TODO adapt training to something softer
|
403 |
#
|
|
|
|
|
|
|
404 |
class SudokuTrialErrorLightning(SudokuLightning):
|
405 |
def __init__(self, **kwargs):
|
406 |
super().__init__(**kwargs)
|
@@ -752,7 +755,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
752 |
mask_possibility[pos]=False
|
753 |
if mask_possibility.sum()==0:
|
754 |
print('mask_possible=0')
|
755 |
-
raise
|
756 |
|
757 |
with torch.no_grad():
|
758 |
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
|
@@ -861,7 +864,36 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
861 |
return x
|
862 |
X_tried = new_X
|
863 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
864 |
|
|
|
865 |
def on_validation_epoch_start(self) -> None:
|
866 |
# self.buffer = BufferArray(self.nets_number, self.batch_size)
|
867 |
self.trial_error_buffer = Buffer(self.batch_size)
|
|
|
401 |
|
402 |
# TODO adapt training to something softer
|
403 |
#
|
404 |
+
class TrialEveryPosException(Exception):
|
405 |
+
pass
|
406 |
+
|
407 |
class SudokuTrialErrorLightning(SudokuLightning):
|
408 |
def __init__(self, **kwargs):
|
409 |
super().__init__(**kwargs)
|
|
|
755 |
mask_possibility[pos]=False
|
756 |
if mask_possibility.sum()==0:
|
757 |
print('mask_possible=0')
|
758 |
+
raise TrialEveryPosException()
|
759 |
|
760 |
with torch.no_grad():
|
761 |
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
|
|
|
864 |
return x
|
865 |
X_tried = new_X
|
866 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
867 |
+
|
868 |
+
def backtracking_predict(self, x):
|
869 |
+
"""
|
870 |
+
return is_valid, new_x
|
871 |
+
"""
|
872 |
+
try:
|
873 |
+
next_x = self.predict(x)
|
874 |
+
except TrialEveryPosException:
|
875 |
+
pos = self.search_trial(x.view(2,729), [])
|
876 |
+
new_x = deepcopy(x.view(2,729))
|
877 |
+
output = self.forward_layer(x.view(1,2,729))
|
878 |
+
pos_output = torch.argmax(output[0,:,pos])
|
879 |
+
if (pos_output[0].item()-self.threshold_abs)>(pos_output[1].item-self.threshold_pos):
|
880 |
+
trial_abs_pres = 0
|
881 |
+
else:
|
882 |
+
trial_abs_pres = 1
|
883 |
+
new_x[0,trial_abs_pres,pos]=1
|
884 |
+
is_valid, new_x = self.backtracking_predict(new_x)
|
885 |
+
if not is_valid:
|
886 |
+
new_x = deepcopy(x.view(2,729))
|
887 |
+
new_x[0,(trial_abs_pres+1)%2,pos]=1
|
888 |
+
return self.backtracking_predict(new_x)
|
889 |
+
return True, new_x
|
890 |
+
except ValueError:
|
891 |
+
return False, None
|
892 |
+
if next_x.sum()==729:
|
893 |
+
return True, next_x
|
894 |
+
return self.backtracking_predict(next_x)
|
895 |
|
896 |
+
|
897 |
def on_validation_epoch_start(self) -> None:
|
898 |
# self.buffer = BufferArray(self.nets_number, self.batch_size)
|
899 |
self.trial_error_buffer = Buffer(self.batch_size)
|