|
import gradio as gr |
|
from llava.model.builder import load_pretrained_model |
|
from llava.mm_utils import get_model_name_from_path |
|
from llava.eval.run_llava import eval_model |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
model_path = "microsoft/llava-med-v1.5-mistral-7b" |
|
model_name = get_model_name_from_path(model_path) |
|
tokenizer, model, image_processor, _ = load_pretrained_model( |
|
model_path=model_path, |
|
model_base=None, |
|
model_name=model_name, |
|
device_map="cpu", |
|
load_4bit=False |
|
) |
|
|
|
def analyze_medical_image(image, question): |
|
|
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
else: |
|
image = Image.fromarray(image) |
|
|
|
|
|
prompt = f"<image>\nUSER: {question}\nASSISTANT:" |
|
|
|
|
|
args = type('Args', (), { |
|
"model_name": model_name, |
|
"query": prompt, |
|
"conv_mode": None, |
|
"image_file": image, |
|
"sep": ",", |
|
"temperature": 0.2, |
|
"top_p": None, |
|
"num_beams": 1, |
|
"max_new_tokens": 512 |
|
})() |
|
|
|
return eval_model(args, tokenizer, model, image_processor) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# LLaVA-Med Medical Analysis") |
|
with gr.Row(): |
|
gr.Image(type="pil", label="Input Image", source="upload", elem_id="image") |
|
gr.Textbox(label="Question", placeholder="Ask about the medical image...") |
|
gr.Textbox(label="Analysis Result", interactive=False) |
|
|
|
examples = [ |
|
["examples/xray.jpg", "Are there any signs of pneumonia in this chest X-ray?"], |
|
["examples/mri.jpg", "What abnormalities are visible in this brain MRI?"] |
|
] |
|
|
|
gr.Examples(examples=examples, inputs=[image, question]) |
|
|
|
demo.launch() |