SebastienGuissart commited on
Commit
6f43ac3
·
verified ·
1 Parent(s): 19a1e6a

unfinished work on backtracking

Browse files
Files changed (1) hide show
  1. sudoku/train.py +25 -1
sudoku/train.py CHANGED
@@ -401,6 +401,9 @@ class SudokuLightning(pl.LightningModule):
401
 
402
  # TODO adapt training to something softer
403
  #
 
 
 
404
  class SudokuTrialErrorLightning(SudokuLightning):
405
  def __init__(self, **kwargs):
406
  super().__init__(**kwargs)
@@ -752,7 +755,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
752
  mask_possibility[pos]=False
753
  if mask_possibility.sum()==0:
754
  print('mask_possible=0')
755
- raise ValueError()
756
 
757
  with torch.no_grad():
758
  x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
@@ -861,7 +864,28 @@ class SudokuTrialErrorLightning(SudokuLightning):
861
  return x
862
  X_tried = new_X
863
  # if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
 
 
865
  def on_validation_epoch_start(self) -> None:
866
  # self.buffer = BufferArray(self.nets_number, self.batch_size)
867
  self.trial_error_buffer = Buffer(self.batch_size)
 
401
 
402
  # TODO adapt training to something softer
403
  #
404
+ class TrialEveryPosException(Exception):
405
+ pass
406
+
407
  class SudokuTrialErrorLightning(SudokuLightning):
408
  def __init__(self, **kwargs):
409
  super().__init__(**kwargs)
 
755
  mask_possibility[pos]=False
756
  if mask_possibility.sum()==0:
757
  print('mask_possible=0')
758
+ raise TrialEveryPosException()
759
 
760
  with torch.no_grad():
761
  x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
 
864
  return x
865
  X_tried = new_X
866
  # if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
867
+
868
+ def backtracking_predict(self, x):
869
+ """
870
+ return is_valid, new_x
871
+ """
872
+ try:
873
+ x = self.predict(x)
874
+ except TrialEveryPosException:
875
+ pos = self.search_trial(x.view(2,729), [])
876
+ new_x = deepcopy(x.view(2,729))
877
+ output = self.forward_layer(x.view(1,2,729))
878
+ pos_output = torch.argmax(output[0,:,pos])
879
+ if (pos_output[0].item()-self.threshold_abs)>(pos_output[1].item-self.threshold_pos):
880
+ trial_abs_pres = 0
881
+ else:
882
+ trial_abs_pres = 1
883
+ new_x[0,trial_abs_pres,pos]=1
884
+ is_valid, new_x = self.backtracking_predict(new_x)
885
+ if not is_valid:
886
+ return is_valid ### TO continue
887
 
888
+
889
  def on_validation_epoch_start(self) -> None:
890
  # self.buffer = BufferArray(self.nets_number, self.batch_size)
891
  self.trial_error_buffer = Buffer(self.batch_size)