import io 
import re
import imp
import time
import json 
import base64
import requests
import gradio as gr
import ui_functions as uifn
from css_and_js import js, call_JS
from PIL import Image, PngImagePlugin, ImageChops

url_host = "https://flagstudio.baai.ac.cn"
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMGY4M2QxMDg3N2MzMTFlZGFiYzYwZmU5ZGFjMTI1ZDMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6ImE3YTE1N2I3LTllNTItNDllMS04YzA0LWEzZmI5YjZiZjNlYSIsIm5iZiI6MTY3MDU5MTcwMSwiZXhwIjoxOTg1OTUxNzAxLCJpYXQiOjE2NzA1OTE3MDF9.OcfGayna-wr_5mo4LT6OJHSCokna8vqKSmmCftFUsx8"

def read_content(file_path: str) -> str:
    """read the content of target file
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()

    return content

def filter_content(raw_style: str):
    if "(" in raw_style:
        i = raw_style.index("(")
    else :
        i = -1

    if i == -1:
        return raw_style
    else :
        return raw_style[:i]

def upload_image(img):
    url = url_host + "/api/v1/image/get-upload-link"
    headers = {"token": token}
    r = requests.post(url, json={}, headers=headers)
    if r.status_code != 200:
        raise gr.Error(r.reason)
    head_res = r.json()
    if head_res["code"] != 0:
        raise gr.Error("Unknown error")
    image_id = head_res["data"]["image_id"]
    image_url = head_res["data"]["url"]
    image_headers = head_res["data"]["headers"]

    imgBytes = io.BytesIO()
    img.save(imgBytes, "PNG")
    imgBytes = imgBytes.getvalue()

    r = requests.put(image_url, data=imgBytes, headers=image_headers)
    if r.status_code != 200:
        raise gr.Error(r.reason)
    return image_id, image_url

def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
    data = {
            "type": "gen-image",
            "parameters": {
                "width": width, # output height width
                "height": height, # output image height
                "prompts": [prompt], 
            }
        }
    data["parameters"]["seed"] = int(seed)
    if img is not None:
        # Upload image
        image_id, image_url = upload_image(img)
        data["parameters"]["init_image"] = {
                    "image_id": image_id,
                    "url": image_url,
                    "width": img.width,
                    "height": img.height,
                }
    if mask is not None:
        # Upload mask
        extrama = mask.convert("L").getextrema()
        if extrama[1] > 0:
            mask_id, mask_url = upload_image(mask)
            data["parameters"]["mask_image"] = {
                        "image_id": mask_id,
                        "url": mask_url,
                        "width": mask.width,
                        "height": mask.height,
                    }
    headers = {"token": token}

    # Send create task request
    all_task_data = []
    url = url_host+"/api/v1/task/create"
    for _ in range(image_num):
        r = requests.post(url, json=data, headers=headers)
        if r.status_code != 200:
            raise gr.Error(r.reason)
        create_res = r.json()
        if create_res['code'] == 3002:
            raise gr.Error("Inappropriate prompt detected.")
        elif create_res['code'] != 0:
            raise gr.Error("Unknown error")
        all_task_data.append(create_res["data"])

    # Get result
    url = url_host+"/api/v1/task/status"
    images = []
    while True:
        if len(all_task_data) <= 0:
            return images
        for i in range(len(all_task_data)-1, -1, -1):
            data = all_task_data[i]
            r = requests.post(url, json=data, headers=headers)
            if r.status_code != 200:
                raise gr.Error(r.reason)
            res = r.json()
            if res["code"] == 6002:
                # Running
                continue
            if res["code"] == 6005:
                raise gr.Error("NSFW image detected.")
            elif res["code"] == 0:
                # Finished
                for img_info in res["data"]["images"]:
                    img_res = requests.get(img_info["url"])
                    images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB"))
                del all_task_data[i]
            else:
                raise gr.Error(f"Error code: {res['code']}")
        time.sleep(1)

def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
    if filter_content(class_draw) != "国画":
        if filter_content(class_draw) != "通用":
            raw_text = raw_text + f",{filter_content(class_draw)}"

        for sty in style_draw:
            raw_text = raw_text + f",{filter_content(sty)}"
    elif filter_content(class_draw) == "国画":
        raw_text = raw_text + ",国画,水墨画,大作,黑白,高清,传统"
    print(f"raw text is {raw_text}")

    images = post_reqest(seed, raw_text, w, h, int(batch_size))

    return images


def img2img(prompt, image_and_mask):
    if image_and_mask["image"].width <= image_and_mask["image"].height:
        width = 512
        height = int((width/image_and_mask["image"].width)*image_and_mask["image"].height)
    else:
        height = 512
        width = int((height/image_and_mask["image"].height)*image_and_mask["image"].width)
    return post_reqest(0, prompt, width, height, 1, image_and_mask["image"], image_and_mask["mask"])


examples = [
    '水墨蝴蝶和牡丹花,国画',
    '苍劲有力的墨竹,国画',
    '暴风雨中的灯塔',
    '机械小松鼠,科学幻想',
    '中国水墨山水画,国画',
    "Lighthouse in the storm",
    "A dog",
    "Landscape by 张大千",
    "A tiger 长了兔子耳朵",
    "A baby bird 铅笔素描",
]

if __name__ == "__main__":
    block = gr.Blocks(css=read_content('style.css'))

    with block:
        gr.HTML(read_content("header.html"))
        with gr.Tabs(elem_id='tabss') as tabs:

            with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'):

                with gr.Group():
                    with gr.Box():
                        with gr.Row().style(mobile_collapse=False, equal_height=True):
                            text = gr.Textbox(
                                label="Prompt",
                                show_label=False,
                                max_lines=1,
                                placeholder="Input text(输入文字)",
                                interactive=True,
                            ).style(
                                border=(True, False, True, True),
                                rounded=(True, False, False, True),
                                container=False,
                            )
                            
                            btn = gr.Button("Generate image").style(
                                margin=False,
                                rounded=(True, True, True, True),
                            )
                        with gr.Row().style(mobile_collapse=False, equal_height=True):
                            class_draw = gr.Radio(choices=["通用(general)","国画(traditional Chinese painting)",], value="通用(general)", show_label=True, label='生成类型(type)')
                            # class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)", 
                            #                             "照片,摄影(picture photography)", "油画(oil painting)", 
                            #                             "铅笔素描(pencil sketch)", "CG", 
                            #                             "水彩画(watercolor painting)", "水墨画(ink and wash)",
                            #                             "插画(illustrations)", "3D", "图生图(img2img)"],
                            #                         label="生成类型(type)",
                            #                         show_label=True,
                            #                         value="通用(general)")
                        with gr.Row().style(mobile_collapse=False, equal_height=True):
                            style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", 
                                                        "概念艺术(concept art)", "Warming lighting", 
                                                        "Dramatic lighting", "Natural lighting", 
                                                        "虚幻引擎(unreal engine)", "4k", "8k",
                                                        "充满细节(full details)"],
                                                    label="画面风格(style)",
                                                    show_label=True,
                                                    )
                        with gr.Row().style(mobile_collapse=False, equal_height=True):
                            # sample_size = gr.Slider(minimum=1,
                            #                         maximum=4,
                            #                         step=1,
                            #                         label="生成数量(number)",
                            #                         show_label=True,
                            #                         interactive=True,
                            #                         )
                            sample_size = gr.Radio(choices=["1","2","3","4"], value="1", show_label=True, label='生成数量(number)')
                            seed = gr.Number(0, label='seed', interactive=True)
                        with gr.Row().style(mobile_collapse=False, equal_height=True):
                            w = gr.Slider(512,1024,value=512, step=64, label="width")
                            h = gr.Slider(512,1024,value=512, step=64, label="height")

                    gallery = gr.Gallery(
                        label="Generated images", show_label=False, elem_id="gallery"
                    ).style(grid=[2,2])
                    gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100)
                    with gr.Row().style(mobile_collapse=False, equal_height=True):
                        img_choices = gr.Dropdown(["图片1(img1)"],label='请选择一张图片发送到图生图',show_label=True,value="图片1(img1)")
                    with gr.Row().style(mobile_collapse=False, equal_height=True):
                        output_txt2img_copy_to_input_btn = gr.Button("发送图片到图生图(Sent the image to img2img)").style(
                                margin=False,
                                rounded=(True, True, True, True),
                            )

                    with gr.Row():
                        prompt = gr.Markdown("提示(Prompt):", visible=False)
                    with gr.Row():
                        move_prompt_zh = gr.Markdown("请移至图生图部分进行编辑(拉到顶部)", visible=False)
                    with gr.Row():
                        move_prompt_en = gr.Markdown("Please move to the img2img section for editing(Pull to the top)", visible=False)

                        

                    text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery)
                    btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery)

                    sample_size.change(
                        fn=uifn.change_img_choices,
                        inputs=[sample_size],
                        outputs=[img_choices]
                    )

            with gr.TabItem("图生图(Img-to-Img)", id="img2img_tab"):
                with gr.Row(elem_id="prompt_row"):
                    img2img_prompt = gr.Textbox(label="Prompt",
                                                elem_id='img2img_prompt_input',
                                                placeholder="神奇的森林,流淌的河流.",
                                                lines=1,
                                                max_lines=1,
                                                value="",
                                                show_label=False).style()

                    img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False,
                                                 elem_id="img2img_mask_btn")
                    img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn")
                gr.Markdown('#### 输入图像')
                with gr.Row().style(equal_height=False):
                    #with gr.Column():
                    img2img_image_mask = gr.Image(
                        value=None,
                        source="upload",
                        interactive=True,
                        tool="sketch",
                        type='pil',
                        elem_id="img2img_mask",
                        image_mode="RGBA"
                    )
                gr.Markdown('#### 编辑后的图片')
                with gr.Row():
                    output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style(
                        grid=[4,4,4] )
                with gr.Row():
                    gr.Markdown('提示(prompt):')
                with gr.Row():
                    gr.Markdown('请选择一张图像掩盖掉一部分区域,并输入文本描述')
                with gr.Row():
                    gr.Markdown('Please select an image to cover up a part of the area and enter a text description.')
                gr.Markdown('# 编辑设置',visible=False)


                output_txt2img_copy_to_input_btn.click(
                    uifn.copy_img_to_input,
                    [gallery, img_choices],
                    [tabs, img2img_image_mask, move_prompt_zh, move_prompt_en, prompt]
                )


                img2img_func = img2img
                img2img_inputs = [img2img_prompt, img2img_image_mask]
                img2img_outputs = [output_img2img_gallery]

                img2img_btn_mask.click(
                    img2img_func,
                    img2img_inputs,
                    img2img_outputs
                )

                def img2img_submit_params():
                    return (img2img_func,
                            img2img_inputs,
                            img2img_outputs)

                img2img_btn_editor.click(*img2img_submit_params())

                # GENERATE ON ENTER
                img2img_prompt.submit(None, None, None,
                                      _js=call_JS("clickFirstVisibleButton",
                                                  rowId="prompt_row"))

        gr.HTML(read_content("footer.html"))
        # gr.Image('./contributors.png')

    block.queue(max_size=512, concurrency_count=256).launch()