from __future__ import annotations
import gradio as gr
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
import subprocess
def runcmd(command):
    ret = subprocess.run(command,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE,encoding="utf-8",timeout=60)
    if ret.returncode == 0:
        print("success:",ret)
    else:
        print("error:",ret)
runcmd("pip3 install --upgrade clueai")

import clueai
cl = clueai.Client("", check_api_key=False)

'''
#luck_t2i_btn_1, #luck_s2i_btn_1, #luck_i2i_btn_1, #luck_ici_btn_1{
            color: #fff;
            --tw-gradient-from: #BED336;
            --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
            --tw-gradient-to: #BED336;
            border-color: #BED336;
        }

        #luck_easy_btn_1, #luck_iti_btn_1, #luck_tsi_btn_1, #luck_isi_btn_1{
            color: #fff;
            --tw-gradient-from: #BED336;
            --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
            --tw-gradient-to: #BED336;
            border-color: #BED336;
        }
'''
css='''
        .container { max-width: 800px; margin: auto; }
        #gen_btn_1{
            color: #fff;
            --tw-gradient-from: #f44336;
            --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
            --tw-gradient-to: #ff9800;
            border-color: #ff9800; 
        }
        #t2i_btn_1, #s2i_btn_1, #i2i_btn_1, #ici_btn_1, #easy_btn_1, #iti_btn_1, #tsi_btn_1, #isi_btn_1{
            color: #fff;
            --tw-gradient-from: #f44336;
            --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
            --tw-gradient-to: #ff9800;
            border-color: #ff9800;
        }
        

        #import_t2i_btn_1, #import_s2i_btn_1, #import_i2i_btn_1, #import_ici_btn_1{
            color: #fff;
            --tw-gradient-from: #BED336;
            --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
            --tw-gradient-to: #BED336;
            border-color: #BED336;
        }

        #import_easy_btn_1, #import_iti_btn_1, #import_tsi_btn_1, #import_isi_btn_1{
            color: #fff;
            --tw-gradient-from: #BED336;
            --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
            --tw-gradient-to: #BED336;
            border-color: #BED336;
        }

        #record_btn{
            
        }
        #record_btn > div > button > span {
            width: 2.375rem;
            height: 2.375rem;
        }
        #record_btn > div > button > span > span {
            width: 2.375rem;
            height: 2.375rem;
        }
        audio {
            margin-bottom: 10px;
        }
        div#record_btn > .mt-6{
            margin-top: 0!important;
        }
        div#record_btn > .mt-6 button {
            font-size: 1em;
            width: 100%;
            padding: 20px;
            height: 60px;
        }

        div#txt2img_tab {
            color: #BED336;
        }

'''

default_generate_config = {
        "do_sample": False,
        "top_p": 0,
        "top_k": 50,
        "max_length": 64,
        "temperature": 1,
        "num_beams": 1,
        "length_penalty": 0.6
    }

task_styles = []
examples_list = []
task_style_to_task_prefix = {}
import csv
examples_set = set()
def read_examples(input_file):
    header = True
    with open(input_file) as finput:
        csv_input = csv.reader(finput)
        for line in csv_input:
            if header:
                header = False
                continue
            task_style, task_prefix, example = line
            task_styles.append(task_style)
            task_style_to_task_prefix[task_style] = task_prefix
            examples_list.append([task_style, example])
            examples_set.add((task_style, example))
read_examples("./examples.csv")
#print(task_styles)
def preprocess(text, task):
    if task == "问答":
        text = text.replace("?", ":").replace("?", ":")
        text = text + ":"

    return task_style_to_task_prefix[task] + "\n" + text + "\n答案:"
      
def inference_gen(text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty):
    default_example = (task, text) in examples_set
    text = preprocess(text, task)
    generate_config = {
        "do_sample": do_sample,
        "top_p": top_p,
        "top_k": top_k,
        "max_length": max_token,
        "temperature": temperature,
        "num_beams": beam_size,
        "length_penalty": length_penalty
    }
    #print(generate_config)
    #print(text)
    default_example = default_example and generate_config == default_generate_config
    try_num = 3
    while try_num:
        try:
            if default_example:
                prediction = cl.generate(
                    model_name='clueai-base',
                    prompt=text)
            else:
                prediction = cl.generate(
                    model_name='clueai-base',
                    prompt=text,
                    generate_config=generate_config)
        except Exception as e:
            logger.error(f"error, {e}")
            return
        if prediction.generations[0].text != "含有违规词,不予展示":
            break
        try_num -= 1

    return prediction.generations[0].text
     
t2i_default_img_path_list = []
import base64, requests
from io import BytesIO
from PIL import Image
def luck_inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale):
    return inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale, luck=True)

def inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale, luck=False):
    try:
        res = requests.get(f"https://www.clueai.cn/clueai/hf_text2image?text={text}&negative_prompt={n_text}\
&guidance_scale={guidance_scale}&num_inference_steps={steps}\
&style={style}&shape={shape}&clarity={clarity}&shape_scale={shape_scale}&luck={luck}")
    except Exception as e:
        logger.error(f"error, {e}")
        return 
    json_dict = res.json()
    file_path_list = []
    for i, image in enumerate(json_dict["images"]):
        image = image.encode('utf-8')
        binary_data = base64.b64decode(image)
        img_data = BytesIO(binary_data)
        img = Image.open(img_data)
        file_path_list.append(img)

    return file_path_list
