File size: 6,219 Bytes
b458509 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# app.py - Main Gradio application
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import os
# Model configuration
MODEL_ID = "google/medgemma-4b-it"
# Load model and processor
@gr.utils.async_wrapper
def load_model():
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
return model, processor
# Initialize model (this will be cached)
model, processor = load_model()
def analyze_medical_image(image, clinical_question, patient_history=""):
"""
Analyze medical image with clinical context
"""
if image is None:
return "Please upload an image first."
# Prepare the conversation
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are an expert medical AI assistant. Provide detailed analysis while emphasizing that this is for educational purposes only and should not replace professional medical diagnosis."}]
}
]
# Add patient history if provided
if patient_history.strip():
messages.append({
"role": "user",
"content": [{"type": "text", "text": f"Patient History: {patient_history}"}]
})
# Add the main question with image
user_content = [{"type": "text", "text": clinical_question}]
if image:
user_content.append({"type": "image", "image": image})
messages.append({
"role": "user",
"content": user_content
})
try:
# Process inputs
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
# Generate response
with torch.inference_mode():
generation = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=True,
temperature=0.7,
top_p=0.9
)
generation = generation[0][input_len:]
# Decode response
response = processor.decode(generation, skip_special_tokens=True)
# Add disclaimer
disclaimer = "\n\n⚠️ IMPORTANT DISCLAIMER: This analysis is for educational and research purposes only. Always consult qualified healthcare professionals for medical diagnosis and treatment decisions."
return response + disclaimer
except Exception as e:
return f"Error processing request: {str(e)}"
# Create Gradio interface
def create_interface():
with gr.Blocks(title="MedGemma Medical Image Analysis", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏥 MedGemma Medical Image Analysis
**Educational Medical AI Assistant powered by Google's MedGemma-4B**
⚠️ **Important**: This tool is for educational and research purposes only.
Do not use real patient data. Always consult healthcare professionals for medical decisions.
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Medical Image",
type="pil",
height=400
)
clinical_question = gr.Textbox(
label="Clinical Question",
placeholder="e.g., 'Describe the findings in this chest X-ray' or 'What pathological changes do you observe?'",
lines=3
)
patient_history = gr.Textbox(
label="Patient History (Optional)",
placeholder="e.g., '45-year-old male with chronic cough and shortness of breath'",
lines=3
)
analyze_btn = gr.Button("🔍 Analyze Image", variant="primary", size="lg")
with gr.Column(scale=1):
output = gr.Textbox(
label="Medical Analysis",
lines=20,
max_lines=30
)
# Example cases
gr.Markdown("## 📋 Example Cases")
examples = gr.Examples(
examples=[
[
"https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
"Describe the findings in this chest X-ray and identify any abnormalities.",
"Adult patient with respiratory symptoms"
],
[
None, # User will upload their own
"What pathological changes are visible in this medical image?",
""
],
[
None,
"Provide a differential diagnosis based on the imaging findings.",
"Patient presenting with acute symptoms"
]
],
inputs=[image_input, clinical_question, patient_history]
)
# Event handlers
analyze_btn.click(
fn=analyze_medical_image,
inputs=[image_input, clinical_question, patient_history],
outputs=output
)
# Footer
gr.Markdown("""
---
**Data Privacy Notice**: Do not upload real patient data or personally identifiable information.
Use only synthetic, anonymized, or publicly available medical images for demonstration purposes.
**Model**: Google MedGemma-4B | **Purpose**: Educational and Research Use Only
""")
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
) |