|
import numpy as np |
|
import pandas as pd |
|
import torch.nn.functional as F |
|
|
|
|
|
x_grid = np.array([x for x in range(1, 10)] * 9 * 9).reshape(9, 9, 3, 3) |
|
df_grid = pd.DataFrame(x_grid.swapaxes(1, 2).reshape(27, 27)) |
|
|
|
cell = { |
|
"selector": "td", |
|
"props": [("color", "black"), ("text-align", "center")], |
|
} |
|
line_height_1 = {"selector": "tr", "props": [("line-height", "2em")]} |
|
line_height_2 = {"selector": "td", "props": "line-height: inherit; padding: 0;"} |
|
border_global = {"selector": "", "props": [("border", "2px solid")]} |
|
border_left_property_1 = [{"selector": "td", "props": "border-left: 1px solid black"}] |
|
border_left_property_2 = [{"selector": "td", "props": "border-left: 2px solid black"}] |
|
|
|
border_1 = { |
|
3: border_left_property_1, |
|
6: border_left_property_1, |
|
9: border_left_property_2, |
|
12: border_left_property_1, |
|
15: border_left_property_1, |
|
18: border_left_property_2, |
|
21: border_left_property_1, |
|
24: border_left_property_1, |
|
} |
|
border_top_property_1 = [{"selector": "td", "props": "border-top: 1px solid black"}] |
|
border_top_property_2 = [{"selector": "td", "props": "border-top: 2px solid black"}] |
|
border_2 = { |
|
3: border_top_property_1, |
|
6: border_top_property_1, |
|
9: border_top_property_2, |
|
12: border_top_property_1, |
|
15: border_top_property_1, |
|
18: border_top_property_2, |
|
21: border_top_property_1, |
|
24: border_top_property_1, |
|
} |
|
|
|
|
|
def display_as_dataframe(x_input, output=None, display="neg"): |
|
x_mask_pos = ( |
|
x_input[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) |
|
) |
|
x_mask_is_not_pos = x_input[0, 1, :].numpy().reshape(9, 9, 9).max(axis=2) == 0 |
|
x_mask_neg_pos = x_input[0, 0, :].numpy().reshape(9, 9, 9).copy() |
|
x_mask_neg_pos[x_mask_is_not_pos] = np.array([0 for _ in range(9)]) |
|
x_mask_full_neg = ( |
|
x_input[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) |
|
) |
|
x_mask_neg_pos = x_mask_neg_pos.reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) |
|
|
|
cell_color = pd.DataFrame("black", index=df_grid.index, columns=df_grid.columns) |
|
cell_color[x_mask_pos == 1] = "blue" |
|
cell_color[x_mask_full_neg == 1] = "red" |
|
cell_color[x_mask_neg_pos == 1] = "white" |
|
|
|
styler = ( |
|
df_grid.style.hide(axis=1) |
|
.hide(axis=0) |
|
.set_properties(subset=[x for x in range(27)], **{"width": "2em"}) |
|
.set_table_styles( |
|
[ |
|
{"selector": ".red", "props": "color: red; font-weight: bold"}, |
|
{"selector": ".blue", "props": "color: blue; font-weight: bold"}, |
|
{"selector": ".white", "props": "color: white; font-weight: bold"}, |
|
], |
|
overwrite=False, |
|
) |
|
.set_td_classes(cell_color) |
|
.set_table_styles( |
|
[cell, border_global, line_height_1, line_height_2], overwrite=False |
|
) |
|
.set_table_styles(border_1, overwrite=False) |
|
.set_table_styles(border_2, overwrite=False, axis=1) |
|
) |
|
if (output is not None) and (display == "neg"): |
|
x_output = ( |
|
output[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) |
|
) |
|
return styler.background_gradient( |
|
axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output |
|
) |
|
if (output is not None) and (display == "pos"): |
|
x_output = ( |
|
output[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) |
|
) |
|
return styler.background_gradient( |
|
axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output |
|
) |
|
|
|
return styler |
|
|
|
|
|
def compute_loss(x, y, output, new_x): |
|
mask_0_error = (new_x == 1) & (y == 0) |
|
mask_error = mask_0_error.view(-1, 2 * 729).any(dim=1) |
|
mask_no_improve = new_x.sum(dim=(1, 2)) <= x.sum(dim=(1, 2)) |
|
mask_no_improve[mask_error] = False |
|
mask_1_no_improve = y == 1 |
|
mask_1_no_improve[~mask_no_improve] = False |
|
loss = F.binary_cross_entropy_with_logits(output, y, reduce=False) |
|
loss_error = loss[mask_0_error].mean() |
|
loss_no_improve = loss[mask_1_no_improve].mean() |
|
|
|
|
|
|
|
return loss_error, loss_no_improve, mask_error.sum(), mask_no_improve.sum() |
|
|
|
|
|
|
|
def solve(i, j, cells, count): |
|
if i == 9: |
|
i = 0 |
|
j += 1 |
|
if j == 9: |
|
return 1 + count |
|
if cells[i][j] != 0: |
|
return solve(i + 1, j, cells, count) |
|
|
|
for val in range(1, 10): |
|
if count < 2 and legal(i, j, val, cells): |
|
cells[i][j] = val |
|
count = solve(i + 1, j, cells, count) |
|
|
|
cells[i][j] = 0 |
|
return count |
|
|
|
|
|
def legal(row, col, num, grid): |
|
|
|
|
|
|
|
|
|
for x in range(9): |
|
if grid[row][x] == num: |
|
return False |
|
|
|
|
|
|
|
|
|
for x in range(9): |
|
if grid[x][col] == num: |
|
return False |
|
|
|
|
|
|
|
|
|
startRow = row - row % 3 |
|
startCol = col - col % 3 |
|
for i in range(3): |
|
for j in range(3): |
|
if grid[i + startRow][j + startCol] == num: |
|
return False |
|
return True |
|
|
|
def get_grid_number_soluce(grid): |
|
return solve(0,0,grid,0) |
|
|
|
|
|
def pos_to_digit_col_row(pos): |
|
|
|
digit = pos%9+1 |
|
col = (pos//9)%9+1 |
|
row=pos//(9*9)+1 |
|
return digit, col, row |