#!/usr/bin/env python

import os

import gradio as gr
import PIL.Image
import spaces
import torch
from transformers import InstructBlipForConditionalGeneration, InstructBlipProcessor

DESCRIPTION = "# InstructBLIP"

MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_id = "Salesforce/instructblip-vicuna-7b"
processor = InstructBlipProcessor.from_pretrained(model_id)
model = InstructBlipForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")


@spaces.GPU
def run(
    image: PIL.Image.Image,
    prompt: str,
    text_decoding_method: str = "Nucleus sampling",
    num_beams: int = 5,
    max_length: int = 256,
    min_length: int = 1,
    top_p: float = 0.9,
    repetition_penalty: float = 1.5,
    length_penalty: float = 1.0,
    temperature: float = 1.0,
) -> str:
    h, w = image.size
    scale = MAX_IMAGE_SIZE / max(h, w)
    if scale < 1:
        new_w = int(w * scale)
        new_h = int(h * scale)
        image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS)

    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(
        **inputs,
        do_sample=text_decoding_method == "Nucleus sampling",
        num_beams=num_beams,
        max_length=max_length,
        min_length=min_length,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        temperature=temperature,
    )
    return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()


with gr.Blocks(css_paths="style.css") as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil")
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button()
            with gr.Accordion(label="Advanced options", open=False):
                text_decoding_method = gr.Radio(
                    label="Text Decoding Method",
                    choices=["Beam search", "Nucleus sampling"],
                    value="Nucleus sampling",
                )
                num_beams = gr.Slider(
                    label="Number of Beams",
                    minimum=1,
                    maximum=10,
                    step=1,
                    value=5,
                )
                max_length = gr.Slider(
                    label="Max Length",
                    minimum=1,
                    maximum=512,
                    step=1,
                    value=256,
                )
                min_length = gr.Slider(
                    label="Minimum Length",
                    minimum=1,
                    maximum=64,
                    step=1,
                    value=1,
                )
                top_p = gr.Slider(
                    label="Top P",
                    minimum=0.1,
                    maximum=1.0,
                    step=0.1,
                    value=0.9,
                )
                repetition_penalty = gr.Slider(
                    label="Repetition Penalty",
                    info="Larger value prevents repetition.",
                    minimum=1.0,
                    maximum=5.0,
                    step=0.5,
                    value=1.5,
                )
                length_penalty = gr.Slider(
                    label="Length Penalty",
                    info="Set to larger for longer sequence, used with beam search.",
                    minimum=-1.0,
                    maximum=2.0,
                    step=0.2,
                    value=1.0,
                )
                temperature = gr.Slider(
                    label="Temperature",
                    info="Used with nucleus sampling.",
                    minimum=0.5,
                    maximum=1.0,
                    step=0.1,
                    value=1.0,
                )

        with gr.Column():
            output = gr.Textbox(label="Result")

    gr.on(
        triggers=[prompt.submit, run_button.click],
        fn=run,
        inputs=[
            input_image,
            prompt,
            text_decoding_method,
            num_beams,
            max_length,
            min_length,
            top_p,
            repetition_penalty,
            length_penalty,
            temperature,
        ],
        outputs=output,
        api_name="run",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()