Sebastien
improving backtracking algorithm and display in the app
19320e5
import numpy as np
import pandas as pd
import torch.nn.functional as F
x_grid = np.array([x for x in range(1, 10)] * 9 * 9).reshape(9, 9, 3, 3)
df_grid = pd.DataFrame(x_grid.swapaxes(1, 2).reshape(27, 27))
cell = { # for row hover use <tr> instead of <td>
"selector": "td",
"props": [("color", "black"), ("text-align", "center")],
}
line_height_1 = {"selector": "tr", "props": [("line-height", "2em")]}
line_height_2 = {"selector": "td", "props": "line-height: inherit; padding: 0;"}
border_global = {"selector": "", "props": [("border", "2px solid")]}
border_left_property_1 = [{"selector": "td", "props": "border-left: 1px solid black"}]
border_left_property_2 = [{"selector": "td", "props": "border-left: 2px solid black"}]
border_1 = {
3: border_left_property_1,
6: border_left_property_1,
9: border_left_property_2,
12: border_left_property_1,
15: border_left_property_1,
18: border_left_property_2,
21: border_left_property_1,
24: border_left_property_1,
}
border_top_property_1 = [{"selector": "td", "props": "border-top: 1px solid black"}]
border_top_property_2 = [{"selector": "td", "props": "border-top: 2px solid black"}]
border_2 = {
3: border_top_property_1,
6: border_top_property_1,
9: border_top_property_2,
12: border_top_property_1,
15: border_top_property_1,
18: border_top_property_2,
21: border_top_property_1,
24: border_top_property_1,
}
def display_as_dataframe(x_input, output=None, display="neg"):
x_mask_pos = (
x_input[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
)
x_mask_is_not_pos = x_input[0, 1, :].numpy().reshape(9, 9, 9).max(axis=2) == 0
x_mask_neg_pos = x_input[0, 0, :].numpy().reshape(9, 9, 9).copy()
x_mask_neg_pos[x_mask_is_not_pos] = np.array([0 for _ in range(9)])
x_mask_full_neg = (
x_input[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
)
x_mask_neg_pos = x_mask_neg_pos.reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
cell_color = pd.DataFrame("black", index=df_grid.index, columns=df_grid.columns)
cell_color[x_mask_pos == 1] = "blue"
cell_color[x_mask_full_neg == 1] = "red"
cell_color[x_mask_neg_pos == 1] = "white"
styler = (
df_grid.style.hide(axis=1)
.hide(axis=0)
.set_properties(subset=[x for x in range(27)], **{"width": "2em"})
.set_table_styles(
[ # create internal CSS classes
{"selector": ".red", "props": "color: red; font-weight: bold"},
{"selector": ".blue", "props": "color: blue; font-weight: bold"},
{"selector": ".white", "props": "color: white; font-weight: bold"},
],
overwrite=False,
)
.set_td_classes(cell_color)
.set_table_styles(
[cell, border_global, line_height_1, line_height_2], overwrite=False
)
.set_table_styles(border_1, overwrite=False)
.set_table_styles(border_2, overwrite=False, axis=1)
)
if (output is not None) and (display == "neg"):
x_output = (
output[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
)
return styler.background_gradient(
axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output
)
if (output is not None) and (display == "pos"):
x_output = (
output[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
)
return styler.background_gradient(
axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output
)
return styler
def compute_loss(x, y, output, new_x):
mask_0_error = (new_x == 1) & (y == 0)
mask_error = mask_0_error.view(-1, 2 * 729).any(dim=1)
mask_no_improve = new_x.sum(dim=(1, 2)) <= x.sum(dim=(1, 2))
mask_no_improve[mask_error] = False
mask_1_no_improve = y == 1
mask_1_no_improve[~mask_no_improve] = False
loss = F.binary_cross_entropy_with_logits(output, y, reduce=False)
loss_error = loss[mask_0_error].mean()
loss_no_improve = loss[mask_1_no_improve].mean()
# loss_error = F.binary_cross_entropy_with_logits(output[mask_0_error], y[mask_0_error])
# loss_no_improve = F.binary_cross_entropy_with_logits(output[mask_1_no_improve], y[mask_1_no_improve])
return loss_error, loss_no_improve, mask_error.sum(), mask_no_improve.sum()
# returns 0, 1 or more than 1 depending on whether 0, 1 or more than 1 solutions are found
def solve(i, j, cells, count): # initially called with count = 0
if i == 9:
i = 0
j += 1
if j == 9:
return 1 + count
if cells[i][j] != 0: # skip filled cells
return solve(i + 1, j, cells, count)
for val in range(1, 10):
if count < 2 and legal(i, j, val, cells):
cells[i][j] = val
count = solve(i + 1, j, cells, count)
cells[i][j] = 0 # reset on backtrack
return count
def legal(row, col, num, grid):
# Check if we find the same num
# in the similar row , we
# return false
for x in range(9):
if grid[row][x] == num:
return False
# Check if we find the same num in
# the similar column , we
# return false
for x in range(9):
if grid[x][col] == num:
return False
# Check if we find the same num in
# the particular 3*3 matrix,
# we return false
startRow = row - row % 3
startCol = col - col % 3
for i in range(3):
for j in range(3):
if grid[i + startRow][j + startCol] == num:
return False
return True
def get_grid_number_soluce(grid):
return solve(0,0,grid,0)
def pos_to_digit_col_row(pos):
# pos = i+j*9+k*9*9
digit = pos%9+1
col = (pos//9)%9+1
row=pos//(9*9)+1
return digit, col, row