Sebastien
commited on
Commit
·
0583ac6
1
Parent(s):
dfa6863
add notebook and fix code
Browse files- experiments/test_harder_sudoku.ipynb +0 -0
- sudoku/train.py +30 -22
experiments/test_harder_sudoku.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sudoku/train.py
CHANGED
@@ -869,29 +869,37 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
869 |
"""
|
870 |
return is_valid, new_x
|
871 |
"""
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
if
|
880 |
-
|
881 |
else:
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
895 |
|
896 |
|
897 |
def on_validation_epoch_start(self) -> None:
|
|
|
869 |
"""
|
870 |
return is_valid, new_x
|
871 |
"""
|
872 |
+
next_X = deepcopy(x)
|
873 |
+
sum_1 = next_X.sum().item()
|
874 |
+
while True:
|
875 |
+
idx, next_X = self.forward(next_X.view(1,2,729))
|
876 |
+
if not self.validate_grids(next_X)[0].item():
|
877 |
+
return False, None
|
878 |
+
sum_2 = next_X.sum().item()
|
879 |
+
if sum_1==sum_2:
|
880 |
+
break
|
881 |
else:
|
882 |
+
sum_1=sum_2
|
883 |
+
if next_X.sum()>=729:
|
884 |
+
return True, next_X
|
885 |
+
|
886 |
+
pos = self.search_trial(next_X.view(2,729), [])
|
887 |
+
new_x = deepcopy(next_X.view(1,2,729))
|
888 |
+
output = self.forward_layer(next_X.view(1,2,729))
|
889 |
+
pos_output = output[0,:,pos]
|
890 |
+
if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item()-self.threshold_pres[0].item()):
|
891 |
+
trial_abs_pres = 0
|
892 |
+
else:
|
893 |
+
trial_abs_pres = 1
|
894 |
+
new_x[0,trial_abs_pres,pos]=1
|
895 |
+
print(f'assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), trial_abs_pres)}')# i+j*9+n*9*9
|
896 |
+
is_valid, new_x = self.backtracking_predict(new_x)
|
897 |
+
if not is_valid:
|
898 |
+
new_x = deepcopy(next_X.view(1,2,729))
|
899 |
+
new_x[0,(trial_abs_pres+1)%2,pos]=1
|
900 |
+
print(f'finally assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), (trial_abs_pres+1)%2)}')# i+j*9+n*9*9
|
901 |
+
return self.backtracking_predict(new_x)
|
902 |
+
return True, new_x
|
903 |
|
904 |
|
905 |
def on_validation_epoch_start(self) -> None:
|