lumenex / medgemma_space.py
walaa2022's picture
Upload 2 files
b458509 verified
raw
history blame
6.22 kB
# app.py - Main Gradio application
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import os
# Model configuration
MODEL_ID = "google/medgemma-4b-it"
# Load model and processor
@gr.utils.async_wrapper
def load_model():
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
return model, processor
# Initialize model (this will be cached)
model, processor = load_model()
def analyze_medical_image(image, clinical_question, patient_history=""):
"""
Analyze medical image with clinical context
"""
if image is None:
return "Please upload an image first."
# Prepare the conversation
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are an expert medical AI assistant. Provide detailed analysis while emphasizing that this is for educational purposes only and should not replace professional medical diagnosis."}]
}
]
# Add patient history if provided
if patient_history.strip():
messages.append({
"role": "user",
"content": [{"type": "text", "text": f"Patient History: {patient_history}"}]
})
# Add the main question with image
user_content = [{"type": "text", "text": clinical_question}]
if image:
user_content.append({"type": "image", "image": image})
messages.append({
"role": "user",
"content": user_content
})
try:
# Process inputs
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
# Generate response
with torch.inference_mode():
generation = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=True,
temperature=0.7,
top_p=0.9
)
generation = generation[0][input_len:]
# Decode response
response = processor.decode(generation, skip_special_tokens=True)
# Add disclaimer
disclaimer = "\n\n⚠️ IMPORTANT DISCLAIMER: This analysis is for educational and research purposes only. Always consult qualified healthcare professionals for medical diagnosis and treatment decisions."
return response + disclaimer
except Exception as e:
return f"Error processing request: {str(e)}"
# Create Gradio interface
def create_interface():
with gr.Blocks(title="MedGemma Medical Image Analysis", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ₯ MedGemma Medical Image Analysis
**Educational Medical AI Assistant powered by Google's MedGemma-4B**
⚠️ **Important**: This tool is for educational and research purposes only.
Do not use real patient data. Always consult healthcare professionals for medical decisions.
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Medical Image",
type="pil",
height=400
)
clinical_question = gr.Textbox(
label="Clinical Question",
placeholder="e.g., 'Describe the findings in this chest X-ray' or 'What pathological changes do you observe?'",
lines=3
)
patient_history = gr.Textbox(
label="Patient History (Optional)",
placeholder="e.g., '45-year-old male with chronic cough and shortness of breath'",
lines=3
)
analyze_btn = gr.Button("πŸ” Analyze Image", variant="primary", size="lg")
with gr.Column(scale=1):
output = gr.Textbox(
label="Medical Analysis",
lines=20,
max_lines=30
)
# Example cases
gr.Markdown("## πŸ“‹ Example Cases")
examples = gr.Examples(
examples=[
[
"https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
"Describe the findings in this chest X-ray and identify any abnormalities.",
"Adult patient with respiratory symptoms"
],
[
None, # User will upload their own
"What pathological changes are visible in this medical image?",
""
],
[
None,
"Provide a differential diagnosis based on the imaging findings.",
"Patient presenting with acute symptoms"
]
],
inputs=[image_input, clinical_question, patient_history]
)
# Event handlers
analyze_btn.click(
fn=analyze_medical_image,
inputs=[image_input, clinical_question, patient_history],
outputs=output
)
# Footer
gr.Markdown("""
---
**Data Privacy Notice**: Do not upload real patient data or personally identifiable information.
Use only synthetic, anonymized, or publicly available medical images for demonstration purposes.
**Model**: Google MedGemma-4B | **Purpose**: Educational and Research Use Only
""")
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)