image_styles = ['无', '细节大师', '对称美', '虚拟引擎', '空间感', '机械风格', '形状艺术', '治愈', '电影构图', '电影构图(治愈)', '荒芜感', '漫画', '逃离艺术', '斯皮尔伯格', '幻想', '杰作', '壁画', '朦胧', '黑白(3d)', '梵高', '毕加索', '莫奈', '丰子恺', '现代', '欧美']
with gr.Blocks(css=css, title="ClueAI") as demo:
    
    gr.Markdown('<h1><center><font color=red style="font-size:50px;">ClueAI全能师</font></center></h1>')
    with gr.TabItem("文本生成", id='_tab'):   
        with gr.Row(variant="compact").style( equal_height=True):
            text = gr.Textbox("标题:俄天然气管道泄漏爆炸",
                label="编辑内容", show_label=False, max_lines=20, 
                placeholder="在这里输入...",
            )
        task = gr.Dropdown(label="任务", show_label=True, choices=task_styles, value="标题生成文章")
        btn = gr.Button("生成",elem_id="gen_btn_1").style(full_width=False)
        with gr.Accordion("高级操作", open=False):
            do_sample = gr.Radio([True, False], label="是否采样", value=False)
            top_p = gr.Slider(0, 1, value=0, step=0.1, label="越大多样性越高, 按照概率采样")
            top_k = gr.Slider(1, 100, value=50, step=1, label="越大多样性越高,按照top k采样")
            max_token = gr.Slider(1, 512, value=64, step=1, label="生成的最大长度")
            temperature = gr.Slider(0,1, value=1, step=0.1, label="temperature, 越小下一个token预测概率越平滑")
            beam_size = gr.Slider(1, 4, value=1, step=1, label="beam size, 越大解码窗口越广,")
            length_penalty = gr.Slider(-1, 1, value=0.6, step=0.1, label="大于0鼓励长句子,小于0鼓励短句子")
        
        with gr.Row(variant="compact").style( equal_height=True):
            output_text = gr.Textbox(
                    label="输出", show_label=True, max_lines=50, 
                    placeholder="在这里展示结果",
                )
        gr.Examples(examples_list, [task, text], label="示例")
        input_params = [text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty]
        #text.submit(inference_gen, inputs=input_params, outputs=output_text)
        btn.click(inference_gen, inputs=input_params, outputs=output_text)

    with gr.TabItem("图像生成", id='txt2img_tab'):   
        with gr.Row(variant="compact").style( equal_height=True):
            text = gr.Textbox("美丽的风景",
                label="编辑内容", show_label=False, max_lines=2, 
                placeholder="在这里输入你的描述...",
            )
            btn = gr.Button("生成图像",elem_id="t2i_btn_1").style(full_width=False)
            
        with gr.Row().style( equal_height=True):
            generate_prompt_btn = gr.Button("手气不错", elem_id="luck_t2i_btn_1")

        style = gr.Dropdown(label="风格", show_label=True, choices=image_styles, value="无")
        with gr.Accordion("高级操作", open=False):
            n_text = gr.Textbox("",
                label="不想要生成的元素", show_label=True, max_lines=2, 
                placeholder="在这里输入你不需要包含的内容...",
            )     
            guidance_scale = gr.Slider(1, 20, value=7.5, step=0.5, label="和你的描述匹配程度,越大越匹配")
            shape = gr.Radio(["1x1", "16x9", "手机壁纸"], label="尺寸", value="1x1")
            shape_scale = gr.Radio([1, 2, 3], label="对图放大倍数", value=1)
            steps = gr.Slider(10, 150, value=50, step=1, label="越大质量越好,生成时间越长")
            clarity = gr.Radio(["标清", "高清"], label="清晰度", value="标清")

        gr.Examples(["秋日的晚霞", "星空", "室内装修", "婚礼鲜花"], text, label="示例")
        
        t2i_gallery = gr.Gallery(
            t2i_default_img_path_list,
            label="生成图像",
             show_label=False).style(
            grid=[2], height="auto"
        )

        input_params = [text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale]
        generate_prompt_btn.click(luck_inference_image, inputs=input_params, outputs=[t2i_gallery])
        text.submit(inference_image, inputs=input_params, outputs=t2i_gallery)
        btn.click(inference_image, inputs=input_params, outputs=t2i_gallery)
    # Page Count
    gr.Markdown("""
                <center><a href="https://clustrmaps.com/site/1bsr8"  title="Visit tracker"><img src="//www.clustrmaps.com/map_v2.png?d=OBV_rLBLpgrXBPyk_STupM-rByau5s53eEWDitHdn_Q&cl=ffffff" /></a></center>
                """)
#demo.queue(concurrency_count=3)
demo.launch()