import os, re, cv2
from typing import Mapping, Tuple, Dict
import gradio as gr
import numpy as np
import io
import pandas as pd
from PIL import Image
from huggingface_hub import hf_hub_download
from onnxruntime import InferenceSession

# noinspection PyUnresolvedReferences
def make_square(img, target_size):
    old_size = img.shape[:2]
    desired_size = max(old_size)
    desired_size = max(desired_size, target_size)

    delta_w = desired_size - old_size[1]
    delta_h = desired_size - old_size[0]
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)

    color = [255, 255, 255]
    return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)


# noinspection PyUnresolvedReferences
def smart_resize(img, size):
    # Assumes the image has already gone through make_square
    if img.shape[0] > size:
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
    elif img.shape[0] < size:
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
    else:  # just do nothing
        pass

    return img


class WaifuDiffusionInterrogator:
    def __init__(
            self,
            repo='SmilingWolf/wd-v1-4-vit-tagger',
            model_path='model.onnx',
            tags_path='selected_tags.csv',
            mode: str = "auto"
    ) -> None:
        self.__repo = repo
        self.__model_path = model_path
        self.__tags_path = tags_path
        self._provider_mode = mode

        self.__initialized = False
        self._model, self._tags = None, None
    def _init(self) -> None:
        if self.__initialized:
            return

        model_path = hf_hub_download(self.__repo, filename=self.__model_path)
        tags_path = hf_hub_download(self.__repo, filename=self.__tags_path)

        self._model = InferenceSession(str(model_path))
        self._tags = pd.read_csv(tags_path)

        self.__initialized = True

    def _calculation(self, image: Image.Image) -> pd.DataFrame:
        # print(image) todo: figure out what to do if URL
        self._init()

        # code for converting the image and running the model is taken from the link below
        # thanks, SmilingWolf!
        # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py

        # convert an image to fit the model
        _, height, _, _ = self._model.get_inputs()[0].shape

        # alpha to white
        image = image.convert('RGBA')
        new_image = Image.new('RGBA', image.size, 'WHITE')
        new_image.paste(image, mask=image)
        image = new_image.convert('RGB')
        image = np.asarray(image)

        # PIL RGB to OpenCV BGR
        image = image[:, :, ::-1]

        image = make_square(image, height)
        image = smart_resize(image, height)
        image = image.astype(np.float32)
        image = np.expand_dims(image, 0)

        # evaluate model
        input_name = self._model.get_inputs()[0].name
        label_name = self._model.get_outputs()[0].name
        confidence = self._model.run([label_name], {input_name: image})[0]

        full_tags = self._tags[['name', 'category']].copy()
        full_tags['confidence'] = confidence[0]

        return full_tags
    def interrogate(self, image: Image) -> Tuple[Dict[str, float], Dict[str, float]]:
        

        full_tags = self._calculation(image)

        # first 4 items are for rating (general, sensitive, questionable, explicit)
        ratings = dict(full_tags[full_tags['category'] == 9][['name', 'confidence']].values)

        # rest are regular tags
        tags = dict(full_tags[full_tags['category'] != 9][['name', 'confidence']].values)

        return ratings, tags


WAIFU_MODELS: Mapping[str, WaifuDiffusionInterrogator] = {
    'chen-vit': WaifuDiffusionInterrogator(),
    'chen-convnext': WaifuDiffusionInterrogator(
        repo='SmilingWolf/wd-v1-4-convnext-tagger'
    ),
    'chen-convnext2-v2': WaifuDiffusionInterrogator(
        repo="SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
    ),
    'chen-swin2': WaifuDiffusionInterrogator(
        repo='SmilingWolf/wd-v1-4-swinv2-tagger-v2'
    ),
    'chen-moatv2': WaifuDiffusionInterrogator(
        repo='SmilingWolf/wd-v1-4-moat-tagger-v2'
    ),
    'chen-convnextv3': WaifuDiffusionInterrogator(
        repo='SmilingWolf/wd-convnext-tagger-v3'
    ),
    'chen-vitv3': WaifuDiffusionInterrogator(
        repo='SmilingWolf/wd-vit-tagger-v3'
    ),
    'chen-swinv3': WaifuDiffusionInterrogator(
        repo='SmilingWolf/wd-swinv2-tagger-v3'
    ),
    'chen-vit-largev3': WaifuDiffusionInterrogator(
        repo='SmilingWolf/wd-vit-large-tagger-v3'
    ),
    'chen-evangelion': WaifuDiffusionInterrogator(
        repo='SmilingWolf/wd-eva02-large-tagger-v3'
    ),
    'chenkaku-evangelion': WaifuDiffusionInterrogator(
        repo='deepghs/idolsankaku-eva02-large-tagger-v1'
    ),
    'chenkaku-swinv2': WaifuDiffusionInterrogator(
        repo='deepghs/idolsankaku-swinv2-tagger-v1'
    ),
}
RE_SPECIAL = re.compile(r'([\\()])')


def image_to_wd14_tags(image: Image.Image, model_name: str, threshold: float,
                       use_spaces: bool, use_escape: bool, include_ranks=False, score_descend=True) \
        -> Tuple[Mapping[str, float], str, Mapping[str, float]]:
    model = WAIFU_MODELS[model_name]
    ratings, tags = model.interrogate(image)

    filtered_tags = {
        tag: score for tag, score in tags.items()
        if score >= threshold
    }

    text_items = []
    tags_pairs = filtered_tags.items()
    if score_descend:
        tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0]))
    for tag, score in tags_pairs:
        tag_outformat = tag
        if use_spaces:
            tag_outformat = tag_outformat.replace('_', '-')
        else:
            tag_outformat = tag_outformat.replace(' ', ', ')
            tag_outformat = tag_outformat.replace('_', ' ')
        if use_escape:
            tag_outformat = re.sub(RE_SPECIAL, r'\\\1', tag_outformat)
        if include_ranks:
            tag_outformat = f"({tag_outformat}:{score:.3f})"
        text_items.append(tag_outformat)
    if use_spaces:
        output_text = ' '.join(text_items)
    else:
        output_text = ', '.join(text_items)

    return ratings, output_text, filtered_tags

if __name__ == '__main__':
    with gr.Blocks(analytics_enabled=False, theme="NoCrypt/miku") as demo:
        with gr.Row():
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Chen Chen', sources=['upload', 'clipboard'])
                with gr.Row():
                    gr_model = gr.Radio(list(WAIFU_MODELS.keys()), value='chen-vit-largev3', label='Chen')
                    gr_threshold = gr.Slider(0.0, 1.0, 0.5, label='Chen Chen Chen Chen Chen')
                with gr.Row():
                    gr_space = gr.Checkbox(value=False, label='Use DashSpace')
                    gr_escape = gr.Checkbox(value=True, label='Chen Text Escape')

                gr_btn_submit = gr.Button(value='橙', variant='primary')

            with gr.Column():
                gr_ratings = gr.Label(label='橙 橙')
                with gr.Tabs():
                    with gr.Tab("Chens"):
                        gr_tags = gr.Label(label='Chens')
                    with gr.Tab("Chen Text"):
                        gr_output_text = gr.TextArea(label='Chen Text')

        gr_btn_submit.click(
            image_to_wd14_tags,
            inputs=[gr_input_image, gr_model, gr_threshold, gr_space, gr_escape],
            outputs=[gr_ratings, gr_output_text, gr_tags],
            api_name="classify"
        )
    demo.queue(os.cpu_count()).launch()