|
|
|
import gradio as gr |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
from PIL import Image |
|
import os |
|
|
|
|
|
MODEL_ID = "google/medgemma-4b-it" |
|
|
|
|
|
@gr.utils.async_wrapper |
|
def load_model(): |
|
model = AutoModelForImageTextToText.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
return model, processor |
|
|
|
|
|
model, processor = load_model() |
|
|
|
def analyze_medical_image(image, clinical_question, patient_history=""): |
|
""" |
|
Analyze medical image with clinical context |
|
""" |
|
if image is None: |
|
return "Please upload an image first." |
|
|
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": "You are an expert medical AI assistant. Provide detailed analysis while emphasizing that this is for educational purposes only and should not replace professional medical diagnosis."}] |
|
} |
|
] |
|
|
|
|
|
if patient_history.strip(): |
|
messages.append({ |
|
"role": "user", |
|
"content": [{"type": "text", "text": f"Patient History: {patient_history}"}] |
|
}) |
|
|
|
|
|
user_content = [{"type": "text", "text": clinical_question}] |
|
if image: |
|
user_content.append({"type": "image", "image": image}) |
|
|
|
messages.append({ |
|
"role": "user", |
|
"content": user_content |
|
}) |
|
|
|
try: |
|
|
|
inputs = processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(model.device, dtype=torch.bfloat16) |
|
|
|
input_len = inputs["input_ids"].shape[-1] |
|
|
|
|
|
with torch.inference_mode(): |
|
generation = model.generate( |
|
**inputs, |
|
max_new_tokens=1000, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9 |
|
) |
|
generation = generation[0][input_len:] |
|
|
|
|
|
response = processor.decode(generation, skip_special_tokens=True) |
|
|
|
|
|
disclaimer = "\n\nβ οΈ IMPORTANT DISCLAIMER: This analysis is for educational and research purposes only. Always consult qualified healthcare professionals for medical diagnosis and treatment decisions." |
|
|
|
return response + disclaimer |
|
|
|
except Exception as e: |
|
return f"Error processing request: {str(e)}" |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="MedGemma Medical Image Analysis", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# π₯ MedGemma Medical Image Analysis |
|
|
|
**Educational Medical AI Assistant powered by Google's MedGemma-4B** |
|
|
|
β οΈ **Important**: This tool is for educational and research purposes only. |
|
Do not use real patient data. Always consult healthcare professionals for medical decisions. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image( |
|
label="Medical Image", |
|
type="pil", |
|
height=400 |
|
) |
|
|
|
clinical_question = gr.Textbox( |
|
label="Clinical Question", |
|
placeholder="e.g., 'Describe the findings in this chest X-ray' or 'What pathological changes do you observe?'", |
|
lines=3 |
|
) |
|
|
|
patient_history = gr.Textbox( |
|
label="Patient History (Optional)", |
|
placeholder="e.g., '45-year-old male with chronic cough and shortness of breath'", |
|
lines=3 |
|
) |
|
|
|
analyze_btn = gr.Button("π Analyze Image", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
output = gr.Textbox( |
|
label="Medical Analysis", |
|
lines=20, |
|
max_lines=30 |
|
) |
|
|
|
|
|
gr.Markdown("## π Example Cases") |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
[ |
|
"https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png", |
|
"Describe the findings in this chest X-ray and identify any abnormalities.", |
|
"Adult patient with respiratory symptoms" |
|
], |
|
[ |
|
None, |
|
"What pathological changes are visible in this medical image?", |
|
"" |
|
], |
|
[ |
|
None, |
|
"Provide a differential diagnosis based on the imaging findings.", |
|
"Patient presenting with acute symptoms" |
|
] |
|
], |
|
inputs=[image_input, clinical_question, patient_history] |
|
) |
|
|
|
|
|
analyze_btn.click( |
|
fn=analyze_medical_image, |
|
inputs=[image_input, clinical_question, patient_history], |
|
outputs=output |
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
--- |
|
**Data Privacy Notice**: Do not upload real patient data or personally identifiable information. |
|
Use only synthetic, anonymized, or publicly available medical images for demonstration purposes. |
|
|
|
**Model**: Google MedGemma-4B | **Purpose**: Educational and Research Use Only |
|
""") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True |
|
) |