import numpy as np import pandas as pd import torch.nn.functional as F x_grid = np.array([x for x in range(1, 10)] * 9 * 9).reshape(9, 9, 3, 3) df_grid = pd.DataFrame(x_grid.swapaxes(1, 2).reshape(27, 27)) cell = { # for row hover use instead of "selector": "td", "props": [("color", "black"), ("text-align", "center")], } line_height_1 = {"selector": "tr", "props": [("line-height", "2em")]} line_height_2 = {"selector": "td", "props": "line-height: inherit; padding: 0;"} border_global = {"selector": "", "props": [("border", "2px solid")]} border_left_property_1 = [{"selector": "td", "props": "border-left: 1px solid black"}] border_left_property_2 = [{"selector": "td", "props": "border-left: 2px solid black"}] border_1 = { 3: border_left_property_1, 6: border_left_property_1, 9: border_left_property_2, 12: border_left_property_1, 15: border_left_property_1, 18: border_left_property_2, 21: border_left_property_1, 24: border_left_property_1, } border_top_property_1 = [{"selector": "td", "props": "border-top: 1px solid black"}] border_top_property_2 = [{"selector": "td", "props": "border-top: 2px solid black"}] border_2 = { 3: border_top_property_1, 6: border_top_property_1, 9: border_top_property_2, 12: border_top_property_1, 15: border_top_property_1, 18: border_top_property_2, 21: border_top_property_1, 24: border_top_property_1, } def display_as_dataframe(x_input, output=None, display="neg"): x_mask_pos = ( x_input[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) ) x_mask_is_not_pos = x_input[0, 1, :].numpy().reshape(9, 9, 9).max(axis=2) == 0 x_mask_neg_pos = x_input[0, 0, :].numpy().reshape(9, 9, 9).copy() x_mask_neg_pos[x_mask_is_not_pos] = np.array([0 for _ in range(9)]) x_mask_full_neg = ( x_input[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) ) x_mask_neg_pos = x_mask_neg_pos.reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) cell_color = pd.DataFrame("black", index=df_grid.index, columns=df_grid.columns) cell_color[x_mask_pos == 1] = "blue" cell_color[x_mask_full_neg == 1] = "red" cell_color[x_mask_neg_pos == 1] = "white" styler = ( df_grid.style.hide(axis=1) .hide(axis=0) .set_properties(subset=[x for x in range(27)], **{"width": "2em"}) .set_table_styles( [ # create internal CSS classes {"selector": ".red", "props": "color: red; font-weight: bold"}, {"selector": ".blue", "props": "color: blue; font-weight: bold"}, {"selector": ".white", "props": "color: white; font-weight: bold"}, ], overwrite=False, ) .set_td_classes(cell_color) .set_table_styles( [cell, border_global, line_height_1, line_height_2], overwrite=False ) .set_table_styles(border_1, overwrite=False) .set_table_styles(border_2, overwrite=False, axis=1) ) if (output is not None) and (display == "neg"): x_output = ( output[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) ) return styler.background_gradient( axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output ) if (output is not None) and (display == "pos"): x_output = ( output[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27) ) return styler.background_gradient( axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output ) return styler def compute_loss(x, y, output, new_x): mask_0_error = (new_x == 1) & (y == 0) mask_error = mask_0_error.view(-1, 2 * 729).any(dim=1) mask_no_improve = new_x.sum(dim=(1, 2)) <= x.sum(dim=(1, 2)) mask_no_improve[mask_error] = False mask_1_no_improve = y == 1 mask_1_no_improve[~mask_no_improve] = False loss = F.binary_cross_entropy_with_logits(output, y, reduce=False) loss_error = loss[mask_0_error].mean() loss_no_improve = loss[mask_1_no_improve].mean() # loss_error = F.binary_cross_entropy_with_logits(output[mask_0_error], y[mask_0_error]) # loss_no_improve = F.binary_cross_entropy_with_logits(output[mask_1_no_improve], y[mask_1_no_improve]) return loss_error, loss_no_improve, mask_error.sum(), mask_no_improve.sum() # returns 0, 1 or more than 1 depending on whether 0, 1 or more than 1 solutions are found def solve(i, j, cells, count): # initially called with count = 0 if i == 9: i = 0 j += 1 if j == 9: return 1 + count if cells[i][j] != 0: # skip filled cells return solve(i + 1, j, cells, count) for val in range(1, 10): if count < 2 and legal(i, j, val, cells): cells[i][j] = val count = solve(i + 1, j, cells, count) cells[i][j] = 0 # reset on backtrack return count def legal(row, col, num, grid): # Check if we find the same num # in the similar row , we # return false for x in range(9): if grid[row][x] == num: return False # Check if we find the same num in # the similar column , we # return false for x in range(9): if grid[x][col] == num: return False # Check if we find the same num in # the particular 3*3 matrix, # we return false startRow = row - row % 3 startCol = col - col % 3 for i in range(3): for j in range(3): if grid[i + startRow][j + startCol] == num: return False return True def get_grid_number_soluce(grid): return solve(0,0,grid,0) def pos_to_digit_col_row(pos): # pos = i+j*9+k*9*9 digit = pos%9+1 col = (pos//9)%9+1 row=pos//(9*9)+1 return digit, col, row