finish backtracking algo
Browse files- sudoku/train.py +10 -2
sudoku/train.py
CHANGED
@@ -870,7 +870,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
870 |
return is_valid, new_x
|
871 |
"""
|
872 |
try:
|
873 |
-
|
874 |
except TrialEveryPosException:
|
875 |
pos = self.search_trial(x.view(2,729), [])
|
876 |
new_x = deepcopy(x.view(2,729))
|
@@ -883,7 +883,15 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
883 |
new_x[0,trial_abs_pres,pos]=1
|
884 |
is_valid, new_x = self.backtracking_predict(new_x)
|
885 |
if not is_valid:
|
886 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
887 |
|
888 |
|
889 |
def on_validation_epoch_start(self) -> None:
|
|
|
870 |
return is_valid, new_x
|
871 |
"""
|
872 |
try:
|
873 |
+
next_x = self.predict(x)
|
874 |
except TrialEveryPosException:
|
875 |
pos = self.search_trial(x.view(2,729), [])
|
876 |
new_x = deepcopy(x.view(2,729))
|
|
|
883 |
new_x[0,trial_abs_pres,pos]=1
|
884 |
is_valid, new_x = self.backtracking_predict(new_x)
|
885 |
if not is_valid:
|
886 |
+
new_x = deepcopy(x.view(2,729))
|
887 |
+
new_x[0,(trial_abs_pres+1)%2,pos]=1
|
888 |
+
return self.backtracking_predict(new_x)
|
889 |
+
return True, new_x
|
890 |
+
except ValueError:
|
891 |
+
return False, None
|
892 |
+
if next_x.sum()==729:
|
893 |
+
return True, next_x
|
894 |
+
return self.backtracking_predict(next_x)
|
895 |
|
896 |
|
897 |
def on_validation_epoch_start(self) -> None:
|