import cv2
from mtcnn.mtcnn import MTCNN
from utils import *


cloth_examples = get_cloth_examples(hr=0)
cloth_hr_examples = get_cloth_examples(hr=1)
pose_examples = get_pose_examples()
tip1, tip2 = get_tips()
face_detector = MTCNN()

# Description
title = r"""
<h1 align="center">Outfit Anyway: Best customer try-on You ever See</h1>
"""

description = r"""
<b>Join discord to know more about </b> <a href='https://discord.com/invite/QgJWCtSG58' target='_blank'><b> heybeauty prebuy vton solution</b></a>.<br>
"""


def onClick(cloth_image, pose_image, high_resolution, request: gr.Request):
    if pose_image is None:
        yield None, "no pose image found !", ""
        return None, "no pose image found !", ""
    if cloth_image is None:
        yield None, "no cloth image found !", ""
        return None, "no cloth image found !", ""

    pose_id = os.path.basename(pose_image).split(".")[0]
    cloth_id = int(os.path.basename(cloth_image).split(".")[0])
            
    try:

        client_ip = request.client.host
        x_forwarded_for = dict(request.headers).get('x-forwarded-for')
        if x_forwarded_for:
            client_ip = x_forwarded_for
            
        pose_np = cv2.imread(pose_image)
        faces = face_detector.detect_faces(pose_np[:,:,::-1])
        if len(faces)==0:
            print(client_ip, 'faces num is 0! ', flush=True)
            yield None, "Fatal Error !!! No face detected !!! You must upload a human photo!!! Not clothing photo!!!", ""
            return None, "Fatal Error !!! No face detected !!! You must upload a human photo!!! Not clothing photo!!!", ""
        else:
            x, y, w, h = faces[0]["box"]
            H, W = pose_np.shape[:2]
            max_face_ratio = 1/3.3
            if w/W>max_face_ratio or h/H>max_face_ratio:
                yield None, "Fatal Error !!! Headshot is not allowed !!! You must upload a full-body or half-body photo!!!", ""
                return None, "Fatal Error !!! Headshot is not allowed !!! You must upload a full-body or half-body photo!!!", ""

        if not check_region_warp(client_ip):
            yield None, "Failed !!! Our server is under maintenance, please try again later", ""
            return None, "Failed !!! Our server is under maintenance, please try again later", ""
            
        # client_ip = '8.8.8.8'
        yield None, "begin to upload ", ""

        timeId = int(  str(time.time()).replace(".", "")  )+random.randint(1000, 9999)
        upload_url = upload_pose_img(client_ip, timeId, pose_image)
        # exit(0)
        yield None, "begin to public task ", ""
        # return None, "begin to public task ", ""
        
        if len(upload_url)==0:
            yield None, "fail to upload", ""
            return None, "fail to upload", ""

        if high_resolution:
            public_res = publicClothSwap(upload_url, cloth_id, is_hr=1)
        else:
            public_res = publicClothSwap(upload_url, cloth_id, is_hr=0)
        if public_res is None:
            yield None, "fail to public you task", ""
            return None, "fail to public you task", ""

        print(client_ip, public_res['mid_result'])
        yield public_res['mid_result'], f"task is processing, task id: {public_res['id']}, {public_res['msg']}", ""

        max_try = 120*3
        wait_s = 0.5
        for i in range(max_try):
            time.sleep(wait_s)
            state = getInfRes(public_res['id'])
            timestamp = int(time.time() * 1000)
            if state is None:
                yield public_res['mid_result'] + f"?t={timestamp}", "task query failed,", ""
            elif state['status']=='PROCESSING':
                yield public_res['mid_result'] + f"?t={timestamp}", f"task is processing, query {i}", ""
            elif state['status']=='SUCCEED':
                yield state['output1'] + f"?t={timestamp}", f"task finished, {state['msg']}", ""
                return state['output1'] + f"?t={timestamp}", f"task finished, {state['msg']}", ""
            elif state['status']=='FAILED':
                yield None, f"task failed, {state['msg']}", ""
                return None, f"task failed, {state['msg']}", ""
            else:
                yield public_res['mid_result'] + f"?t={timestamp}", f"task is on processing, query {i}", ""
        return None, "no machine...", ""
    except Exception as e:
        print(e)
        raise e
        return None, "fail to create task", ""

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Accordion('upload tips', open=False):
        with gr.Row():
            gr.HTML(f"<img src=\"{tip1}\" >")
            gr.HTML(f"<img src=\"{tip2}\" >")
                    
    with gr.Row():
        with gr.Column():
            cloth_image = gr.Image(value=None, interactive=False, type="filepath", label="choose a clothing")
            example = gr.Examples(inputs=cloth_image,examples_per_page=20,examples=cloth_examples, label="clothing")
            hr_example = gr.Examples(inputs=cloth_image,examples_per_page=9,examples=cloth_hr_examples, label="invalid clothing")
            
        with gr.Column():
            pose_image = gr.Image(value=None, type="filepath", label="choose/upload a photo")
            example_pose = gr.Examples(inputs=pose_image,
                                      examples_per_page=20,
                                      examples=pose_examples)
            
        with gr.Column():
            with gr.Column():
                # size_slider = gr.Slider(-3, 3, value=1, interactive=True, label="clothes size")
                high_resolution = gr.Checkbox(value=False, label="high resolution", interactive=True)
                
                run_button = gr.Button(value="Run")
                info_text = gr.Textbox(value="", interactive=False, 
                    label='runtime information')                
                res_image = gr.Image(label="result image", value=None, type="filepath")
                MK01 = gr.Markdown()

    run_button.click(fn=onClick, inputs=[cloth_image, pose_image, high_resolution], 
                     outputs=[res_image, info_text, MK01])


if __name__ == "__main__":

    demo.queue(max_size=50)
    # demo.queue(concurrency_count=60)
    # demo.launch(server_name='0.0.0.0', server_port=225)
    demo.launch(server_name='0.0.0.0')