import gradio as gr
from PIL import Image
import torch
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from urllib.request import urlopen
import spaces

# Define model path
model_path = "microsoft/Phi-4-multimodal-instruct"

# Load model and processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype="auto",
    trust_remote_code=True,
    _attn_implementation="eager",
)

# Define prompt structure
user_prompt = '<|user|>'
assistant_prompt = '<|assistant|>'
prompt_suffix = '<|end|>'

# Define inference function
@spaces.GPU
def process_input(input_type, file, question):
    if not file or not question:
        return "Please upload a file and provide a question."

    # Prepare the prompt
    if input_type == "Image":
        prompt = f'{user_prompt}<|image_1|>{question}{prompt_suffix}{assistant_prompt}'
        # Open image from uploaded file
        image = Image.open(file)
        inputs = processor(text=prompt, images=image, return_tensors='pt').to(model.device)
    elif input_type == "Audio":
        prompt = f'{user_prompt}<|audio_1|>{question}{prompt_suffix}{assistant_prompt}'
        # Read audio from uploaded file
        audio, samplerate = sf.read(file)
        inputs = processor(text=prompt, audios=[(audio, samplerate)], return_tensors='pt').to(model.device)
    else:
        return "Invalid input type selected."

    # Generate response
    with torch.no_grad():
        generate_ids = model.generate(
            **inputs,
            max_new_tokens=200,
            num_logits_to_keep=0,
        )
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
    response = processor.batch_decode(
        generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    return response

# Gradio interface
with gr.Blocks(
    title="Phi-4 Multimodal Demo",
    theme=gr.themes.Soft(
        primary_hue="blue",
        secondary_hue="gray",
        radius_size="lg",
    ),
) as demo:
    gr.Markdown(
        """
        # Phi-4 Multimodal Demo
        Upload an **image** or **audio** file, ask a question, and get a response from the model!  
        
        Built with the `microsoft/Phi-4-multimodal-instruct` model by Microsoft.
        
        Credits: Grok from xAI helped me in making this demo.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            input_type = gr.Radio(
                choices=["Image", "Audio"],
                label="Select Input Type",
                value="Image",
            )
            file_input = gr.File(
                label="Upload Your File",
                file_types=["image", "audio"],
            )
            question_input = gr.Textbox(
                label="Your Question",
                placeholder="e.g., 'What is shown in this image?' or 'Transcribe this audio.'",
                lines=2,
            )
            submit_btn = gr.Button("Submit", variant="primary")
        
        with gr.Column(scale=2):
            output_text = gr.Textbox(
                label="Model Response",
                placeholder="Response will appear here...",
                lines=10,
                interactive=False,
            )

    # Example section
    with gr.Accordion("Examples", open=False):
        gr.Markdown("Try these examples:")
        gr.Examples(
            examples=[
                ["Image", "https://www.ilankelman.org/stopsigns/australia.jpg", "What is shown in this image?"],
                ["Audio", "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac", "Transcribe the audio to text."],
            ],
            inputs=[input_type, file_input, question_input],
            outputs=output_text,
            fn=process_input,
            cache_examples=False,
        )

    # Connect the submit button
    submit_btn.click(
        fn=process_input,
        inputs=[input_type, file_input, question_input],
        outputs=output_text,
    )

# Launch the demo
demo.launch()