Sebastien
improving backtracking algorithm and display in the app
19320e5
from math import copysign
import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pytorch_lightning as pl
from sudoku.models import SmallNetBis, SymPreprocess
import torch.nn.functional as F
from sudoku.buffer import BufferArray, Buffer
from sudoku.trial_grid import TrialGrid
from sudoku.helper import pos_to_digit_col_row
from copy import deepcopy
class SudokuLightning(pl.LightningModule):
def __init__(
self,
lr=0.1,
margin=0.1, # th marge?
coef_0=10,
nets_number=6,
nets_training_number=1,
batch_size=32,
):
super().__init__()
self.nets_number = nets_number
self.batch_size = batch_size
self.nets_training_number = nets_training_number
# self.nets=[SmallNetBis() for _ in range(self.nets_number)]
self.nets = nn.ModuleList([SmallNetBis() for _ in range(self.nets_number)])
self.buffer = BufferArray(self.nets_number, self.batch_size)
self.sym_preprocess = SymPreprocess()
pos_weight = torch.ones((2, 9 * 9 * 9))
pos_weight[0, :] = 1.0 / 8.0
pos_weight[1, :] = 1.0
pos_weight /= coef_0
weight = torch.ones((2, 9 * 9 * 9))
weight[0, :] = 8.0
weight[1, :] = 1.0
weight *= coef_0
self.bcewll = nn.BCEWithLogitsLoss(
pos_weight=pos_weight, weight=weight, reduce=False
)
self.lr = lr
# self.auroc = AUROC(task='binary')
self.margin = margin
self.th_epsilon = margin * 0.01
self.threshold_pres = torch.tensor([-10.0 for _ in range(nets_number)])
self.threshold_abs = torch.tensor([-10.0 for _ in range(nets_number)])
self.automatic_optimization = False
self.reset_threshold_on_validation = True
def configure_optimizers(self):
# no need config scheduler -> manual optimisation
optimizers = []
for net in self.nets:
opti = torch.optim.Adam(net.parameters(), lr=self.lr)
optimizers.append(
{
"optimizer": opti,
"lr_scheduler": ReduceLROnPlateau(opti, "min"),
}
)
return optimizers
# def configure_optimizers(self):
# optimizer1 = Adam(...)
# optimizer2 = SGD(...)
# scheduler1 = ReduceLROnPlateau(optimizer1, ...)
# scheduler2 = LambdaLR(optimizer2, ...)
# return (
# {
# "optimizer": optimizer1,
# "lr_scheduler": {
# "scheduler": scheduler1,
# "monitor": "metric_to_track",
# },
# },
# {"optimizer": optimizer2, "lr_scheduler": scheduler2},
# )
# lr_scheduler_config = {
# # REQUIRED: The scheduler instance
# "scheduler": lr_scheduler,
# # The unit of the scheduler's step size, could also be 'step'.
# # 'epoch' updates the scheduler on epoch end whereas 'step'
# # updates it after a optimizer update.
# "interval": "epoch",
# # How many epochs/steps should pass between calls to
# # `scheduler.step()`. 1 corresponds to updating the learning
# # rate after every epoch/step.
# "frequency": 1,
# # Metric to to monitor for schedulers like `ReduceLROnPlateau`
# "monitor": "val_loss",
# # If set to `True`, will enforce that the value specified 'monitor'
# # is available when the scheduler is updated, thus stopping
# # training if not found. If set to `False`, it will only produce a warning
# "strict": True,
# # If using the `LearningRateMonitor` callback to monitor the
# # learning rate progress, this keyword can be used to specify
# # a custom logged name
# "name": None,
# }
# lr_scheduler_config = {'scheduler: lr_sch, interval: epoch, frequency: 1, monitor: 'val_loss'}
def forward_layer(self, x, idx=0):
x = self.sym_preprocess.forward(x)
return self.nets[idx](x)
def forward(self, x):
for idx in range(self.nets_number):
output = self.forward_layer(x, idx)
new_X = self.compute_new_X(output, x, idx, None, train=False)
improved_mask = ((new_X == 1) & (x == 0)).any(dim=1).any(dim=1)
if improved_mask.sum() > 0:
return idx, new_X
return idx, new_X
def predict_from_net(self, x, net, th_abs, th_pres):
x = self.sym_preprocess.forward(x)
x = net(x)
new_x = torch.empty(x.shape, device=x.device)
new_x[:, 0] = (x[:, 0] > th_abs).float()
new_x[:, 1] = (x[:, 1] > th_pres).float()
return new_x
@staticmethod
def mask_uncomplete(x, y):
mask_uncomplete = x.reshape(-1, 2, 9, 9, 9).sum(-1) < torch.tensor((8, 1)).to(
x
).reshape(1, 2, 1, 1)
mask_uncomplete = mask_uncomplete.reshape(-1, 2, 9, 9, 1)
mask = ((x == 0).reshape(-1, 2, 9, 9, 9) * mask_uncomplete).reshape(
-1, 2, 9**3
)
mask = mask.float()
return mask
def computing_loss(self, x, y, output):
loss = self.bcewll(output, y)
mask = self.mask_uncomplete(x, y)
loss = (loss * mask).sum()
return loss
def training_step(self, batch, batch_idx):
self.log(
"train_grid_count",
batch[0].shape[0],
reduce_fx=torch.sum,
on_epoch=True,
on_step=False,
)
self.layer_training_step(0, batch)
while True:
idx, batch = self.buffer.get_batch()
if batch is None:
break
# check if the train should be done by comparing lr from sch = self.lr_schedulers()
# if self.lr != sch[idx].get_last_lr():
self.layer_training_step(idx, batch)
def validation_step(self, batch, batch_idx):
self.layer_training_step(0, batch, train=False)
while True:
idx, batch = self.buffer.get_batch()
if batch is None:
break
# check if the train should be done by comparing lr from sch = self.lr_schedulers()
# if self.lr != sch[idx].get_last_lr():
self.layer_training_step(idx, batch, train=False)
def layer_training_step(
self, idx, batch, train=True
): # to rename to layer_training_step
x, y = batch
prefix = "train" if train else "val"
self.log(
f"{prefix}_grid_count_{idx}",
batch[0].shape[0],
reduce_fx=torch.sum,
on_epoch=True,
on_step=False,
)
output = self.forward_layer(x, idx)
loss = self.computing_loss(x, y, output)
if train:
opt = self.optimizers()#[idx]
if isinstance(opt, list):
opt=opt[idx]
opt.zero_grad()
self.manual_backward(loss)
opt.step()
loss_0 = F.binary_cross_entropy_with_logits(output[:, [0], :], y[:, [0], :])
loss_1 = F.binary_cross_entropy_with_logits(output[:, [1], :], y[:, [1], :])
self.log_dict(
{f"{prefix}_loss_pos": loss_1, f"{prefix}_loss_neg": loss_0}, on_epoch=True
)
# accuracy_1 = torch.mean(torch.eq(transform_to_number_1(output), transform_to_number_1(x)).type(torch.float))
# accuracy_0 = torch.mean(torch.eq(transform_to_number_0(output), transform_to_number_0(x)).type(torch.float))
# self.log_dict({'accuracy_1': accuracy_1, 'accuracy_0': accuracy_0}, on_epoch=True)
self.log(f"{prefix}_loss_{idx}", loss)
# add a count log on (X and x == y)
new_X = self.compute_new_X(output, x, idx, y, train=train)
solved_mask = (new_X == y).all(dim=1).all(dim=1)
new_X = new_X[~solved_mask]
x = x[~solved_mask]
y = y[~solved_mask]
self.log(
f"{prefix}_resolved_grid_count",
solved_mask.sum(),
on_epoch=True,
on_step=False,
reduce_fx=torch.sum,
)
mask_no_improve = new_X.sum(dim=(1, 2)) <= x.sum(dim=(1, 2))
self.log(
f"{prefix}_improved_grid_count_{idx}",
(~mask_no_improve).sum(),
on_epoch=True,
on_step=False,
reduce_fx=torch.sum,
)
# store_new_x
# TODO keep the log in this method
# loss per epoch per model boost layer
# number of error per epoch model boost layer
# number of resolved puzzles per epochs
# threshold per epochs per model layer
# number of sudoku grid
# number of filled digits per model boost layer per epoch for both pis ans abs
# add parameter reduce_fx=torch.sum() to numbers
# th -> on_epoch=False
self.store_new_x(idx, new_X, x, y)
def store_new_x(self, idx, new_X, x, y):
mask_improve = new_X.sum(dim=(1, 2)) > x.sum(dim=(1, 2))
self.buffer.append(
idx + 1, (new_X[~mask_improve].clone(), y[~mask_improve].clone())
)
self.buffer.append(0, (new_X[mask_improve].clone(), y[mask_improve].clone()))
# TODO if improve on no improvments -> add one digit from y to new_X and ad it to idx=0
def compute_new_X(self, output, x, idx, y=None, train=True, mask_adapt_th=None):
# y could be None
prefix = "train" if train else "val"
new_X = torch.empty(output.shape, device=output.device)
# we could try to make evolv threshold here
if y is not None:
# max_th_abs = (
# output[:, 0][(x[:, 0] == 0) & (y[:, 0] == 0)].max().item()
# + self.th_epsilon
# )
max_th_abs = output[:, 0][(y[:, 0] == 0)].max().item() + self.th_epsilon
max_th_pres = (
output[:, 1][(x[:, 1] == 0) & (y[:, 1] == 0)].max().item()
+ self.th_epsilon
)
if mask_adapt_th is None or (mask_adapt_th.sum()>0):
if mask_adapt_th is not None and (mask_adapt_th.sum()>0):
max_th_abs = output[mask_adapt_th, 0][(y[mask_adapt_th, 0] == 0)].max().item() + self.th_epsilon
max_th_pres = (
output[mask_adapt_th, 1][(x[mask_adapt_th, 1] == 0) & (y[mask_adapt_th, 1] == 0)].max().item()
+ self.th_epsilon
)
self.threshold_abs[idx] = max(max_th_abs, self.threshold_abs[idx])
self.threshold_pres[idx] = max(max_th_pres, self.threshold_pres[idx])
self.log_dict(
{
f"{prefix}_th_abs_{idx}": self.threshold_abs[idx],
f"{prefix}_th_pres_{idx}": self.threshold_pres[idx],
},
on_step=True,
)
if not train:
self.threshold_abs_compute[idx] = max(
max_th_abs + self.margin, self.threshold_abs_compute[idx]
)
self.threshold_pres_compute[idx] = max(
max_th_pres + self.margin, self.threshold_pres_compute[idx]
)
if self.training:
new_X[:, 0] = (output[:, 0].detach() > self.threshold_abs[idx]).float()
new_X[:, 1] = (output[:, 1].detach() > self.threshold_pres[idx]).float()
else:
new_X[:, 0] = (output[:, 0].detach() > self.threshold_abs[idx]).float()
new_X[:, 1] = (output[:, 1].detach() > self.threshold_pres[idx]).float()
new_X[x.detach() == 1] = 1
if y is not None:
self.log(
f"{prefix}_count_error_grid_{idx}",
((new_X == 1) & (y == 0)).any(dim=1).any(dim=1).sum(),
on_epoch=True,
on_step=False,
reduce_fx=torch.sum,
)
if mask_adapt_th is None:
new_X[y.detach() == 0] = 0 # do not remove the error!!!!!!
else:
y_bis = y.detach().clone()
y_bis[~mask_adapt_th]=1
new_X[y_bis==0] = 0
return new_X
# TODO add idx stuff (one lr scheduler per net)
# def on_train_epoch_end(self):
# sch = self.lr_schedulers()
# # If the selected scheduler is a ReduceLROnPlateau scheduler.
# if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
# sch.step(self.trainer.callback_metrics["loss"])
def on_validation_epoch_start(self) -> None:
if self.reset_threshold_on_validation:
self.threshold_abs_compute = torch.tensor(
[-10.0 for _ in range(self.nets_number)]
)
self.threshold_pres_compute = torch.tensor(
[-10.0 for _ in range(self.nets_number)]
)
else:
self.threshold_abs_compute = self.threshold_abs
self.threshold_pres_compute = self.threshold_pres
self.buffer = BufferArray(self.nets_number, self.batch_size)
def on_train_epoch_start(self) -> None:
self.buffer = BufferArray(self.nets_number, self.batch_size)
return super().on_train_epoch_start()
def on_validation_epoch_end(self):
# tensorboard = self.logger.experiment
self.threshold_abs = self.threshold_abs_compute
self.threshold_pres = self.threshold_pres_compute
schs = self.lr_schedulers()
if not isinstance(schs, list):
schs=[schs]
for idx, sch in enumerate(schs):
# sch.step(self.validation.callback_metrics["val_loss_{idx}"])
try:
sch.step(self.trainer.callback_metrics[f"val_loss_{idx}"])
except:
# print(f"val_loss_{idx} not found")
pass
# sch.step(self.trainer.callback_metrics["val_loss_"])
def on_save_checkpoint(self, checkpoint) -> None:
"Objects to include in checkpoint file"
checkpoint["ths_abs"] = self.threshold_abs
checkpoint["ths_pres"] = self.threshold_pres
def on_load_checkpoint(self, checkpoint) -> None:
"Objects to retrieve from checkpoint file"
self.threshold_abs = checkpoint["ths_abs"]
self.threshold_pres = checkpoint["ths_pres"]
self.nets = nn.ModuleList([SmallNetBis() for _ in self.threshold_abs])
def validate_grids(self, x) -> "torch.tensor":
return ~(
(self.sym_preprocess(x)[:, 17].max(dim=1).values > (1 / 8))
| (self.sym_preprocess(x)[:, 18].max(dim=1).values > (1 / 8))
| (self.sym_preprocess(x)[:, 19].max(dim=1).values > (1 / 8))
| (x.view(-1,2,9,9,9)[:,1].sum(dim=-1)>1).any(dim=1).any(dim=1)
| (x.view(-1,2,9,9,9)[:,0].sum(dim=-1)>8).any(dim=1).any(dim=1)
)
# steps to trial error
# - get stops
# - choose a number -> store it
# - process to get either a new stop either a a validation grid fail
# if validation grid fail back propagate
# else choose a number
# add counter to each grid,
# add id to each grid id=batch_id + position
# add validation
# if non improvment stop ->
# - check if id already exist, if true add non improve counter
# if non improve counter = 2 -> add grid to trial_error_model_buffer with 1000 step target.
# - store the grid to trial_error_model deep search dict
# - create two grids with counter to 0, same id
# add them in the buffer
# - when validation fail ->
# - check if id already exist
# if true: add grid to trial_error_model with the counter
# if false: raise error
# TODO adapt training to something softer
#
class TrialEveryPosException(Exception):
pass
class SudokuTrialErrorLightning(SudokuLightning):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.deep_backtrack_regressor = SmallNetBis(n_output=1)
self.trial_error_buffer = Buffer(self.batch_size)
self.trial_grids = [None]
# schema:
# [
# idx:
# "tried_pos": [
# ]
# "pos": pos
# "no_improve_counter": 0
# ]
#
# self.tracking_grid = []
def copy_from_model(self, model):
self.nets = model.nets
self.threshold_pres = model.threshold_pres
self.threshold_abs = model.threshold_abs
def reg(self, x):
x_reg = self.sym_preprocess.forward(x)
x_reg = self.deep_backtrack_regressor(x_reg)
return torch.softmax(x_reg, dim=1)
def configure_optimizers(self):
# no need config scheduler -> manual optimisation
# optimizers = [torch.optim.Adam(net.parameters(), lr=self.lr) for net in self.nets]
optimizers = []
for net in self.nets:
opti = torch.optim.Adam(net.parameters(), lr=self.lr)
optimizers.append(
{
"optimizer": opti,
"lr_scheduler": ReduceLROnPlateau(opti, "min"),
}
)
optimizers.append(
{
'optimizer': torch.optim.Adam(self.deep_backtrack_regressor.parameters(), lr=self.lr),
"lr_scheduler": ReduceLROnPlateau(opti, "min"),
}
)
return optimizers
def training_step(self, batch, batch_idx):
self.log(
"train_grid_count",
batch[0].shape[0],
reduce_fx=torch.sum,
on_epoch=True,
on_step=False,
)
x, y = batch
x_idx = torch.zeros(self.batch_size) # if we are not on trial error x_idx=0
counters = torch.zeros(self.batch_size)
self.layer_training_step(0, (x, y, x_idx, counters))
idx_while=0
while True:
idx_while+=1
if idx_while ==10000:
print('a while')
idx, batch = self.buffer.get_batch()
if batch is None:
break
# check if the train should be done by comparing lr from sch = self.lr_schedulers()
# if self.lr != sch[idx].get_last_lr():
self.layer_training_step(idx, batch)
while True:
trial_error_batch = self.trial_error_buffer.get_batch()
if trial_error_batch is None:
break
self.trial_error_training_step(trial_error_batch)
def validation_step(self, batch, batch_idx):
x, y = batch
x_idx = torch.zeros(x.shape[0], dtype=torch.long) # if we are not on trial error x_idx=0
counters = torch.zeros(x.shape[0])
self.layer_training_step(0, (x, y, x_idx, counters), train=False)
while True:
idx, batch = self.buffer.get_batch()
if batch is None:
break
# check if the train should be done by comparing lr from sch = self.lr_schedulers()
# if self.lr != sch[idx].get_last_lr():
self.layer_training_step(idx, batch, train=False)
while True:
trial_error_batch = self.trial_error_buffer.get_batch()
if trial_error_batch is None:
break
self.trial_error_training_step(trial_error_batch, train=False)
def layer_training_step(
self, idx, batch, train=True
): # to rename to layer_training_step
x, y, x_idx, counters = batch
prefix = "train" if train else "val"
self.log(
f"{prefix}_grid_count_{idx}",
batch[0].shape[0],
reduce_fx=torch.sum,
on_epoch=True,
on_step=False,
)
output = self.forward_layer(x, idx)
loss = self.computing_loss(x[x_idx==0], y[x_idx==0], output[x_idx==0])
if train:
pass
opt = self.optimizers()[idx]
opt.zero_grad()
self.manual_backward(loss)
opt.step()
loss_0 = F.binary_cross_entropy_with_logits(output[:, [0], :], y[:, [0], :])
loss_1 = F.binary_cross_entropy_with_logits(output[:, [1], :], y[:, [1], :])
self.log_dict(
{f"{prefix}_loss_pos": loss_1, f"{prefix}_loss_neg": loss_0}, on_epoch=True
)
self.log(f"{prefix}_loss_{idx}", loss)
mask_bad_x = ((x==1)&(y==0)).any(dim=1).any(dim=1)
new_X = self.compute_new_X(output, x, idx, y, train=train, mask_adapt_th=(~mask_bad_x))
solved_mask = (new_X == y).all(dim=1).all(dim=1)
new_X = new_X[~solved_mask]
x = x[~solved_mask]
y = y[~solved_mask]
x_idx = x_idx[~solved_mask]
counters = counters[~solved_mask]
self.log(
f"{prefix}_resolved_grid_count",
solved_mask.sum(),
on_epoch=True,
on_step=False,
reduce_fx=torch.sum,
)
mask_no_improve = new_X.sum(dim=(1, 2)) <= x.sum(dim=(1, 2))
self.log(
f"{prefix}_improved_grid_count_{idx}",
(~mask_no_improve).sum(),
on_epoch=True,
on_step=False,
reduce_fx=torch.sum,
)
# self.store_new_x(idx, new_X, x, y) # TODO create another function (need to increment counter and validate)
self.process_validation(idx, new_X, x, y, x_idx, counters)
def process_validation(self, idx, new_X, x, y, x_idx, counters):
new_X = self.redresse_new_X(new_X,y,x)
mask_validated = self.validate_grids(new_X)
# mask_not_validated = (~self.validate_grids(new_X)) & ((x==0)&(y==1)).any(dim=(1,2))
mask_improve = (new_X.sum(dim=(1, 2)) > x.sum(dim=(1, 2))) & mask_validated
mask_not_improved = (new_X.sum(dim=(1, 2)) == x.sum(dim=(1, 2))) & mask_validated
for i, (failed_idx, failed_counter, s_new_X, s_y) in enumerate(zip(
x_idx[~mask_validated],
counters[~mask_validated],
new_X[~mask_validated],
y[~mask_validated],
)):
# when we find failed:
# - we store good grid to continue the process # /!\ it is not necessary, the second half will continue to process.
# - we store the initial grid with the score (to traine the regressor)
if failed_idx == 0:
self.failed_batch = (x[~mask_validated][i], s_y)
raise ValueError("validation error on no trial-error grid")
if not ((x[~mask_validated][i]==0)&(s_y==1)).any():
raise ValueError()
is_pos = copysign(1, failed_idx)==1
trial_grid: TrialGrid = self.trial_grids[int(abs(failed_idx))]
if is_pos:
trial_grid.pos_result = 'fail'
else:
trial_grid.neg_result = 'fail'
self.process_search_store_grid(int(abs(failed_idx)), trial_grid, s_y)
if idx == self.nets_number - 1:
for no_improved_idx, s_new_X, s_y in zip(
x_idx[mask_not_improved], new_X[mask_not_improved], y[mask_not_improved]
):
if no_improved_idx == 0:
self.search_trial_buffer_trials(s_new_X, s_y)
continue
is_pos = copysign(1, no_improved_idx)==1
trial_grid: TrialGrid = self.trial_grids[int(abs(no_improved_idx.item()))]
if is_pos:
trial_grid.pos_result = 'no_improved'
else:
trial_grid.neg_result = 'no_improved'
assert s_new_X.sum()> trial_grid.initial_grid.sum()
self.process_search_store_grid(int(abs(no_improved_idx)), trial_grid, s_y)
self.buffer.append(
idx + 1,
(new_X[mask_not_improved].clone(), y[mask_not_improved].clone(), x_idx[mask_not_improved].clone(), counters[mask_not_improved].clone()),
)
# assert mask_improve.sum()>0
if ((new_X[mask_improve & (x_idx.to(self.device)==0)]==1) & (y[mask_improve & (x_idx.to(self.device)==0)]==0)).any():
self.failed_batch=(x[mask_improve & (x_idx.to(self.device)==0)],y[mask_improve & (x_idx.to(self.device)==0)] )
raise ValueError()
self.buffer.append(
0,
(new_X[mask_improve].clone(), y[mask_improve].clone(), x_idx[mask_improve].clone(), counters[mask_improve].clone() + 1),
)
def process_search_store_grid(self, idx, trial_grid: TrialGrid, s_y):
"""_summary_
if score is 1:
great
if fail -> the second one should continue (it has his id if it stopped)
so do nothing
if no_improved ->
trial_error and reset trial_error_grid
if score is not None -> store the new grid in the trial_error_buffer
if score is -1 => also search_trial and store.
if store is 1 =>
if both result are here:
get the no_improved -> search_trial and store on a new grid
if one complete grid -> set grid place to None
else: wait
if score is None:
if grid_idx==-1 or oposite_grid failed:
we create a new_idx, and store stuff.
else:
we store the grid (in case the second grid fail)
we increment the non_improvement counter
if non_improvement counter = 2:
we add the initial grid to the search training buffer
we process the search training engine to find another grid postion
else:
we add the initial grid to the search training buffer
if a non improved grid is store we create a new_idx and store stuf.
Args:
grid_idx (_type_): _description_
score (_type_): _description_
s_new_X (_type_): _description_
s_y (_type_): _description_
"""
score = trial_grid.score()
if score is None:
self.trial_grids[idx]=trial_grid
return
# add grid to buffer (initial_grid, score)
self.trial_error_buffer.append((
trial_grid.initial_grid.view(-1,2,729),
torch.tensor([score,],dtype=torch.float).to(self.device),
torch.tensor([trial_grid.row_col_digit_position,], dtype=torch.long).to(self.device),
))
# find the no_improve_grid ~and search_trial~ and add it to buffer
if trial_grid.neg_result == 'no_improved':
if trial_grid.pos_result == 'no_improved':
trial_grid.tried_grid.append(trial_grid.row_col_digit_position)
trial_grid.neg_result= None
trial_grid.pos_result= None
self.trial_grids[idx] = trial_grid
self.search_trial_buffer_trials(None, s_y, idx)
# new trial with same idx
return
# add to buffer neg grid
# we get back the initial grid
# set the correct row col digit
# add it the buffer
# set trial_grids to None
grid_neg = deepcopy(trial_grid.initial_grid)
grid_neg[0,trial_grid.row_col_digit_position] = 1
if ((grid_neg==1) & (s_y==0)).any():
raise ValueError()
self.buffer.append(
0,
(
grid_neg.view(-1,2,729),
s_y.clone().view(-1,2,729),
torch.tensor([0]),
torch.tensor([0]),
)
)
self.trial_grids[idx] = None
return
if trial_grid.pos_result == 'no_improved':
grid_pos = deepcopy(trial_grid.initial_grid)
grid_pos[1,trial_grid.row_col_digit_position] = 1
if ((grid_pos==1) & (s_y==0)).any():
raise ValueError()
self.buffer.append(
0,
(
grid_pos.view(-1,2,729),
s_y.clone().view(-1,2,729),
torch.tensor([0]),
torch.tensor([0]),
)
)
self.trial_grids[idx] = None
# add to buffer pos grid
return
# if complete: replace grid by none.
if "complete" in [trial_grid.neg_result, trial_grid.pos_result]:
self.trial_grids[idx]=None
# def store_new_trial_error_grid(self, new_X, y):
# """build a new idx add the grid in the tracking stuff
# and add grid in the buffer
# Args:
# new_X (_type_): _description_
# y (_type_): _description_
# """
# ...
# def store_training_trail_search_batch(self, grid, score):
# """store grid to train trial_search nn model
# Args:
# grid (_type_): _description_
# score (_type_): _description_
# """
# ...
def search_trial(self, s_new_X, tried_pos):
"""use the trail_search nn model to probe a new
Args:
s_new_X (_type_): _description_
tried_pos (_type_): _description_
"""
mask_possibility = s_new_X.sum(dim=0)==0
for pos in tried_pos:
mask_possibility[pos]=False
if mask_possibility.sum()==0:
print('mask_possible=0')
raise TrialEveryPosException()
with torch.no_grad():
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
output = self.deep_backtrack_regressor(x_reg)
# shape (1, 729)
# can be regression -> i want the smallest
# can be logistic regression -> i want the smallest
# if i do softmax -> i can add 1 to each tried pos
output = torch.softmax(output[0][0],dim=0)
# for pos in tried_pos:
# output[pos]=1
output[~mask_possibility]+=1
return torch.argmin(output, dim=0).item()
def search_trial_buffer_trials(self, s_new_X, s_y, idx_trial_grids=None):
if idx_trial_grids is None:
row_col_digit_trial = self.search_trial(s_new_X, [])
trial_grid = TrialGrid(s_new_X, row_col_digit_trial)
self.trial_grids.append(TrialGrid(s_new_X, row_col_digit_trial))
idx_trial_grids = len(self.trial_grids)-1
else:
trial_grid = self.trial_grids[idx_trial_grids]
s_new_X = trial_grid.initial_grid
row_col_digit_trial = self.search_trial(s_new_X, trial_grid.tried_grid)
trial_grid.row_col_digit_position = row_col_digit_trial
self.trial_grids[idx_trial_grids] = trial_grid
# and we add both into buffer.
grid_pos = deepcopy(s_new_X)
grid_neg = deepcopy(s_new_X)
grid_pos[1,row_col_digit_trial] = 1
grid_neg[0,row_col_digit_trial] = 1
self.buffer.append(
0,
(
torch.stack([grid_pos,grid_neg], dim=0),
torch.stack([s_y.clone(),s_y.clone()], dim=0),
torch.tensor([idx_trial_grids, -idx_trial_grids]),
torch.tensor([0, 0]),
)
)
def trial_error_training_step(self, batch, train=True):
x, y, row_col_digit = batch
prefix = "train" if train else "val"
self.log(
f"{prefix}_grid_count_trial_error_training",
batch[0].shape[0],
reduce_fx=torch.sum,
on_epoch=True,
on_step=False,
)
x_reg = self.sym_preprocess.forward(x)
output = self.deep_backtrack_regressor(x_reg)
loss = nn.functional.binary_cross_entropy_with_logits(output[[i for i in range(self.batch_size)], 0, row_col_digit], y, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
# # depending the distribution of the target, the weight could be different
# loss = binary (output[:,0,row_col_digit], y)
# loss = self.computing_loss(x, y, output)
if train:
opt = self.optimizers()[-1]
opt.zero_grad()
self.manual_backward(loss)
opt.step()
self.log(f"{prefix}_loss_trial_error", loss)
self.log(f"{prefix}_loss_{self.nets_number}", loss)
self.log(f"{prefix}_y_pos_trial_error", y.sum())
self.log(f"{prefix}_y_neg_trial_eror", y.shape[0]-y.sum())
def predict(self, x, func_text_display=None):
""" return an improvement of x
"""
idx, new_X = self.forward(x.view(-1,2,729))
if (new_X.sum()>x.sum()) or (new_X.sum()==729):
if func_text_display:
func_text_display(f'boost layer step: {idx}')
return new_X
else:
# call trial error until we find a solution
tried_position = []
while True:
pos = self.search_trial(x.view(2,729), tried_position)
tried_position.append(pos)
# creat pos neg tensor
grid_pos = deepcopy(x.view(2,729))
grid_neg = deepcopy(x.view(2,729))
grid_pos[1,pos] = 1
grid_neg[0,pos] = 1
X_tried = torch.stack([grid_neg, grid_pos], dim=0)
# process it
while True:
idx, new_X = self.forward(X_tried)
mask_validated = self.validate_grids(new_X)
if mask_validated.sum()<2:
x[0, mask_validated, pos] = 1 # TODO check if it work
if func_text_display:
digit, col, row = pos_to_digit_col_row(pos)
func_text_display('model failed to improve the grid')
func_text_display(f'trial error alogorithm, found error at digit: {digit}, col: {col}, row: {row}')
return x
if X_tried.sum()==new_X.sum():
# if both stop to improve -> break it will tried an new pos
break
mask_complete = (X_tried.sum(dim=1)==729)# check if it works
if mask_complete.sum()>0:
x[0, mask_complete, pos] = 1
if func_text_display:
digit, col, row = pos_to_digit_col_row(pos)
func_text_display('model failed to improve the grid')
func_text_display(f'trial error alogorithm, complete the grid with digit: {digit}, col: {col}, row: {row}')
return x
X_tried = new_X
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
def backtracking_predict(self, x, use_trial_error=False, assumption=[], func_text_display=None, func_tensor_display=None):
"""
return is_valid, new_x
"""
next_X = deepcopy(x)
sum_1 = next_X.sum().item()
if use_trial_error:
while True:
try:
next_X = self.predict(next_X.view(1,2,729))
if not self.validate_grids(next_X)[0].item():
if assumption is not []:
func_text_display(f'assumption {assumption[-1]}, failed')
func_text_display(f'assumption length {len(assumption)}')
func_tensor_display(next_X)
return False, None
if next_X.sum()>=729:
if assumption is not []:
func_text_display(f'assumption length {len(assumption)}')
func_tensor_display(next_X)
return True, next_X
except TrialEveryPosException:
break
else:
while True:
_idx, next_X = self.forward(next_X.view(1,2,729))
if not self.validate_grids(next_X)[0].item():
return False, None
sum_2 = next_X.sum().item()
if sum_1==sum_2:
break
else:
sum_1=sum_2
if next_X.sum()>=729:
return True, next_X
pos = self.search_trial(next_X.view(2,729), [])
new_x = deepcopy(next_X.view(1,2,729))
output = self.forward_layer(next_X.view(1,2,729))
pos_output = output[0,:,pos]
if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item()-self.threshold_pres[0].item()):
trial_abs_pres = 0
abs_pres_str = 'abs'
else:
trial_abs_pres = 1
abs_pres_str = 'pres'
new_x[0,trial_abs_pres,pos]=1
# print(f'assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), trial_abs_pres)}')# i+j*9+n*9*9
digit, col, row = pos_to_digit_col_row(pos)
func_text_display('grid not improving adding the assuption')
current_assumption = f'{abs_pres_str}, digit: {digit}, col: {col}, row: {row}'
func_text_display(current_assumption)
assumption.append(current_assumption)
is_valid, new_x = self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display)
if not is_valid:
new_x = deepcopy(next_X.view(1,2,729))
new_x[0,(trial_abs_pres+1)%2,pos]=1
# print(f'finally assuming {(pos, pos%9, (pos//9)%9, pos//(9*9), (trial_abs_pres+1)%2)}')# i+j*9+n*9*9
current_assumption = f'{"pres" if abs_pres_str=="abs" else "abs"}, digit: {digit}, col: {col}, row: {row}'
func_text_display(current_assumption)
return self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display)
return True, new_x
def on_validation_epoch_start(self) -> None:
# self.buffer = BufferArray(self.nets_number, self.batch_size)
self.trial_error_buffer = Buffer(self.batch_size)
self.trial_grids = [None]
return super().on_validation_epoch_start()
def on_train_epoch_start(self) -> None:
self.trial_error_buffer = Buffer(self.batch_size)
self.trial_grids = [None]
return super().on_train_epoch_start()
def redresse_new_X(self, new_X,y,x):
mask_bad_x = ((x==1)&(y==0)).any(dim=1).any(dim=1)
y_bis = y.clone()
y_bis[mask_bad_x]=1
new_X[y_bis==0]=0
return new_X
# ADD threshold adjustment during prediction
# or maybe validate? on it? bah oui!