import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import ImageDraw


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

models = {
    "AskUI/PTA-1": AutoModelForCausalLM.from_pretrained("AskUI/PTA-1", trust_remote_code=True),
}

processors = {
    "AskUI/PTA-1": AutoProcessor.from_pretrained("AskUI/PTA-1", trust_remote_code=True)
}


def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=3):
    draw = ImageDraw.Draw(image)
    for box in bounding_boxes:
        xmin, ymin, xmax, ymax = box
        draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
    return image


def florence_output_to_box(output):
    try:
        if "polygons" in output and len(output["polygons"]) > 0:
            polygons = output["polygons"]
            target_polygon = polygons[0][0]
            target_polygon = [int(el) for el in target_polygon]
            return [
                target_polygon[0],
                target_polygon[1],
                target_polygon[4],
                target_polygon[5],
            ]
        if "bboxes" in output and len(output["bboxes"]) > 0:
            bboxes = output["bboxes"]
            target_bbox = bboxes[0]
            target_bbox = [int(el) for el in target_bbox]
            return target_bbox
    except Exception as e:
        print(f"Error: {e}")
    return None


@spaces.GPU
def run_example(image, text_input, model_id="AskUI/PTA-1"):
    model = models[model_id].to(device, torch_dtype)
    processor = processors[model_id]
    task_prompt = "<OPEN_VOCABULARY_DETECTION>"
    prompt = task_prompt + text_input

    image = image.convert("RGB")

    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)

    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        do_sample=False,
        num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(generated_text, task="<OPEN_VOCABULARY_DETECTION>", image_size=(image.width, image.height))
    target_box = florence_output_to_box(parsed_answer["<OPEN_VOCABULARY_DETECTION>"])
    return target_box, draw_bounding_boxes(image, [target_box])


css = """
  #output {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""
with gr.Blocks(css=css) as demo:
    gr.Markdown(
    """
    # PTA-1: Controlling Computers with Small Models
    """)
    gr.Markdown("Check out the model [AskUI/PTA-1](https://huggingface.co/AskUI/PTA-1).")
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="Input Image", type="pil")
            model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="AskUI/PTA-1")
            text_input = gr.Textbox(label="User Prompt")
            submit_btn = gr.Button(value="Submit")
        with gr.Column():
            model_output_text = gr.Textbox(label="Model Output Text")
            annotated_image = gr.Image(label="Annotated Image")

    gr.Examples(
        examples=[
            ["assets/sample.png", "search box"],
            ["assets/sample.png", "Query Service"],
            ["assets/ipad.png", "App Store icon"],
            ["assets/ipad.png", 'colorful icon with letter "S"'],
            ["assets/phone.jpg", "password field"],
            ["assets/phone.jpg", "back arrow icon"],
            ["assets/windows.jpg", "icon with letter S"],
            ["assets/windows.jpg", "Settings"],
        ],
        inputs=[input_img, text_input],
        outputs=[model_output_text, annotated_image],
        fn=run_example,
        cache_examples=False,
        label="Try examples"
    )

    submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, annotated_image])

demo.launch(debug=False)