from math import copysign import torch from torch import nn from torch.optim.lr_scheduler import ReduceLROnPlateau import pytorch_lightning as pl from sudoku.models import SmallNetBis, SymPreprocess import torch.nn.functional as F from sudoku.buffer import BufferArray, Buffer from sudoku.trial_grid import TrialGrid from sudoku.helper import pos_to_digit_col_row from copy import deepcopy class SudokuLightning(pl.LightningModule): def __init__( self, lr=0.1, margin=0.1, # th marge? coef_0=10, nets_number=6, nets_training_number=1, batch_size=32, ): super().__init__() self.nets_number = nets_number self.batch_size = batch_size self.nets_training_number = nets_training_number # self.nets=[SmallNetBis() for _ in range(self.nets_number)] self.nets = nn.ModuleList([SmallNetBis() for _ in range(self.nets_number)]) self.buffer = BufferArray(self.nets_number, self.batch_size) self.sym_preprocess = SymPreprocess() pos_weight = torch.ones((2, 9 * 9 * 9)) pos_weight[0, :] = 1.0 / 8.0 pos_weight[1, :] = 1.0 pos_weight /= coef_0 weight = torch.ones((2, 9 * 9 * 9)) weight[0, :] = 8.0 weight[1, :] = 1.0 weight *= coef_0 self.bcewll = nn.BCEWithLogitsLoss( pos_weight=pos_weight, weight=weight, reduce=False ) self.lr = lr # self.auroc = AUROC(task='binary') self.margin = margin self.th_epsilon = margin * 0.01 self.threshold_pres = torch.tensor([-10.0 for _ in range(nets_number)]) self.threshold_abs = torch.tensor([-10.0 for _ in range(nets_number)]) self.automatic_optimization = False self.reset_threshold_on_validation = True def configure_optimizers(self): # no need config scheduler -> manual optimisation optimizers = [] for net in self.nets: opti = torch.optim.Adam(net.parameters(), lr=self.lr) optimizers.append( { "optimizer": opti, "lr_scheduler": ReduceLROnPlateau(opti, "min"), } ) return optimizers # def configure_optimizers(self): # optimizer1 = Adam(...) # optimizer2 = SGD(...) # scheduler1 = ReduceLROnPlateau(optimizer1, ...) # scheduler2 = LambdaLR(optimizer2, ...) # return ( # { # "optimizer": optimizer1, # "lr_scheduler": { # "scheduler": scheduler1, # "monitor": "metric_to_track", # }, # }, # {"optimizer": optimizer2, "lr_scheduler": scheduler2}, # ) # lr_scheduler_config = { # # REQUIRED: The scheduler instance # "scheduler": lr_scheduler, # # The unit of the scheduler's step size, could also be 'step'. # # 'epoch' updates the scheduler on epoch end whereas 'step' # # updates it after a optimizer update. # "interval": "epoch", # # How many epochs/steps should pass between calls to # # `scheduler.step()`. 1 corresponds to updating the learning # # rate after every epoch/step. # "frequency": 1, # # Metric to to monitor for schedulers like `ReduceLROnPlateau` # "monitor": "val_loss", # # If set to `True`, will enforce that the value specified 'monitor' # # is available when the scheduler is updated, thus stopping # # training if not found. If set to `False`, it will only produce a warning # "strict": True, # # If using the `LearningRateMonitor` callback to monitor the # # learning rate progress, this keyword can be used to specify # # a custom logged name # "name": None, # } # lr_scheduler_config = {'scheduler: lr_sch, interval: epoch, frequency: 1, monitor: 'val_loss'} def forward_layer(self, x, idx=0): x = self.sym_preprocess.forward(x) return self.nets[idx](x) def forward(self, x): for idx in range(self.nets_number): output = self.forward_layer(x, idx) new_X = self.compute_new_X(output, x, idx, None, train=False) improved_mask = ((new_X == 1) & (x == 0)).any(dim=1).any(dim=1) if improved_mask.sum() > 0: return idx, new_X return idx, new_X def predict_from_net(self, x, net, th_abs, th_pres): x = self.sym_preprocess.forward(x) x = net(x) new_x = torch.empty(x.shape, device=x.device) new_x[:, 0] = (x[:, 0] > th_abs).float() new_x[:, 1] = (x[:, 1] > th_pres).float() return new_x @staticmethod def mask_uncomplete(x, y): mask_uncomplete = x.reshape(-1, 2, 9, 9, 9).sum(-1) < torch.tensor((8, 1)).to( x ).reshape(1, 2, 1, 1) mask_uncomplete = mask_uncomplete.reshape(-1, 2, 9, 9, 1) mask = ((x == 0).reshape(-1, 2, 9, 9, 9) * mask_uncomplete).reshape( -1, 2, 9**3 ) mask = mask.float() return mask def computing_loss(self, x, y, output): loss = self.bcewll(output, y) mask = self.mask_uncomplete(x, y) loss = (loss * mask).sum() return loss def training_step(self, batch, batch_idx): self.log( "train_grid_count", batch[0].shape[0], reduce_fx=torch.sum, on_epoch=True, on_step=False, ) self.layer_training_step(0, batch) while True: idx, batch = self.buffer.get_batch() if batch is None: break # check if the train should be done by comparing lr from sch = self.lr_schedulers() # if self.lr != sch[idx].get_last_lr(): self.layer_training_step(idx, batch) def validation_step(self, batch, batch_idx): self.layer_training_step(0, batch, train=False) while True: idx, batch = self.buffer.get_batch() if batch is None: break # check if the train should be done by comparing lr from sch = self.lr_schedulers() # if self.lr != sch[idx].get_last_lr(): self.layer_training_step(idx, batch, train=False) def layer_training_step( self, idx, batch, train=True ): # to rename to layer_training_step x, y = batch prefix = "train" if train else "val" self.log( f"{prefix}_grid_count_{idx}", batch[0].shape[0], reduce_fx=torch.sum, on_epoch=True, on_step=False, ) output = self.forward_layer(x, idx) loss = self.computing_loss(x, y, output) if train: opt = self.optimizers()#[idx] if isinstance(opt, list): opt=opt[idx] opt.zero_grad() self.manual_backward(loss) opt.step() loss_0 = F.binary_cross_entropy_with_logits(output[:, [0], :], y[:, [0], :]) loss_1 = F.binary_cross_entropy_with_logits(output[:, [1], :], y[:, [1], :]) self.log_dict( {f"{prefix}_loss_pos": loss_1, f"{prefix}_loss_neg": loss_0}, on_epoch=True ) # accuracy_1 = torch.mean(torch.eq(transform_to_number_1(output), transform_to_number_1(x)).type(torch.float)) # accuracy_0 = torch.mean(torch.eq(transform_to_number_0(output), transform_to_number_0(x)).type(torch.float)) # self.log_dict({'accuracy_1': accuracy_1, 'accuracy_0': accuracy_0}, on_epoch=True) self.log(f"{prefix}_loss_{idx}", loss) # add a count log on (X and x == y) new_X = self.compute_new_X(output, x, idx, y, train=train) solved_mask = (new_X == y).all(dim=1).all(dim=1) new_X = new_X[~solved_mask] x = x[~solved_mask] y = y[~solved_mask] self.log( f"{prefix}_resolved_grid_count", solved_mask.sum(), on_epoch=True, on_step=False, reduce_fx=torch.sum, ) mask_no_improve = new_X.sum(dim=(1, 2)) <= x.sum(dim=(1, 2)) self.log( f"{prefix}_improved_grid_count_{idx}", (~mask_no_improve).sum(), on_epoch=True, on_step=False, reduce_fx=torch.sum, ) # store_new_x # TODO keep the log in this method # loss per epoch per model boost layer # number of error per epoch model boost layer # number of resolved puzzles per epochs # threshold per epochs per model layer # number of sudoku grid # number of filled digits per model boost layer per epoch for both pis ans abs # add parameter reduce_fx=torch.sum() to numbers # th -> on_epoch=False self.store_new_x(idx, new_X, x, y) def store_new_x(self, idx, new_X, x, y): mask_improve = new_X.sum(dim=(1, 2)) > x.sum(dim=(1, 2)) self.buffer.append( idx + 1, (new_X[~mask_improve].clone(), y[~mask_improve].clone()) ) self.buffer.append(0, (new_X[mask_improve].clone(), y[mask_improve].clone())) # TODO if improve on no improvments -> add one digit from y to new_X and ad it to idx=0 def compute_new_X(self, output, x, idx, y=None, train=True, mask_adapt_th=None): # y could be None prefix = "train" if train else "val" new_X = torch.empty(output.shape, device=output.device) # we could try to make evolv threshold here if y is not None: # max_th_abs = ( # output[:, 0][(x[:, 0] == 0) & (y[:, 0] == 0)].max().item() # + self.th_epsilon # ) max_th_abs = output[:, 0][(y[:, 0] == 0)].max().item() + self.th_epsilon max_th_pres = ( output[:, 1][(x[:, 1] == 0) & (y[:, 1] == 0)].max().item() + self.th_epsilon ) if mask_adapt_th is None or (mask_adapt_th.sum()>0): if mask_adapt_th is not None and (mask_adapt_th.sum()>0): max_th_abs = output[mask_adapt_th, 0][(y[mask_adapt_th, 0] == 0)].max().item() + self.th_epsilon max_th_pres = ( output[mask_adapt_th, 1][(x[mask_adapt_th, 1] == 0) & (y[mask_adapt_th, 1] == 0)].max().item() + self.th_epsilon ) self.threshold_abs[idx] = max(max_th_abs, self.threshold_abs[idx]) self.threshold_pres[idx] = max(max_th_pres, self.threshold_pres[idx]) self.log_dict( { f"{prefix}_th_abs_{idx}": self.threshold_abs[idx], f"{prefix}_th_pres_{idx}": self.threshold_pres[idx], }, on_step=True, ) if not train: self.threshold_abs_compute[idx] = max( max_th_abs + self.margin, self.threshold_abs_compute[idx] ) self.threshold_pres_compute[idx] = max( max_th_pres + self.margin, self.threshold_pres_compute[idx] ) if self.training: new_X[:, 0] = (output[:, 0].detach() > self.threshold_abs[idx]).float() new_X[:, 1] = (output[:, 1].detach() > self.threshold_pres[idx]).float() else: new_X[:, 0] = (output[:, 0].detach() > self.threshold_abs[idx]).float() new_X[:, 1] = (output[:, 1].detach() > self.threshold_pres[idx]).float() new_X[x.detach() == 1] = 1 if y is not None: self.log( f"{prefix}_count_error_grid_{idx}", ((new_X == 1) & (y == 0)).any(dim=1).any(dim=1).sum(), on_epoch=True, on_step=False, reduce_fx=torch.sum, ) if mask_adapt_th is None: new_X[y.detach() == 0] = 0 # do not remove the error!!!!!! else: y_bis = y.detach().clone() y_bis[~mask_adapt_th]=1 new_X[y_bis==0] = 0 return new_X # TODO add idx stuff (one lr scheduler per net) # def on_train_epoch_end(self): # sch = self.lr_schedulers() # # If the selected scheduler is a ReduceLROnPlateau scheduler. # if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau): # sch.step(self.trainer.callback_metrics["loss"]) def on_validation_epoch_start(self) -> None: if self.reset_threshold_on_validation: self.threshold_abs_compute = torch.tensor( [-10.0 for _ in range(self.nets_number)] ) self.threshold_pres_compute = torch.tensor( [-10.0 for _ in range(self.nets_number)] ) else: self.threshold_abs_compute = self.threshold_abs self.threshold_pres_compute = self.threshold_pres self.buffer = BufferArray(self.nets_number, self.batch_size) def on_train_epoch_start(self) -> None: self.buffer = BufferArray(self.nets_number, self.batch_size) return super().on_train_epoch_start() def on_validation_epoch_end(self): # tensorboard = self.logger.experiment self.threshold_abs = self.threshold_abs_compute self.threshold_pres = self.threshold_pres_compute schs = self.lr_schedulers() if not isinstance(schs, list): schs=[schs] for idx, sch in enumerate(schs): # sch.step(self.validation.callback_metrics["val_loss_{idx}"]) try: sch.step(self.trainer.callback_metrics[f"val_loss_{idx}"]) except: # print(f"val_loss_{idx} not found") pass # sch.step(self.trainer.callback_metrics["val_loss_"]) def on_save_checkpoint(self, checkpoint) -> None: "Objects to include in checkpoint file" checkpoint["ths_abs"] = self.threshold_abs checkpoint["ths_pres"] = self.threshold_pres def on_load_checkpoint(self, checkpoint) -> None: "Objects to retrieve from checkpoint file" self.threshold_abs = checkpoint["ths_abs"] self.threshold_pres = checkpoint["ths_pres"] self.nets = nn.ModuleList([SmallNetBis() for _ in self.threshold_abs]) def validate_grids(self, x) -> "torch.tensor": return ~( (self.sym_preprocess(x)[:, 17].max(dim=1).values > (1 / 8)) | (self.sym_preprocess(x)[:, 18].max(dim=1).values > (1 / 8)) | (self.sym_preprocess(x)[:, 19].max(dim=1).values > (1 / 8)) | (x.view(-1,2,9,9,9)[:,1].sum(dim=-1)>1).any(dim=1).any(dim=1) | (x.view(-1,2,9,9,9)[:,0].sum(dim=-1)>8).any(dim=1).any(dim=1) ) # steps to trial error # - get stops # - choose a number -> store it # - process to get either a new stop either a a validation grid fail # if validation grid fail back propagate # else choose a number # add counter to each grid, # add id to each grid id=batch_id + position # add validation # if non improvment stop -> # - check if id already exist, if true add non improve counter # if non improve counter = 2 -> add grid to trial_error_model_buffer with 1000 step target. # - store the grid to trial_error_model deep search dict # - create two grids with counter to 0, same id # add them in the buffer # - when validation fail -> # - check if id already exist # if true: add grid to trial_error_model with the counter # if false: raise error # TODO adapt training to something softer # class TrialEveryPosException(Exception): pass class SudokuTrialErrorLightning(SudokuLightning): def __init__(self, **kwargs): super().__init__(**kwargs) self.deep_backtrack_regressor = SmallNetBis(n_output=1) self.trial_error_buffer = Buffer(self.batch_size) self.trial_grids = [None] # schema: # [ # idx: # "tried_pos": [ # ] # "pos": pos # "no_improve_counter": 0 # ] # # self.tracking_grid = [] def copy_from_model(self, model): self.nets = model.nets self.threshold_pres = model.threshold_pres self.threshold_abs = model.threshold_abs def reg(self, x): x_reg = self.sym_preprocess.forward(x) x_reg = self.deep_backtrack_regressor(x_reg) return torch.softmax(x_reg, dim=1) def configure_optimizers(self): # no need config scheduler -> manual optimisation # optimizers = [torch.optim.Adam(net.parameters(), lr=self.lr) for net in self.nets] optimizers = [] for net in self.nets: opti = torch.optim.Adam(net.parameters(), lr=self.lr) optimizers.append( { "optimizer": opti, "lr_scheduler": ReduceLROnPlateau(opti, "min"), } ) optimizers.append( { 'optimizer': torch.optim.Adam(self.deep_backtrack_regressor.parameters(), lr=self.lr), "lr_scheduler": ReduceLROnPlateau(opti, "min"), } ) return optimizers def training_step(self, batch, batch_idx): self.log( "train_grid_count", batch[0].shape[0], reduce_fx=torch.sum, on_epoch=True, on_step=False, ) x, y = batch x_idx = torch.zeros(self.batch_size) # if we are not on trial error x_idx=0 counters = torch.zeros(self.batch_size) self.layer_training_step(0, (x, y, x_idx, counters)) idx_while=0 while True: idx_while+=1 if idx_while ==10000: print('a while') idx, batch = self.buffer.get_batch() if batch is None: break # check if the train should be done by comparing lr from sch = self.lr_schedulers() # if self.lr != sch[idx].get_last_lr(): self.layer_training_step(idx, batch) while True: trial_error_batch = self.trial_error_buffer.get_batch() if trial_error_batch is None: break self.trial_error_training_step(trial_error_batch) def validation_step(self, batch, batch_idx): x, y = batch x_idx = torch.zeros(x.shape[0], dtype=torch.long) # if we are not on trial error x_idx=0 counters = torch.zeros(x.shape[0]) self.layer_training_step(0, (x, y, x_idx, counters), train=False) while True: idx, batch = self.buffer.get_batch() if batch is None: break # check if the train should be done by comparing lr from sch = self.lr_schedulers() # if self.lr != sch[idx].get_last_lr(): self.layer_training_step(idx, batch, train=False) while True: trial_error_batch = self.trial_error_buffer.get_batch() if trial_error_batch is None: break self.trial_error_training_step(trial_error_batch, train=False) def layer_training_step( self, idx, batch, train=True ): # to rename to layer_training_step x, y, x_idx, counters = batch prefix = "train" if train else "val" self.log( f"{prefix}_grid_count_{idx}", batch[0].shape[0], reduce_fx=torch.sum, on_epoch=True, on_step=False, ) output = self.forward_layer(x, idx) loss = self.computing_loss(x[x_idx==0], y[x_idx==0], output[x_idx==0]) if train: pass opt = self.optimizers()[idx] opt.zero_grad() self.manual_backward(loss) opt.step() loss_0 = F.binary_cross_entropy_with_logits(output[:, [0], :], y[:, [0], :]) loss_1 = F.binary_cross_entropy_with_logits(output[:, [1], :], y[:, [1], :]) self.log_dict( {f"{prefix}_loss_pos": loss_1, f"{prefix}_loss_neg": loss_0}, on_epoch=True ) self.log(f"{prefix}_loss_{idx}", loss) mask_bad_x = ((x==1)&(y==0)).any(dim=1).any(dim=1) new_X = self.compute_new_X(output, x, idx, y, train=train, mask_adapt_th=(~mask_bad_x)) solved_mask = (new_X == y).all(dim=1).all(dim=1) new_X = new_X[~solved_mask] x = x[~solved_mask] y = y[~solved_mask] x_idx = x_idx[~solved_mask] counters = counters[~solved_mask] self.log( f"{prefix}_resolved_grid_count", solved_mask.sum(), on_epoch=True, on_step=False, reduce_fx=torch.sum, ) mask_no_improve = new_X.sum(dim=(1, 2)) <= x.sum(dim=(1, 2)) self.log( f"{prefix}_improved_grid_count_{idx}", (~mask_no_improve).sum(), on_epoch=True, on_step=False, reduce_fx=torch.sum, ) # self.store_new_x(idx, new_X, x, y) # TODO create another function (need to increment counter and validate) self.process_validation(idx, new_X, x, y, x_idx, counters) def process_validation(self, idx, new_X, x, y, x_idx, counters): new_X = self.redresse_new_X(new_X,y,x) mask_validated = self.validate_grids(new_X) # mask_not_validated = (~self.validate_grids(new_X)) & ((x==0)&(y==1)).any(dim=(1,2)) mask_improve = (new_X.sum(dim=(1, 2)) > x.sum(dim=(1, 2))) & mask_validated mask_not_improved = (new_X.sum(dim=(1, 2)) == x.sum(dim=(1, 2))) & mask_validated for i, (failed_idx, failed_counter, s_new_X, s_y) in enumerate(zip( x_idx[~mask_validated], counters[~mask_validated], new_X[~mask_validated], y[~mask_validated], )): # when we find failed: # - we store good grid to continue the process # /!\ it is not necessary, the second half will continue to process. # - we store the initial grid with the score (to traine the regressor) if failed_idx == 0: self.failed_batch = (x[~mask_validated][i], s_y) raise ValueError("validation error on no trial-error grid") if not ((x[~mask_validated][i]==0)&(s_y==1)).any(): raise ValueError() is_pos = copysign(1, failed_idx)==1 trial_grid: TrialGrid = self.trial_grids[int(abs(failed_idx))] if is_pos: trial_grid.pos_result = 'fail' else: trial_grid.neg_result = 'fail' self.process_search_store_grid(int(abs(failed_idx)), trial_grid, s_y) if idx == self.nets_number - 1: for no_improved_idx, s_new_X, s_y in zip( x_idx[mask_not_improved], new_X[mask_not_improved], y[mask_not_improved] ): if no_improved_idx == 0: self.search_trial_buffer_trials(s_new_X, s_y) continue is_pos = copysign(1, no_improved_idx)==1 trial_grid: TrialGrid = self.trial_grids[int(abs(no_improved_idx.item()))] if is_pos: trial_grid.pos_result = 'no_improved' else: trial_grid.neg_result = 'no_improved' assert s_new_X.sum()> trial_grid.initial_grid.sum() self.process_search_store_grid(int(abs(no_improved_idx)), trial_grid, s_y) self.buffer.append( idx + 1, (new_X[mask_not_improved].clone(), y[mask_not_improved].clone(), x_idx[mask_not_improved].clone(), counters[mask_not_improved].clone()), ) # assert mask_improve.sum()>0 if ((new_X[mask_improve & (x_idx.to(self.device)==0)]==1) & (y[mask_improve & (x_idx.to(self.device)==0)]==0)).any(): self.failed_batch=(x[mask_improve & (x_idx.to(self.device)==0)],y[mask_improve & (x_idx.to(self.device)==0)] ) raise ValueError() self.buffer.append( 0, (new_X[mask_improve].clone(), y[mask_improve].clone(), x_idx[mask_improve].clone(), counters[mask_improve].clone() + 1), ) def process_search_store_grid(self, idx, trial_grid: TrialGrid, s_y): """_summary_ if score is 1: great if fail -> the second one should continue (it has his id if it stopped) so do nothing if no_improved -> trial_error and reset trial_error_grid if score is not None -> store the new grid in the trial_error_buffer if score is -1 => also search_trial and store. if store is 1 => if both result are here: get the no_improved -> search_trial and store on a new grid if one complete grid -> set grid place to None else: wait if score is None: if grid_idx==-1 or oposite_grid failed: we create a new_idx, and store stuff. else: we store the grid (in case the second grid fail) we increment the non_improvement counter if non_improvement counter = 2: we add the initial grid to the search training buffer we process the search training engine to find another grid postion else: we add the initial grid to the search training buffer if a non improved grid is store we create a new_idx and store stuf. Args: grid_idx (_type_): _description_ score (_type_): _description_ s_new_X (_type_): _description_ s_y (_type_): _description_ """ score = trial_grid.score() if score is None: self.trial_grids[idx]=trial_grid return # add grid to buffer (initial_grid, score) self.trial_error_buffer.append(( trial_grid.initial_grid.view(-1,2,729), torch.tensor([score,],dtype=torch.float).to(self.device), torch.tensor([trial_grid.row_col_digit_position,], dtype=torch.long).to(self.device), )) # find the no_improve_grid ~and search_trial~ and add it to buffer if trial_grid.neg_result == 'no_improved': if trial_grid.pos_result == 'no_improved': trial_grid.tried_grid.append(trial_grid.row_col_digit_position) trial_grid.neg_result= None trial_grid.pos_result= None self.trial_grids[idx] = trial_grid self.search_trial_buffer_trials(None, s_y, idx) # new trial with same idx return # add to buffer neg grid # we get back the initial grid # set the correct row col digit # add it the buffer # set trial_grids to None grid_neg = deepcopy(trial_grid.initial_grid) grid_neg[0,trial_grid.row_col_digit_position] = 1 if ((grid_neg==1) & (s_y==0)).any(): raise ValueError() self.buffer.append( 0, ( grid_neg.view(-1,2,729), s_y.clone().view(-1,2,729), torch.tensor([0]), torch.tensor([0]), ) ) self.trial_grids[idx] = None return if trial_grid.pos_result == 'no_improved': grid_pos = deepcopy(trial_grid.initial_grid) grid_pos[1,trial_grid.row_col_digit_position] = 1 if ((grid_pos==1) & (s_y==0)).any(): raise ValueError() self.buffer.append( 0, ( grid_pos.view(-1,2,729), s_y.clone().view(-1,2,729), torch.tensor([0]), torch.tensor([0]), ) ) self.trial_grids[idx] = None # add to buffer pos grid return # if complete: replace grid by none. if "complete" in [trial_grid.neg_result, trial_grid.pos_result]: self.trial_grids[idx]=None # def store_new_trial_error_grid(self, new_X, y): # """build a new idx add the grid in the tracking stuff # and add grid in the buffer # Args: # new_X (_type_): _description_ # y (_type_): _description_ # """ # ... # def store_training_trail_search_batch(self, grid, score): # """store grid to train trial_search nn model # Args: # grid (_type_): _description_ # score (_type_): _description_ # """ # ... def search_trial(self, s_new_X, tried_pos): """use the trail_search nn model to probe a new Args: s_new_X (_type_): _description_ tried_pos (_type_): _description_ """ mask_possibility = s_new_X.sum(dim=0)==0 for pos in tried_pos: mask_possibility[pos]=False if mask_possibility.sum()==0: print('mask_possible=0') raise TrialEveryPosException() with torch.no_grad(): x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1)) output = self.deep_backtrack_regressor(x_reg) # shape (1, 729) # can be regression -> i want the smallest # can be logistic regression -> i want the smallest # if i do softmax -> i can add 1 to each tried pos output = torch.softmax(output[0][0],dim=0) # for pos in tried_pos: # output[pos]=1 output[~mask_possibility]+=1 return torch.argmin(output, dim=0).item() def search_trial_buffer_trials(self, s_new_X, s_y, idx_trial_grids=None): if idx_trial_grids is None: row_col_digit_trial = self.search_trial(s_new_X, []) trial_grid = TrialGrid(s_new_X, row_col_digit_trial) self.trial_grids.append(TrialGrid(s_new_X, row_col_digit_trial)) idx_trial_grids = len(self.trial_grids)-1 else: trial_grid = self.trial_grids[idx_trial_grids] s_new_X = trial_grid.initial_grid row_col_digit_trial = self.search_trial(s_new_X, trial_grid.tried_grid) trial_grid.row_col_digit_position = row_col_digit_trial self.trial_grids[idx_trial_grids] = trial_grid # and we add both into buffer. grid_pos = deepcopy(s_new_X) grid_neg = deepcopy(s_new_X) grid_pos[1,row_col_digit_trial] = 1 grid_neg[0,row_col_digit_trial] = 1 self.buffer.append( 0, ( torch.stack([grid_pos,grid_neg], dim=0), torch.stack([s_y.clone(),s_y.clone()], dim=0), torch.tensor([idx_trial_grids, -idx_trial_grids]), torch.tensor([0, 0]), ) ) def trial_error_training_step(self, batch, train=True): x, y, row_col_digit = batch prefix = "train" if train else "val" self.log( f"{prefix}_grid_count_trial_error_training", batch[0].shape[0], reduce_fx=torch.sum, on_epoch=True, on_step=False, ) x_reg = self.sym_preprocess.forward(x) output = self.deep_backtrack_regressor(x_reg) loss = nn.functional.binary_cross_entropy_with_logits(output[[i for i in range(self.batch_size)], 0, row_col_digit], y, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None) # # depending the distribution of the target, the weight could be different # loss = binary (output[:,0,row_col_digit], y) # loss = self.computing_loss(x, y, output) if train: opt = self.optimizers()[-1] opt.zero_grad() self.manual_backward(loss) opt.step() self.log(f"{prefix}_loss_trial_error", loss) self.log(f"{prefix}_loss_{self.nets_number}", loss) self.log(f"{prefix}_y_pos_trial_error", y.sum()) self.log(f"{prefix}_y_neg_trial_eror", y.shape[0]-y.sum()) def predict(self, x, func_text_display=None): """ return an improvement of x """ idx, new_X = self.forward(x.view(-1,2,729)) if (new_X.sum()>x.sum()) or (new_X.sum()==729): if func_text_display: func_text_display(f'boost layer step: {idx}') return new_X else: # call trial error until we find a solution tried_position = [] while True: pos = self.search_trial(x.view(2,729), tried_position) tried_position.append(pos) # creat pos neg tensor grid_pos = deepcopy(x.view(2,729)) grid_neg = deepcopy(x.view(2,729)) grid_pos[1,pos] = 1 grid_neg[0,pos] = 1 X_tried = torch.stack([grid_neg, grid_pos], dim=0) # process it while True: idx, new_X = self.forward(X_tried) mask_validated = self.validate_grids(new_X) if mask_validated.sum()<2: x[0, mask_validated, pos] = 1 # TODO check if it work if func_text_display: digit, col, row = pos_to_digit_col_row(pos) func_text_display('model failed to improve the grid') func_text_display(f'trial error alogorithm, found error at digit: {digit}, col: {col}, row: {row}') return x if X_tried.sum()==new_X.sum(): # if both stop to improve -> break it will tried an new pos break mask_complete = (X_tried.sum(dim=1)==729)# check if it works if mask_complete.sum()>0: x[0, mask_complete, pos] = 1 if func_text_display: digit, col, row = pos_to_digit_col_row(pos) func_text_display('model failed to improve the grid') func_text_display(f'trial error alogorithm, complete the grid with digit: {digit}, col: {col}, row: {row}') return x X_tried = new_X # if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution) def backtracking_predict(self, x, use_trial_error=False, assumption=[], func_text_display=None, func_tensor_display=None): """ return is_valid, new_x """ next_X = deepcopy(x) sum_1 = next_X.sum().item() if use_trial_error: while True: try: next_X = self.predict(next_X.view(1,2,729)) if not self.validate_grids(next_X)[0].item(): if assumption is not []: func_text_display(f'assumption {assumption[-1]}, failed') func_text_display(f'assumption length {len(assumption)}') func_tensor_display(next_X) return False, None if next_X.sum()>=729: if assumption is not []: func_text_display(f'assumption length {len(assumption)}') func_tensor_display(next_X) return True, next_X except TrialEveryPosException: break else: while True: _idx, next_X = self.forward(next_X.view(1,2,729)) if not self.validate_grids(next_X)[0].item(): return False, None sum_2 = next_X.sum().item() if sum_1==sum_2: break else: sum_1=sum_2 if next_X.sum()>=729: return True, next_X pos = self.search_trial(next_X.view(2,729), []) new_x = deepcopy(next_X.view(1,2,729)) output = self.forward_layer(next_X.view(1,2,729)) pos_output = output[0,:,pos] if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item()-self.threshold_pres[0].item()): trial_abs_pres = 0 abs_pres_str = 'abs' else: trial_abs_pres = 1 abs_pres_str = 'pres' new_x[0,trial_abs_pres,pos]=1 # print(f'assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), trial_abs_pres)}')# i+j*9+n*9*9 digit, col, row = pos_to_digit_col_row(pos) func_text_display('grid not improving adding the assuption') current_assumption = f'{abs_pres_str}, digit: {digit}, col: {col}, row: {row}' func_text_display(current_assumption) assumption.append(current_assumption) is_valid, new_x = self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display) if not is_valid: new_x = deepcopy(next_X.view(1,2,729)) new_x[0,(trial_abs_pres+1)%2,pos]=1 # print(f'finally assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), (trial_abs_pres+1)%2)}')# i+j*9+n*9*9 current_assumption = f'{"pres" if abs_pres_str=="abs" else "abs"}, digit: {digit}, col: {col}, row: {row}' func_text_display(current_assumption) return self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display) return True, new_x def on_validation_epoch_start(self) -> None: # self.buffer = BufferArray(self.nets_number, self.batch_size) self.trial_error_buffer = Buffer(self.batch_size) self.trial_grids = [None] return super().on_validation_epoch_start() def on_train_epoch_start(self) -> None: self.trial_error_buffer = Buffer(self.batch_size) self.trial_grids = [None] return super().on_train_epoch_start() def redresse_new_X(self, new_X,y,x): mask_bad_x = ((x==1)&(y==0)).any(dim=1).any(dim=1) y_bis = y.clone() y_bis[mask_bad_x]=1 new_X[y_bis==0]=0 return new_X # ADD threshold adjustment during prediction # or maybe validate? on it? bah oui!