unfinished work on backtracking
Browse files- sudoku/train.py +25 -1
sudoku/train.py
CHANGED
@@ -401,6 +401,9 @@ class SudokuLightning(pl.LightningModule):
|
|
401 |
|
402 |
# TODO adapt training to something softer
|
403 |
#
|
|
|
|
|
|
|
404 |
class SudokuTrialErrorLightning(SudokuLightning):
|
405 |
def __init__(self, **kwargs):
|
406 |
super().__init__(**kwargs)
|
@@ -752,7 +755,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
752 |
mask_possibility[pos]=False
|
753 |
if mask_possibility.sum()==0:
|
754 |
print('mask_possible=0')
|
755 |
-
raise
|
756 |
|
757 |
with torch.no_grad():
|
758 |
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
|
@@ -861,7 +864,28 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
861 |
return x
|
862 |
X_tried = new_X
|
863 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
864 |
|
|
|
865 |
def on_validation_epoch_start(self) -> None:
|
866 |
# self.buffer = BufferArray(self.nets_number, self.batch_size)
|
867 |
self.trial_error_buffer = Buffer(self.batch_size)
|
|
|
401 |
|
402 |
# TODO adapt training to something softer
|
403 |
#
|
404 |
+
class TrialEveryPosException(Exception):
|
405 |
+
pass
|
406 |
+
|
407 |
class SudokuTrialErrorLightning(SudokuLightning):
|
408 |
def __init__(self, **kwargs):
|
409 |
super().__init__(**kwargs)
|
|
|
755 |
mask_possibility[pos]=False
|
756 |
if mask_possibility.sum()==0:
|
757 |
print('mask_possible=0')
|
758 |
+
raise TrialEveryPosException()
|
759 |
|
760 |
with torch.no_grad():
|
761 |
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
|
|
|
864 |
return x
|
865 |
X_tried = new_X
|
866 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
867 |
+
|
868 |
+
def backtracking_predict(self, x):
|
869 |
+
"""
|
870 |
+
return is_valid, new_x
|
871 |
+
"""
|
872 |
+
try:
|
873 |
+
x = self.predict(x)
|
874 |
+
except TrialEveryPosException:
|
875 |
+
pos = self.search_trial(x.view(2,729), [])
|
876 |
+
new_x = deepcopy(x.view(2,729))
|
877 |
+
output = self.forward_layer(x.view(1,2,729))
|
878 |
+
pos_output = torch.argmax(output[0,:,pos])
|
879 |
+
if (pos_output[0].item()-self.threshold_abs)>(pos_output[1].item-self.threshold_pos):
|
880 |
+
trial_abs_pres = 0
|
881 |
+
else:
|
882 |
+
trial_abs_pres = 1
|
883 |
+
new_x[0,trial_abs_pres,pos]=1
|
884 |
+
is_valid, new_x = self.backtracking_predict(new_x)
|
885 |
+
if not is_valid:
|
886 |
+
return is_valid ### TO continue
|
887 |
|
888 |
+
|
889 |
def on_validation_epoch_start(self) -> None:
|
890 |
# self.buffer = BufferArray(self.nets_number, self.batch_size)
|
891 |
self.trial_error_buffer = Buffer(self.batch_size)
|