File size: 5,731 Bytes
4484b8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import numpy as np
import pandas as pd
import torch.nn.functional as F
x_grid = np.array([x for x in range(1, 10)] * 9 * 9).reshape(9, 9, 3, 3)
df_grid = pd.DataFrame(x_grid.swapaxes(1, 2).reshape(27, 27))
cell = { # for row hover use <tr> instead of <td>
"selector": "td",
"props": [("color", "black"), ("text-align", "center")],
}
line_height_1 = {"selector": "tr", "props": [("line-height", "2em")]}
line_height_2 = {"selector": "td", "props": "line-height: inherit; padding: 0;"}
border_global = {"selector": "", "props": [("border", "2px solid")]}
border_left_property_1 = [{"selector": "td", "props": "border-left: 1px solid black"}]
border_left_property_2 = [{"selector": "td", "props": "border-left: 2px solid black"}]
border_1 = {
3: border_left_property_1,
6: border_left_property_1,
9: border_left_property_2,
12: border_left_property_1,
15: border_left_property_1,
18: border_left_property_2,
21: border_left_property_1,
24: border_left_property_1,
}
border_top_property_1 = [{"selector": "td", "props": "border-top: 1px solid black"}]
border_top_property_2 = [{"selector": "td", "props": "border-top: 2px solid black"}]
border_2 = {
3: border_top_property_1,
6: border_top_property_1,
9: border_top_property_2,
12: border_top_property_1,
15: border_top_property_1,
18: border_top_property_2,
21: border_top_property_1,
24: border_top_property_1,
}
def display_as_dataframe(x_input, output=None, display="neg"):
x_mask_pos = (
x_input[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
)
x_mask_is_not_pos = x_input[0, 1, :].numpy().reshape(9, 9, 9).max(axis=2) == 0
x_mask_neg_pos = x_input[0, 0, :].numpy().reshape(9, 9, 9).copy()
x_mask_neg_pos[x_mask_is_not_pos] = np.array([0 for _ in range(9)])
x_mask_full_neg = (
x_input[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
)
x_mask_neg_pos = x_mask_neg_pos.reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
cell_color = pd.DataFrame("black", index=df_grid.index, columns=df_grid.columns)
cell_color[x_mask_pos == 1] = "blue"
cell_color[x_mask_full_neg == 1] = "red"
cell_color[x_mask_neg_pos == 1] = "white"
styler = (
df_grid.style.hide(axis=1)
.hide(axis=0)
.set_properties(subset=[x for x in range(27)], **{"width": "2em"})
.set_table_styles(
[ # create internal CSS classes
{"selector": ".red", "props": "color: red; font-weight: bold"},
{"selector": ".blue", "props": "color: blue; font-weight: bold"},
{"selector": ".white", "props": "color: white; font-weight: bold"},
],
overwrite=False,
)
.set_td_classes(cell_color)
.set_table_styles(
[cell, border_global, line_height_1, line_height_2], overwrite=False
)
.set_table_styles(border_1, overwrite=False)
.set_table_styles(border_2, overwrite=False, axis=1)
)
if (output is not None) and (display == "neg"):
x_output = (
output[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
)
return styler.background_gradient(
axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output
)
if (output is not None) and (display == "pos"):
x_output = (
output[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
)
return styler.background_gradient(
axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output
)
return styler
def compute_loss(x, y, output, new_x):
mask_0_error = (new_x == 1) & (y == 0)
mask_error = mask_0_error.view(-1, 2 * 729).any(dim=1)
mask_no_improve = new_x.sum(dim=(1, 2)) <= x.sum(dim=(1, 2))
mask_no_improve[mask_error] = False
mask_1_no_improve = y == 1
mask_1_no_improve[~mask_no_improve] = False
loss = F.binary_cross_entropy_with_logits(output, y, reduce=False)
loss_error = loss[mask_0_error].mean()
loss_no_improve = loss[mask_1_no_improve].mean()
# loss_error = F.binary_cross_entropy_with_logits(output[mask_0_error], y[mask_0_error])
# loss_no_improve = F.binary_cross_entropy_with_logits(output[mask_1_no_improve], y[mask_1_no_improve])
return loss_error, loss_no_improve, mask_error.sum(), mask_no_improve.sum()
# returns 0, 1 or more than 1 depending on whether 0, 1 or more than 1 solutions are found
def solve(i, j, cells, count): # initially called with count = 0
if i == 9:
i = 0
j += 1
if j == 9:
return 1 + count
if cells[i][j] != 0: # skip filled cells
return solve(i + 1, j, cells, count)
for val in range(1, 10):
if count < 2 and legal(i, j, val, cells):
cells[i][j] = val
count = solve(i + 1, j, cells, count)
cells[i][j] = 0 # reset on backtrack
return count
def legal(row, col, num, grid):
# Check if we find the same num
# in the similar row , we
# return false
for x in range(9):
if grid[row][x] == num:
return False
# Check if we find the same num in
# the similar column , we
# return false
for x in range(9):
if grid[x][col] == num:
return False
# Check if we find the same num in
# the particular 3*3 matrix,
# we return false
startRow = row - row % 3
startCol = col - col % 3
for i in range(3):
for j in range(3):
if grid[i + startRow][j + startCol] == num:
return False
return True
def get_grid_number_soluce(grid):
return solve(0,0,grid,0) |