import gradio as gr
# import cv2
import os
import base64
from pathlib import Path
    
from core.chessboard_detector import ChessboardDetector

detector = ChessboardDetector(
    pose_model_path="onnx/pose/4_v6-0301.onnx",
    full_classifier_model_path="onnx/layout_recognition/nano_v3-0319.onnx"
)

# 数据集路径
dict_cate_names = {
    '.': '.',
    'x': 'x',
    '红帅': 'K',
    '红士': 'A',
    '红相': 'B',
    '红马': 'N',
    '红车': 'R',
    '红炮': 'C',
    '红兵': 'P',

    '黑将': 'k',
    '黑仕': 'a',
    '黑象': 'b',
    '黑傌': 'n',
    '黑車': 'r',
    '黑砲': 'c',
    '黑卒': 'p',
}

# 数据集路径
dict_cate_images = {
    'K': 'red_K.png',
    'A': 'red_A.png',
    'B': 'red_B.png',
    'N': 'red_N.png',
    'R': 'red_R.png',
    'C': 'red_C.png',
    'P': 'red_P.png',
    'k': 'black_k.png',
    'a': 'black_a.png',
    'b': 'black_b.png',
    'n': 'black_n.png',
    'r': 'black_r.png',
    'c': 'black_c.png',
    'p': 'black_p.png',
}

dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()}


 # 缓存图片的 base64 编码
image_base64_cache = {}

def get_image_base64(img_path):
    if img_path in image_base64_cache:
        return image_base64_cache[img_path]
    
    try:
        img_full_path = Path("resources") / img_path
        if not img_full_path.exists():
            return ""
            
        with open(img_full_path, "rb") as img_file:
            encoded = base64.b64encode(img_file.read()).decode('utf-8')
            data_url = f"data:image/png;base64,{encoded}"
            image_base64_cache[img_path] = data_url
            return data_url
    except Exception as e:
        print(f"Error loading image {img_path}: {e}")
        return 




### 构建 examples 

def build_examples():
    examples = []
    # 读取 examples 目录下的所有图片
    for file in os.listdir("examples"):
        if file.endswith(".jpg") or file.endswith(".png"):
            image_path = os.path.join("examples", file)
            examples.append([image_path])

    return examples


full_examples = build_examples()


with gr.Blocks(css="""
        .image img {
            max-height: 512px;
        }
    """
) as demo:
    gr.Markdown("""
                ## 棋盘检测, 棋子识别

                features: 轻量化模型
                

                x 表示 有遮挡位置   
                . 表示 棋盘上的普通交叉点  

                步骤:  
                    1. 流程分成两步,第一步 keypoints 检测
                    2. 拉伸棋盘,并预测棋子
                    
                log:
                    1. 优化棋子识别,增加对游戏棋盘的识别
                """
    )
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")

        with gr.Column():
            original_image_with_keypoints = gr.Image(
                label="step1: 原图带关键点",
                interactive=False,
                visible=True,
                elem_classes="image"
            )


    with gr.Row():
        with gr.Column():   
            transformed_image = gr.Image(
                label="step2: 拉伸棋盘",
                interactive=False,
                visible=True,
                elem_classes="image"
            )

        with gr.Column():
            use_time = gr.Textbox(
                label="用时",
                interactive=False,
                visible=True,
            )
            
            # 添加 手风琴
            with gr.Accordion("文字识别", open=False):
                layout_pred_info = gr.Dataframe(
                    label="棋子识别",
                    interactive=False,
                    visible=True,
                )
            
            with gr.Accordion("棋子识别", open=True):
                # 10 行 9 列的表格
                table_html = gr.HTML(
                    """
                    <table>
                    </table>
                    """
                )

    with gr.Row():
        with gr.Column():
            gr.Examples(
                full_examples[:10], inputs=[image_input], label="示例图片1",  examples_per_page=10,)
            
            gr.Examples(
                full_examples[10:20], inputs=[image_input], label="示例图片2",  examples_per_page=10,)
            
            gr.Examples(
                full_examples[20:], inputs=[image_input], label="示例图片3",  examples_per_page=10,)
            

    def gen_table_html(annotation_arr_10_9):
        # 生成表格 HTML
        html = "<table border='1' style='margin: auto;'>"
        
        for row in annotation_arr_10_9:
            html += "<tr>"
            for cell in row:
                if cell == '.':
                    # 普通交叉点
                    html += "<td style='width: 60px; height: 60px; text-align: center;'></td>"
                elif cell == 'x':
                    # 遮挡位置
                    html += "<td style='width: 60px; height: 60px; text-align: center;'>x</td>"
                else:
                    # 获取对应的图片文件名
                    img_file = dict_cate_images.get(cell, '')
                    img_data_base64 = get_image_base64(img_file)
                    # 生成图片标签
                    html += f"<td style='width: 60px; height: 60px; text-align: center; padding: 0;'><img src='{img_data_base64}' width='58' height='58'></td>"
            html += "</tr>"
        
        html += "</table>"
        return html
        


    def detect_chessboard(image):
        if image is None:
            return None, None, None, None, None
        
        try:
            original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image)

            # 将 cells_labels 转换为 DataFrame
            # cells_labels 通过  \n 分割
            annotation_10_rows = [item for item in cells_labels_str.split("\n")]
            # 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
            annotation_arr_10_9_short = [list(item) for item in annotation_10_rows]

            # 将 棋子类别 转换为 中文
            annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9_short]

        except Exception as e:
            gr.Warning(f"检测失败 图片或者视频布局错误")
            return None, None, None, None, None


        table_html = gen_table_html(annotation_arr_10_9_short)

        return original_image_with_keypoints, transformed_image, annotation_arr_10_9, table_html, time_info

    image_input.change(fn=detect_chessboard, 
                       inputs=[image_input], 
                       outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, table_html, use_time])

if __name__ == "__main__":
    demo.launch()