import gradio as gr
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
from PIL import Image
import torch

# Load model configuration
model_path = "microsoft/llava-med-v1.5-mistral-7b"
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, _ = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=model_name,
    device_map="cpu",
    load_4bit=False
)

def analyze_medical_image(image, question):
    # Convert Gradio input to PIL Image
    if isinstance(image, str):
        image = Image.open(image)
    else:
        image = Image.fromarray(image)
    
    # Prepare prompt
    prompt = f"<image>\nUSER: {question}\nASSISTANT:"
    
    # Run inference
    args = type('Args', (), {
        "model_name": model_name,
        "query": prompt,
        "conv_mode": None,
        "image_file": image,
        "sep": ",",
        "temperature": 0.2,
        "top_p": None,
        "num_beams": 1,
        "max_new_tokens": 512
    })()
    
    return eval_model(args, tokenizer, model, image_processor)

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# LLaVA-Med Medical Analysis")
    with gr.Row():
        gr.Image(type="pil", label="Input Image", source="upload", elem_id="image")
        gr.Textbox(label="Question", placeholder="Ask about the medical image...")
        gr.Textbox(label="Analysis Result", interactive=False)
    
    examples = [
        ["examples/xray.jpg", "Are there any signs of pneumonia in this chest X-ray?"],
        ["examples/mri.jpg", "What abnormalities are visible in this brain MRI?"]
    ]
    
    gr.Examples(examples=examples, inputs=[image, question])
    
demo.launch()