import os
import re
from typing import Mapping, Tuple, Dict

import cv2
import gradio as gr
import numpy as np
import pandas as pd
from PIL import Image
from huggingface_hub import hf_hub_download
from onnxruntime import InferenceSession


# noinspection PyUnresolvedReferences
def make_square(img, target_size):
    old_size = img.shape[:2]
    desired_size = max(old_size)
    desired_size = max(desired_size, target_size)

    delta_w = desired_size - old_size[1]
    delta_h = desired_size - old_size[0]
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)

    color = [255, 255, 255]
    return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)


# noinspection PyUnresolvedReferences
def smart_resize(img, size):
    # Assumes the image has already gone through make_square
    if img.shape[0] > size:
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
    elif img.shape[0] < size:
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
    else:  # just do nothing
        pass

    return img


class WaifuDiffusionInterrogator:
    def __init__(
            self,
            repo='SmilingWolf/wd-v1-4-vit-tagger',
            model_path='model.onnx',
            tags_path='selected_tags.csv',
            mode: str = "auto"
    ) -> None:
        self.__repo = repo
        self.__model_path = model_path
        self.__tags_path = tags_path
        self._provider_mode = mode

        self.__initialized = False
        self._model, self._tags = None, None

    def _init(self) -> None:
        if self.__initialized:
            return

        model_path = hf_hub_download(self.__repo, filename=self.__model_path)
        tags_path = hf_hub_download(self.__repo, filename=self.__tags_path)

        self._model = InferenceSession(str(model_path))
        self._tags = pd.read_csv(tags_path)

        self.__initialized = True

    def _calculation(self, image: Image.Image) -> pd.DataFrame:
        self._init()

        # code for converting the image and running the model is taken from the link below
        # thanks, SmilingWolf!
        # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py

        # convert an image to fit the model
        _, height, _, _ = self._model.get_inputs()[0].shape

        # alpha to white
        image = image.convert('RGBA')
        new_image = Image.new('RGBA', image.size, 'WHITE')
        new_image.paste(image, mask=image)
        image = new_image.convert('RGB')
        image = np.asarray(image)

        # PIL RGB to OpenCV BGR
        image = image[:, :, ::-1]

        image = make_square(image, height)
        image = smart_resize(image, height)
        image = image.astype(np.float32)
        image = np.expand_dims(image, 0)

        # evaluate model
        input_name = self._model.get_inputs()[0].name
        label_name = self._model.get_outputs()[0].name
        confidence = self._model.run([label_name], {input_name: image})[0]

        full_tags = self._tags[['name', 'category']].copy()
        full_tags['confidence'] = confidence[0]

        return full_tags

    def interrogate(self, image: Image) -> Tuple[Dict[str, float], Dict[str, float]]:
        full_tags = self._calculation(image)

        # first 4 items are for rating (general, sensitive, questionable, explicit)
        ratings = dict(full_tags[full_tags['category'] == 9][['name', 'confidence']].values)

        # rest are regular tags
        tags = dict(full_tags[full_tags['category'] != 9][['name', 'confidence']].values)

        return ratings, tags


WAIFU_MODELS: Mapping[str, WaifuDiffusionInterrogator] = {
    'wd14-vit': WaifuDiffusionInterrogator(),
    'wd14-convnext': WaifuDiffusionInterrogator(
        repo='SmilingWolf/wd-v1-4-convnext-tagger'
    ),
}
RE_SPECIAL = re.compile(r'([\\()])')


def image_to_wd14_tags(image: Image.Image, model_name: str, threshold: float,
                       use_spaces: bool, use_escape: bool, include_ranks: bool, score_descend: bool) \
        -> Tuple[Mapping[str, float], str, Mapping[str, float]]:
    model = WAIFU_MODELS[model_name]
    ratings, tags = model.interrogate(image)

    filtered_tags = {
        tag: score for tag, score in tags.items()
        if score >= threshold
    }

    text_items = []
    tags_pairs = filtered_tags.items()
    if score_descend:
        tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0]))
    for tag, score in tags_pairs:
        tag_outformat = tag
        if use_spaces:
            tag_outformat = tag_outformat.replace('_', ' ')
        if use_escape:
            tag_outformat = re.sub(RE_SPECIAL, r'\\\1', tag_outformat)
        if include_ranks:
            tag_outformat = f"({tag_outformat}:{score:.3f})"
        text_items.append(tag_outformat)
    output_text = ', '.join(text_items)

    return ratings, output_text, filtered_tags


if __name__ == '__main__':
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Chen')
                with gr.Row():
                    gr_model = gr.Radio(list(WAIFU_MODELS.keys()), value='wd14-vit', label='Waifu Model')
                    gr_threshold = gr.Slider(0.0, 1.0, 0.5, label='Chen Chen Chen Chen Chen')
                with gr.Row():
                    gr_space = gr.Checkbox(value=False, label='Use Space Instead Of _')
                    gr_escape = gr.Checkbox(value=True, label='Use Text Escape')
                    gr_confidence = gr.Checkbox(value=False, label='Keep Confidences')
                    gr_order = gr.Checkbox(value=True, label='Descend By Confidence')

                gr_btn_submit = gr.Button(value='Tagging', variant='primary')

            with gr.Column():
                gr_ratings = gr.Label(label='Ratings')
                with gr.Tabs():
                    with gr.Tab("Tags"):
                        gr_tags = gr.Label(label='Tags')
                    with gr.Tab("Exported Text"):
                        gr_output_text = gr.TextArea(label='Exported Text')

        gr_btn_submit.click(
            image_to_wd14_tags,
            inputs=[gr_input_image, gr_model, gr_threshold, gr_space, gr_escape, gr_confidence, gr_order],
            outputs=[gr_ratings, gr_output_text, gr_tags],
        )
    demo.queue(os.cpu_count()).launch()