|
from math import copysign |
|
import torch |
|
from torch import nn |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
import pytorch_lightning as pl |
|
from sudoku.models import SmallNetBis, SymPreprocess |
|
import torch.nn.functional as F |
|
|
|
from sudoku.buffer import BufferArray, Buffer |
|
from sudoku.trial_grid import TrialGrid |
|
from sudoku.helper import pos_to_digit_col_row |
|
|
|
from copy import deepcopy |
|
|
|
|
|
class SudokuLightning(pl.LightningModule): |
|
def __init__( |
|
self, |
|
lr=0.1, |
|
margin=0.1, |
|
coef_0=10, |
|
nets_number=6, |
|
nets_training_number=1, |
|
batch_size=32, |
|
): |
|
super().__init__() |
|
self.nets_number = nets_number |
|
self.batch_size = batch_size |
|
self.nets_training_number = nets_training_number |
|
|
|
self.nets = nn.ModuleList([SmallNetBis() for _ in range(self.nets_number)]) |
|
self.buffer = BufferArray(self.nets_number, self.batch_size) |
|
self.sym_preprocess = SymPreprocess() |
|
pos_weight = torch.ones((2, 9 * 9 * 9)) |
|
pos_weight[0, :] = 1.0 / 8.0 |
|
pos_weight[1, :] = 1.0 |
|
pos_weight /= coef_0 |
|
weight = torch.ones((2, 9 * 9 * 9)) |
|
weight[0, :] = 8.0 |
|
weight[1, :] = 1.0 |
|
weight *= coef_0 |
|
|
|
self.bcewll = nn.BCEWithLogitsLoss( |
|
pos_weight=pos_weight, weight=weight, reduce=False |
|
) |
|
self.lr = lr |
|
|
|
|
|
self.margin = margin |
|
self.th_epsilon = margin * 0.01 |
|
self.threshold_pres = torch.tensor([-10.0 for _ in range(nets_number)]) |
|
self.threshold_abs = torch.tensor([-10.0 for _ in range(nets_number)]) |
|
|
|
self.automatic_optimization = False |
|
self.reset_threshold_on_validation = True |
|
|
|
def configure_optimizers(self): |
|
|
|
optimizers = [] |
|
for net in self.nets: |
|
opti = torch.optim.Adam(net.parameters(), lr=self.lr) |
|
optimizers.append( |
|
{ |
|
"optimizer": opti, |
|
"lr_scheduler": ReduceLROnPlateau(opti, "min"), |
|
} |
|
) |
|
return optimizers |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_layer(self, x, idx=0): |
|
x = self.sym_preprocess.forward(x) |
|
return self.nets[idx](x) |
|
|
|
def forward(self, x): |
|
for idx in range(self.nets_number): |
|
output = self.forward_layer(x, idx) |
|
new_X = self.compute_new_X(output, x, idx, None, train=False) |
|
improved_mask = ((new_X == 1) & (x == 0)).any(dim=1).any(dim=1) |
|
if improved_mask.sum() > 0: |
|
return idx, new_X |
|
return idx, new_X |
|
|
|
def predict_from_net(self, x, net, th_abs, th_pres): |
|
x = self.sym_preprocess.forward(x) |
|
x = net(x) |
|
new_x = torch.empty(x.shape, device=x.device) |
|
new_x[:, 0] = (x[:, 0] > th_abs).float() |
|
new_x[:, 1] = (x[:, 1] > th_pres).float() |
|
return new_x |
|
|
|
@staticmethod |
|
def mask_uncomplete(x, y): |
|
mask_uncomplete = x.reshape(-1, 2, 9, 9, 9).sum(-1) < torch.tensor((8, 1)).to( |
|
x |
|
).reshape(1, 2, 1, 1) |
|
mask_uncomplete = mask_uncomplete.reshape(-1, 2, 9, 9, 1) |
|
mask = ((x == 0).reshape(-1, 2, 9, 9, 9) * mask_uncomplete).reshape( |
|
-1, 2, 9**3 |
|
) |
|
mask = mask.float() |
|
return mask |
|
|
|
def computing_loss(self, x, y, output): |
|
loss = self.bcewll(output, y) |
|
mask = self.mask_uncomplete(x, y) |
|
loss = (loss * mask).sum() |
|
|
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
self.log( |
|
"train_grid_count", |
|
batch[0].shape[0], |
|
reduce_fx=torch.sum, |
|
on_epoch=True, |
|
on_step=False, |
|
) |
|
|
|
self.layer_training_step(0, batch) |
|
while True: |
|
idx, batch = self.buffer.get_batch() |
|
if batch is None: |
|
break |
|
|
|
|
|
self.layer_training_step(idx, batch) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
self.layer_training_step(0, batch, train=False) |
|
while True: |
|
idx, batch = self.buffer.get_batch() |
|
if batch is None: |
|
break |
|
|
|
|
|
self.layer_training_step(idx, batch, train=False) |
|
|
|
def layer_training_step( |
|
self, idx, batch, train=True |
|
): |
|
x, y = batch |
|
|
|
prefix = "train" if train else "val" |
|
self.log( |
|
f"{prefix}_grid_count_{idx}", |
|
batch[0].shape[0], |
|
reduce_fx=torch.sum, |
|
on_epoch=True, |
|
on_step=False, |
|
) |
|
|
|
output = self.forward_layer(x, idx) |
|
loss = self.computing_loss(x, y, output) |
|
if train: |
|
|
|
opt = self.optimizers() |
|
if isinstance(opt, list): |
|
opt=opt[idx] |
|
opt.zero_grad() |
|
self.manual_backward(loss) |
|
opt.step() |
|
|
|
loss_0 = F.binary_cross_entropy_with_logits(output[:, [0], :], y[:, [0], :]) |
|
loss_1 = F.binary_cross_entropy_with_logits(output[:, [1], :], y[:, [1], :]) |
|
self.log_dict( |
|
{f"{prefix}_loss_pos": loss_1, f"{prefix}_loss_neg": loss_0}, on_epoch=True |
|
) |
|
|
|
|
|
|
|
|
|
self.log(f"{prefix}_loss_{idx}", loss) |
|
|
|
|
|
new_X = self.compute_new_X(output, x, idx, y, train=train) |
|
solved_mask = (new_X == y).all(dim=1).all(dim=1) |
|
new_X = new_X[~solved_mask] |
|
x = x[~solved_mask] |
|
y = y[~solved_mask] |
|
self.log( |
|
f"{prefix}_resolved_grid_count", |
|
solved_mask.sum(), |
|
on_epoch=True, |
|
on_step=False, |
|
reduce_fx=torch.sum, |
|
) |
|
mask_no_improve = new_X.sum(dim=(1, 2)) <= x.sum(dim=(1, 2)) |
|
self.log( |
|
f"{prefix}_improved_grid_count_{idx}", |
|
(~mask_no_improve).sum(), |
|
on_epoch=True, |
|
on_step=False, |
|
reduce_fx=torch.sum, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.store_new_x(idx, new_X, x, y) |
|
|
|
def store_new_x(self, idx, new_X, x, y): |
|
mask_improve = new_X.sum(dim=(1, 2)) > x.sum(dim=(1, 2)) |
|
self.buffer.append( |
|
idx + 1, (new_X[~mask_improve].clone(), y[~mask_improve].clone()) |
|
) |
|
self.buffer.append(0, (new_X[mask_improve].clone(), y[mask_improve].clone())) |
|
|
|
|
|
def compute_new_X(self, output, x, idx, y=None, train=True, mask_adapt_th=None): |
|
|
|
prefix = "train" if train else "val" |
|
new_X = torch.empty(output.shape, device=output.device) |
|
|
|
if y is not None: |
|
|
|
|
|
|
|
|
|
max_th_abs = output[:, 0][(y[:, 0] == 0)].max().item() + self.th_epsilon |
|
max_th_pres = ( |
|
output[:, 1][(x[:, 1] == 0) & (y[:, 1] == 0)].max().item() |
|
+ self.th_epsilon |
|
) |
|
if mask_adapt_th is None or (mask_adapt_th.sum()>0): |
|
if mask_adapt_th is not None and (mask_adapt_th.sum()>0): |
|
max_th_abs = output[mask_adapt_th, 0][(y[mask_adapt_th, 0] == 0)].max().item() + self.th_epsilon |
|
max_th_pres = ( |
|
output[mask_adapt_th, 1][(x[mask_adapt_th, 1] == 0) & (y[mask_adapt_th, 1] == 0)].max().item() |
|
+ self.th_epsilon |
|
) |
|
self.threshold_abs[idx] = max(max_th_abs, self.threshold_abs[idx]) |
|
self.threshold_pres[idx] = max(max_th_pres, self.threshold_pres[idx]) |
|
self.log_dict( |
|
{ |
|
f"{prefix}_th_abs_{idx}": self.threshold_abs[idx], |
|
f"{prefix}_th_pres_{idx}": self.threshold_pres[idx], |
|
}, |
|
on_step=True, |
|
) |
|
if not train: |
|
self.threshold_abs_compute[idx] = max( |
|
max_th_abs + self.margin, self.threshold_abs_compute[idx] |
|
) |
|
self.threshold_pres_compute[idx] = max( |
|
max_th_pres + self.margin, self.threshold_pres_compute[idx] |
|
) |
|
|
|
if self.training: |
|
new_X[:, 0] = (output[:, 0].detach() > self.threshold_abs[idx]).float() |
|
new_X[:, 1] = (output[:, 1].detach() > self.threshold_pres[idx]).float() |
|
else: |
|
new_X[:, 0] = (output[:, 0].detach() > self.threshold_abs[idx]).float() |
|
new_X[:, 1] = (output[:, 1].detach() > self.threshold_pres[idx]).float() |
|
new_X[x.detach() == 1] = 1 |
|
if y is not None: |
|
self.log( |
|
f"{prefix}_count_error_grid_{idx}", |
|
((new_X == 1) & (y == 0)).any(dim=1).any(dim=1).sum(), |
|
on_epoch=True, |
|
on_step=False, |
|
reduce_fx=torch.sum, |
|
) |
|
if mask_adapt_th is None: |
|
new_X[y.detach() == 0] = 0 |
|
else: |
|
y_bis = y.detach().clone() |
|
y_bis[~mask_adapt_th]=1 |
|
new_X[y_bis==0] = 0 |
|
return new_X |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_validation_epoch_start(self) -> None: |
|
if self.reset_threshold_on_validation: |
|
self.threshold_abs_compute = torch.tensor( |
|
[-10.0 for _ in range(self.nets_number)] |
|
) |
|
self.threshold_pres_compute = torch.tensor( |
|
[-10.0 for _ in range(self.nets_number)] |
|
) |
|
else: |
|
self.threshold_abs_compute = self.threshold_abs |
|
self.threshold_pres_compute = self.threshold_pres |
|
|
|
self.buffer = BufferArray(self.nets_number, self.batch_size) |
|
|
|
def on_train_epoch_start(self) -> None: |
|
self.buffer = BufferArray(self.nets_number, self.batch_size) |
|
return super().on_train_epoch_start() |
|
|
|
def on_validation_epoch_end(self): |
|
|
|
self.threshold_abs = self.threshold_abs_compute |
|
self.threshold_pres = self.threshold_pres_compute |
|
|
|
schs = self.lr_schedulers() |
|
if not isinstance(schs, list): |
|
schs=[schs] |
|
for idx, sch in enumerate(schs): |
|
|
|
try: |
|
sch.step(self.trainer.callback_metrics[f"val_loss_{idx}"]) |
|
except: |
|
|
|
pass |
|
|
|
|
|
def on_save_checkpoint(self, checkpoint) -> None: |
|
"Objects to include in checkpoint file" |
|
checkpoint["ths_abs"] = self.threshold_abs |
|
checkpoint["ths_pres"] = self.threshold_pres |
|
|
|
def on_load_checkpoint(self, checkpoint) -> None: |
|
"Objects to retrieve from checkpoint file" |
|
self.threshold_abs = checkpoint["ths_abs"] |
|
self.threshold_pres = checkpoint["ths_pres"] |
|
self.nets = nn.ModuleList([SmallNetBis() for _ in self.threshold_abs]) |
|
|
|
def validate_grids(self, x) -> "torch.tensor": |
|
return ~( |
|
(self.sym_preprocess(x)[:, 17].max(dim=1).values > (1 / 8)) |
|
| (self.sym_preprocess(x)[:, 18].max(dim=1).values > (1 / 8)) |
|
| (self.sym_preprocess(x)[:, 19].max(dim=1).values > (1 / 8)) |
|
| (x.view(-1,2,9,9,9)[:,1].sum(dim=-1)>1).any(dim=1).any(dim=1) |
|
| (x.view(-1,2,9,9,9)[:,0].sum(dim=-1)>8).any(dim=1).any(dim=1) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrialEveryPosException(Exception): |
|
pass |
|
|
|
class SudokuTrialErrorLightning(SudokuLightning): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.deep_backtrack_regressor = SmallNetBis(n_output=1) |
|
self.trial_error_buffer = Buffer(self.batch_size) |
|
self.trial_grids = [None] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy_from_model(self, model): |
|
self.nets = model.nets |
|
self.threshold_pres = model.threshold_pres |
|
self.threshold_abs = model.threshold_abs |
|
|
|
def reg(self, x): |
|
x_reg = self.sym_preprocess.forward(x) |
|
x_reg = self.deep_backtrack_regressor(x_reg) |
|
return torch.softmax(x_reg, dim=1) |
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
optimizers = [] |
|
for net in self.nets: |
|
opti = torch.optim.Adam(net.parameters(), lr=self.lr) |
|
optimizers.append( |
|
{ |
|
"optimizer": opti, |
|
"lr_scheduler": ReduceLROnPlateau(opti, "min"), |
|
} |
|
) |
|
optimizers.append( |
|
{ |
|
'optimizer': torch.optim.Adam(self.deep_backtrack_regressor.parameters(), lr=self.lr), |
|
"lr_scheduler": ReduceLROnPlateau(opti, "min"), |
|
} |
|
) |
|
return optimizers |
|
|
|
def training_step(self, batch, batch_idx): |
|
self.log( |
|
"train_grid_count", |
|
batch[0].shape[0], |
|
reduce_fx=torch.sum, |
|
on_epoch=True, |
|
on_step=False, |
|
) |
|
x, y = batch |
|
x_idx = torch.zeros(self.batch_size) |
|
counters = torch.zeros(self.batch_size) |
|
|
|
self.layer_training_step(0, (x, y, x_idx, counters)) |
|
idx_while=0 |
|
while True: |
|
idx_while+=1 |
|
if idx_while ==10000: |
|
print('a while') |
|
idx, batch = self.buffer.get_batch() |
|
if batch is None: |
|
break |
|
|
|
|
|
self.layer_training_step(idx, batch) |
|
|
|
while True: |
|
trial_error_batch = self.trial_error_buffer.get_batch() |
|
if trial_error_batch is None: |
|
break |
|
self.trial_error_training_step(trial_error_batch) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
x, y = batch |
|
x_idx = torch.zeros(x.shape[0], dtype=torch.long) |
|
counters = torch.zeros(x.shape[0]) |
|
|
|
self.layer_training_step(0, (x, y, x_idx, counters), train=False) |
|
while True: |
|
idx, batch = self.buffer.get_batch() |
|
if batch is None: |
|
break |
|
|
|
|
|
self.layer_training_step(idx, batch, train=False) |
|
|
|
while True: |
|
trial_error_batch = self.trial_error_buffer.get_batch() |
|
if trial_error_batch is None: |
|
break |
|
self.trial_error_training_step(trial_error_batch, train=False) |
|
|
|
|
|
def layer_training_step( |
|
self, idx, batch, train=True |
|
): |
|
x, y, x_idx, counters = batch |
|
|
|
prefix = "train" if train else "val" |
|
self.log( |
|
f"{prefix}_grid_count_{idx}", |
|
batch[0].shape[0], |
|
reduce_fx=torch.sum, |
|
on_epoch=True, |
|
on_step=False, |
|
) |
|
output = self.forward_layer(x, idx) |
|
loss = self.computing_loss(x[x_idx==0], y[x_idx==0], output[x_idx==0]) |
|
if train: |
|
pass |
|
opt = self.optimizers()[idx] |
|
opt.zero_grad() |
|
self.manual_backward(loss) |
|
opt.step() |
|
|
|
loss_0 = F.binary_cross_entropy_with_logits(output[:, [0], :], y[:, [0], :]) |
|
loss_1 = F.binary_cross_entropy_with_logits(output[:, [1], :], y[:, [1], :]) |
|
self.log_dict( |
|
{f"{prefix}_loss_pos": loss_1, f"{prefix}_loss_neg": loss_0}, on_epoch=True |
|
) |
|
self.log(f"{prefix}_loss_{idx}", loss) |
|
|
|
mask_bad_x = ((x==1)&(y==0)).any(dim=1).any(dim=1) |
|
new_X = self.compute_new_X(output, x, idx, y, train=train, mask_adapt_th=(~mask_bad_x)) |
|
solved_mask = (new_X == y).all(dim=1).all(dim=1) |
|
|
|
new_X = new_X[~solved_mask] |
|
x = x[~solved_mask] |
|
y = y[~solved_mask] |
|
x_idx = x_idx[~solved_mask] |
|
counters = counters[~solved_mask] |
|
|
|
self.log( |
|
f"{prefix}_resolved_grid_count", |
|
solved_mask.sum(), |
|
on_epoch=True, |
|
on_step=False, |
|
reduce_fx=torch.sum, |
|
) |
|
mask_no_improve = new_X.sum(dim=(1, 2)) <= x.sum(dim=(1, 2)) |
|
self.log( |
|
f"{prefix}_improved_grid_count_{idx}", |
|
(~mask_no_improve).sum(), |
|
on_epoch=True, |
|
on_step=False, |
|
reduce_fx=torch.sum, |
|
) |
|
|
|
self.process_validation(idx, new_X, x, y, x_idx, counters) |
|
|
|
def process_validation(self, idx, new_X, x, y, x_idx, counters): |
|
new_X = self.redresse_new_X(new_X,y,x) |
|
mask_validated = self.validate_grids(new_X) |
|
|
|
mask_improve = (new_X.sum(dim=(1, 2)) > x.sum(dim=(1, 2))) & mask_validated |
|
mask_not_improved = (new_X.sum(dim=(1, 2)) == x.sum(dim=(1, 2))) & mask_validated |
|
|
|
|
|
for i, (failed_idx, failed_counter, s_new_X, s_y) in enumerate(zip( |
|
x_idx[~mask_validated], |
|
counters[~mask_validated], |
|
new_X[~mask_validated], |
|
y[~mask_validated], |
|
)): |
|
|
|
|
|
|
|
if failed_idx == 0: |
|
self.failed_batch = (x[~mask_validated][i], s_y) |
|
raise ValueError("validation error on no trial-error grid") |
|
if not ((x[~mask_validated][i]==0)&(s_y==1)).any(): |
|
raise ValueError() |
|
is_pos = copysign(1, failed_idx)==1 |
|
trial_grid: TrialGrid = self.trial_grids[int(abs(failed_idx))] |
|
if is_pos: |
|
trial_grid.pos_result = 'fail' |
|
else: |
|
trial_grid.neg_result = 'fail' |
|
|
|
self.process_search_store_grid(int(abs(failed_idx)), trial_grid, s_y) |
|
|
|
if idx == self.nets_number - 1: |
|
for no_improved_idx, s_new_X, s_y in zip( |
|
x_idx[mask_not_improved], new_X[mask_not_improved], y[mask_not_improved] |
|
): |
|
if no_improved_idx == 0: |
|
self.search_trial_buffer_trials(s_new_X, s_y) |
|
continue |
|
|
|
is_pos = copysign(1, no_improved_idx)==1 |
|
trial_grid: TrialGrid = self.trial_grids[int(abs(no_improved_idx.item()))] |
|
if is_pos: |
|
trial_grid.pos_result = 'no_improved' |
|
else: |
|
trial_grid.neg_result = 'no_improved' |
|
assert s_new_X.sum()> trial_grid.initial_grid.sum() |
|
|
|
self.process_search_store_grid(int(abs(no_improved_idx)), trial_grid, s_y) |
|
|
|
self.buffer.append( |
|
idx + 1, |
|
(new_X[mask_not_improved].clone(), y[mask_not_improved].clone(), x_idx[mask_not_improved].clone(), counters[mask_not_improved].clone()), |
|
) |
|
|
|
if ((new_X[mask_improve & (x_idx.to(self.device)==0)]==1) & (y[mask_improve & (x_idx.to(self.device)==0)]==0)).any(): |
|
self.failed_batch=(x[mask_improve & (x_idx.to(self.device)==0)],y[mask_improve & (x_idx.to(self.device)==0)] ) |
|
raise ValueError() |
|
self.buffer.append( |
|
0, |
|
(new_X[mask_improve].clone(), y[mask_improve].clone(), x_idx[mask_improve].clone(), counters[mask_improve].clone() + 1), |
|
) |
|
|
|
def process_search_store_grid(self, idx, trial_grid: TrialGrid, s_y): |
|
"""_summary_ |
|
if score is 1: |
|
great |
|
if fail -> the second one should continue (it has his id if it stopped) |
|
so do nothing |
|
if no_improved -> |
|
trial_error and reset trial_error_grid |
|
if score is not None -> store the new grid in the trial_error_buffer |
|
|
|
if score is -1 => also search_trial and store. |
|
if store is 1 => |
|
if both result are here: |
|
get the no_improved -> search_trial and store on a new grid |
|
if one complete grid -> set grid place to None |
|
else: wait |
|
|
|
|
|
if score is None: |
|
if grid_idx==-1 or oposite_grid failed: |
|
we create a new_idx, and store stuff. |
|
else: |
|
we store the grid (in case the second grid fail) |
|
we increment the non_improvement counter |
|
|
|
if non_improvement counter = 2: |
|
we add the initial grid to the search training buffer |
|
we process the search training engine to find another grid postion |
|
else: |
|
we add the initial grid to the search training buffer |
|
if a non improved grid is store we create a new_idx and store stuf. |
|
|
|
Args: |
|
grid_idx (_type_): _description_ |
|
score (_type_): _description_ |
|
s_new_X (_type_): _description_ |
|
s_y (_type_): _description_ |
|
""" |
|
score = trial_grid.score() |
|
if score is None: |
|
self.trial_grids[idx]=trial_grid |
|
return |
|
|
|
self.trial_error_buffer.append(( |
|
trial_grid.initial_grid.view(-1,2,729), |
|
torch.tensor([score,],dtype=torch.float).to(self.device), |
|
torch.tensor([trial_grid.row_col_digit_position,], dtype=torch.long).to(self.device), |
|
)) |
|
|
|
|
|
if trial_grid.neg_result == 'no_improved': |
|
if trial_grid.pos_result == 'no_improved': |
|
trial_grid.tried_grid.append(trial_grid.row_col_digit_position) |
|
trial_grid.neg_result= None |
|
trial_grid.pos_result= None |
|
self.trial_grids[idx] = trial_grid |
|
self.search_trial_buffer_trials(None, s_y, idx) |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
grid_neg = deepcopy(trial_grid.initial_grid) |
|
grid_neg[0,trial_grid.row_col_digit_position] = 1 |
|
if ((grid_neg==1) & (s_y==0)).any(): |
|
raise ValueError() |
|
self.buffer.append( |
|
0, |
|
( |
|
grid_neg.view(-1,2,729), |
|
s_y.clone().view(-1,2,729), |
|
torch.tensor([0]), |
|
torch.tensor([0]), |
|
) |
|
) |
|
self.trial_grids[idx] = None |
|
return |
|
if trial_grid.pos_result == 'no_improved': |
|
grid_pos = deepcopy(trial_grid.initial_grid) |
|
grid_pos[1,trial_grid.row_col_digit_position] = 1 |
|
if ((grid_pos==1) & (s_y==0)).any(): |
|
raise ValueError() |
|
self.buffer.append( |
|
0, |
|
( |
|
grid_pos.view(-1,2,729), |
|
s_y.clone().view(-1,2,729), |
|
torch.tensor([0]), |
|
torch.tensor([0]), |
|
) |
|
) |
|
self.trial_grids[idx] = None |
|
|
|
return |
|
|
|
|
|
|
|
if "complete" in [trial_grid.neg_result, trial_grid.pos_result]: |
|
self.trial_grids[idx]=None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search_trial(self, s_new_X, tried_pos): |
|
"""use the trail_search nn model to probe a new |
|
|
|
Args: |
|
s_new_X (_type_): _description_ |
|
tried_pos (_type_): _description_ |
|
""" |
|
mask_possibility = s_new_X.sum(dim=0)==0 |
|
for pos in tried_pos: |
|
mask_possibility[pos]=False |
|
if mask_possibility.sum()==0: |
|
print('mask_possible=0') |
|
raise TrialEveryPosException() |
|
|
|
with torch.no_grad(): |
|
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1)) |
|
output = self.deep_backtrack_regressor(x_reg) |
|
|
|
|
|
|
|
|
|
output = torch.softmax(output[0][0],dim=0) |
|
|
|
|
|
output[~mask_possibility]+=1 |
|
return torch.argmin(output, dim=0).item() |
|
|
|
def search_trial_buffer_trials(self, s_new_X, s_y, idx_trial_grids=None): |
|
|
|
if idx_trial_grids is None: |
|
row_col_digit_trial = self.search_trial(s_new_X, []) |
|
trial_grid = TrialGrid(s_new_X, row_col_digit_trial) |
|
self.trial_grids.append(TrialGrid(s_new_X, row_col_digit_trial)) |
|
idx_trial_grids = len(self.trial_grids)-1 |
|
else: |
|
trial_grid = self.trial_grids[idx_trial_grids] |
|
s_new_X = trial_grid.initial_grid |
|
row_col_digit_trial = self.search_trial(s_new_X, trial_grid.tried_grid) |
|
trial_grid.row_col_digit_position = row_col_digit_trial |
|
self.trial_grids[idx_trial_grids] = trial_grid |
|
|
|
|
|
grid_pos = deepcopy(s_new_X) |
|
grid_neg = deepcopy(s_new_X) |
|
grid_pos[1,row_col_digit_trial] = 1 |
|
grid_neg[0,row_col_digit_trial] = 1 |
|
self.buffer.append( |
|
0, |
|
( |
|
torch.stack([grid_pos,grid_neg], dim=0), |
|
torch.stack([s_y.clone(),s_y.clone()], dim=0), |
|
torch.tensor([idx_trial_grids, -idx_trial_grids]), |
|
torch.tensor([0, 0]), |
|
) |
|
) |
|
|
|
def trial_error_training_step(self, batch, train=True): |
|
x, y, row_col_digit = batch |
|
prefix = "train" if train else "val" |
|
self.log( |
|
f"{prefix}_grid_count_trial_error_training", |
|
batch[0].shape[0], |
|
reduce_fx=torch.sum, |
|
on_epoch=True, |
|
on_step=False, |
|
) |
|
|
|
|
|
x_reg = self.sym_preprocess.forward(x) |
|
output = self.deep_backtrack_regressor(x_reg) |
|
loss = nn.functional.binary_cross_entropy_with_logits(output[[i for i in range(self.batch_size)], 0, row_col_digit], y, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None) |
|
|
|
|
|
|
|
if train: |
|
opt = self.optimizers()[-1] |
|
opt.zero_grad() |
|
self.manual_backward(loss) |
|
opt.step() |
|
|
|
self.log(f"{prefix}_loss_trial_error", loss) |
|
self.log(f"{prefix}_loss_{self.nets_number}", loss) |
|
self.log(f"{prefix}_y_pos_trial_error", y.sum()) |
|
self.log(f"{prefix}_y_neg_trial_eror", y.shape[0]-y.sum()) |
|
|
|
def predict(self, x, func_text_display=None): |
|
""" return an improvement of x |
|
|
|
""" |
|
|
|
idx, new_X = self.forward(x.view(-1,2,729)) |
|
if (new_X.sum()>x.sum()) or (new_X.sum()==729): |
|
if func_text_display: |
|
func_text_display(f'boost layer step: {idx}') |
|
return new_X |
|
else: |
|
|
|
tried_position = [] |
|
while True: |
|
pos = self.search_trial(x.view(2,729), tried_position) |
|
tried_position.append(pos) |
|
|
|
grid_pos = deepcopy(x.view(2,729)) |
|
grid_neg = deepcopy(x.view(2,729)) |
|
grid_pos[1,pos] = 1 |
|
grid_neg[0,pos] = 1 |
|
X_tried = torch.stack([grid_neg, grid_pos], dim=0) |
|
|
|
while True: |
|
idx, new_X = self.forward(X_tried) |
|
mask_validated = self.validate_grids(new_X) |
|
if mask_validated.sum()<2: |
|
x[0, mask_validated, pos] = 1 |
|
if func_text_display: |
|
digit, col, row = pos_to_digit_col_row(pos) |
|
func_text_display('model failed to improve the grid') |
|
func_text_display(f'trial error alogorithm, found error at digit: {digit}, col: {col}, row: {row}') |
|
return x |
|
if X_tried.sum()==new_X.sum(): |
|
|
|
break |
|
mask_complete = (X_tried.sum(dim=1)==729) |
|
if mask_complete.sum()>0: |
|
x[0, mask_complete, pos] = 1 |
|
if func_text_display: |
|
digit, col, row = pos_to_digit_col_row(pos) |
|
func_text_display('model failed to improve the grid') |
|
func_text_display(f'trial error alogorithm, complete the grid with digit: {digit}, col: {col}, row: {row}') |
|
return x |
|
X_tried = new_X |
|
|
|
|
|
def backtracking_predict(self, x, use_trial_error=False, assumption=[], func_text_display=None, func_tensor_display=None): |
|
""" |
|
return is_valid, new_x |
|
""" |
|
next_X = deepcopy(x) |
|
sum_1 = next_X.sum().item() |
|
if use_trial_error: |
|
while True: |
|
try: |
|
next_X = self.predict(next_X.view(1,2,729)) |
|
if not self.validate_grids(next_X)[0].item(): |
|
if assumption is not []: |
|
func_text_display(f'assumption {assumption[-1]}, failed') |
|
func_text_display(f'assumption length {len(assumption)}') |
|
func_tensor_display(next_X) |
|
return False, None |
|
if next_X.sum()>=729: |
|
if assumption is not []: |
|
func_text_display(f'assumption length {len(assumption)}') |
|
func_tensor_display(next_X) |
|
return True, next_X |
|
except TrialEveryPosException: |
|
break |
|
|
|
else: |
|
while True: |
|
_idx, next_X = self.forward(next_X.view(1,2,729)) |
|
if not self.validate_grids(next_X)[0].item(): |
|
return False, None |
|
sum_2 = next_X.sum().item() |
|
if sum_1==sum_2: |
|
break |
|
else: |
|
sum_1=sum_2 |
|
if next_X.sum()>=729: |
|
return True, next_X |
|
|
|
pos = self.search_trial(next_X.view(2,729), []) |
|
new_x = deepcopy(next_X.view(1,2,729)) |
|
output = self.forward_layer(next_X.view(1,2,729)) |
|
pos_output = output[0,:,pos] |
|
if (pos_output[0].item()-self.threshold_abs[0].item())>(pos_output[1].item()-self.threshold_pres[0].item()): |
|
trial_abs_pres = 0 |
|
abs_pres_str = 'abs' |
|
else: |
|
trial_abs_pres = 1 |
|
abs_pres_str = 'pres' |
|
new_x[0,trial_abs_pres,pos]=1 |
|
|
|
digit, col, row = pos_to_digit_col_row(pos) |
|
func_text_display('grid not improving adding the assuption') |
|
current_assumption = f'{abs_pres_str}, digit: {digit}, col: {col}, row: {row}' |
|
func_text_display(current_assumption) |
|
assumption.append(current_assumption) |
|
is_valid, new_x = self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display) |
|
if not is_valid: |
|
new_x = deepcopy(next_X.view(1,2,729)) |
|
new_x[0,(trial_abs_pres+1)%2,pos]=1 |
|
|
|
current_assumption = f'{"pres" if abs_pres_str=="abs" else "abs"}, digit: {digit}, col: {col}, row: {row}' |
|
func_text_display(current_assumption) |
|
return self.backtracking_predict(new_x, use_trial_error=use_trial_error, assumption=assumption, func_text_display=func_text_display, func_tensor_display=func_tensor_display) |
|
return True, new_x |
|
|
|
|
|
def on_validation_epoch_start(self) -> None: |
|
|
|
self.trial_error_buffer = Buffer(self.batch_size) |
|
self.trial_grids = [None] |
|
return super().on_validation_epoch_start() |
|
|
|
def on_train_epoch_start(self) -> None: |
|
self.trial_error_buffer = Buffer(self.batch_size) |
|
self.trial_grids = [None] |
|
return super().on_train_epoch_start() |
|
|
|
def redresse_new_X(self, new_X,y,x): |
|
mask_bad_x = ((x==1)&(y==0)).any(dim=1).any(dim=1) |
|
y_bis = y.clone() |
|
y_bis[mask_bad_x]=1 |
|
new_X[y_bis==0]=0 |
|
return new_X |
|
|
|
|
|
|
|
|