lumenex / app.py
walaa2022's picture
Update app.py
2c0541f verified
raw
history blame
13.3 kB
# app.py - Working MedGemma with Correct Implementation
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
from PIL import Image
import os
import logging
from huggingface_hub import login
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Authenticate with Hugging Face
def authenticate_hf():
"""Authenticate with Hugging Face using token"""
try:
hf_token = os.getenv('HF_TOKEN')
if hf_token:
login(token=hf_token)
logger.info("βœ… Authenticated with Hugging Face")
return True
else:
logger.warning("⚠️ No HF_TOKEN found in environment")
return False
except Exception as e:
logger.error(f"❌ Authentication failed: {e}")
return False
# Model configuration
MODEL_ID = "google/medgemma-4b-it"
# Global variables
model = None
processor = None
pipeline_model = None
def load_model():
"""Load MedGemma model using the recommended approach"""
global model, processor, pipeline_model
try:
# First authenticate
auth_success = authenticate_hf()
if not auth_success:
logger.error("❌ Authentication required for MedGemma")
return False
logger.info(f"Loading MedGemma: {MODEL_ID}")
# Method 1: Try using pipeline (recommended by HuggingFace)
try:
logger.info("Attempting to load using pipeline...")
pipeline_model = pipeline(
"image-text-to-text",
model=MODEL_ID,
torch_dtype=torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
logger.info("βœ… Pipeline model loaded successfully!")
return True
except Exception as e:
logger.warning(f"Pipeline loading failed: {e}")
# Method 2: Try direct model loading
logger.info("Attempting direct model loading...")
# Load processor
processor = AutoProcessor.from_pretrained(
MODEL_ID,
trust_remote_code=True,
token=True
)
logger.info("βœ… Processor loaded")
# Load model
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
token=True
)
logger.info("βœ… Model loaded successfully!")
return True
except Exception as e:
logger.error(f"❌ Error loading model: {str(e)}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
return False
# Initialize model at startup
model_loaded = load_model()
def analyze_medical_image(image, clinical_question, patient_history=""):
"""Analyze medical image with clinical context"""
global model, processor, pipeline_model
# Check if model is loaded
if not model_loaded:
return """❌ **Model Loading Issue**
MedGemma failed to load. This is likely due to:
1. **Transformers version**: Make sure you're using transformers >= 4.52.0
2. **Authentication**: Ensure HF_TOKEN is properly set
3. **Model compatibility**: MedGemma requires the latest transformers library
**Status**: Model loading failed. Please try refreshing the page or contact support."""
if image is None:
return "⚠️ Please upload a medical image first."
if not clinical_question.strip():
return "⚠️ Please provide a clinical question."
try:
# Method 1: Use pipeline if available
if pipeline_model is not None:
logger.info("Using pipeline for analysis...")
# Prepare message in the format expected by pipeline
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": f"Patient History: {patient_history}\n\nClinical Question: {clinical_question}\n\nAs MedGemma, provide a detailed medical analysis of this image for educational purposes only."}
]
}
]
# Generate response using pipeline
result = pipeline_model(messages, max_new_tokens=1000)
# Extract response text
response = result[0]['generated_text'] if isinstance(result, list) else result['generated_text']
# Method 2: Use direct model if pipeline failed
elif model is not None and processor is not None:
logger.info("Using direct model for analysis...")
# Prepare messages for direct model
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are MedGemma, an expert medical AI assistant. Provide detailed medical analysis for educational purposes only."}]
},
{
"role": "user",
"content": [
{"type": "text", "text": f"Patient History: {patient_history}\n\nClinical Question: {clinical_question}"},
{"type": "image", "image": image}
]
}
]
# Process inputs
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
)
# Generate response
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=True,
temperature=0.3,
top_p=0.9
)
# Decode response
response = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
else:
return "❌ No model available for analysis. Please try refreshing the page."
# Clean up response
response = response.strip()
# Add medical disclaimer
disclaimer = """
---
### ⚠️ MEDICAL DISCLAIMER
**This analysis is for educational and research purposes only.**
- This AI assistant is not a substitute for professional medical advice
- Always consult qualified healthcare professionals for diagnosis and treatment
- Do not make medical decisions based solely on this analysis
- In case of medical emergency, contact emergency services immediately
---
"""
logger.info("βœ… Analysis completed successfully")
return response + disclaimer
except Exception as e:
logger.error(f"❌ Error in analysis: {str(e)}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
return f"❌ Analysis failed: {str(e)}\n\nPlease try again with a different image or question."
# Create Gradio interface
def create_interface():
with gr.Blocks(
title="MedGemma Medical Analysis",
theme=gr.themes.Soft(),
css="""
.gradio-container { max-width: 1200px !important; }
.disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; }
.success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px; margin: 16px 0; }
.warning { background-color: #fffbeb; border: 1px solid #fed7aa; border-radius: 8px; padding: 16px; margin: 16px 0; }
"""
) as demo:
# Header
gr.Markdown("""
# πŸ₯ MedGemma Medical Image Analysis
**Advanced Medical AI Assistant powered by Google's MedGemma-4B**
Specialized in medical imaging across multiple modalities:
🫁 **Radiology** β€’ πŸ”¬ **Histopathology** β€’ πŸ‘οΈ **Ophthalmology** β€’ 🩺 **Dermatology**
""")
# Status display
if model_loaded:
method = "Pipeline" if pipeline_model else "Direct Model"
gr.Markdown(f"""
<div class="success">
βœ… <strong>MEDGEMMA READY</strong><br>
Model loaded successfully using {method} method. Ready for medical image analysis.
</div>
""")
else:
gr.Markdown("""
<div class="warning">
⚠️ <strong>MODEL LOADING FAILED</strong><br>
MedGemma failed to load. Please ensure you have the latest transformers library and proper authentication.
</div>
""")
# Medical disclaimer
gr.Markdown("""
<div class="disclaimer">
⚠️ <strong>IMPORTANT MEDICAL DISCLAIMER</strong><br>
This tool is for <strong>educational and research purposes only</strong>.
Do not upload real patient data. Always consult qualified healthcare professionals.
</div>
""")
with gr.Row():
# Left column
with gr.Column(scale=1):
gr.Markdown("## πŸ“€ Medical Image Upload")
image_input = gr.Image(
label="Medical Image",
type="pil",
height=300
)
clinical_question = gr.Textbox(
label="Clinical Question *",
placeholder="Examples:\nβ€’ Describe findings in this chest X-ray\nβ€’ What pathological changes are visible?\nβ€’ Provide differential diagnosis\nβ€’ Identify abnormalities",
lines=4
)
patient_history = gr.Textbox(
label="Patient History (Optional)",
placeholder="e.g., 65-year-old male with chronic cough",
lines=2
)
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
analyze_btn = gr.Button("πŸ” Analyze", variant="primary", size="lg")
# System info
gr.Markdown(f"""
**Status:** {'βœ… Ready' if model_loaded else '❌ Failed'}
**Method:** {'Pipeline' if pipeline_model else 'Direct' if model else 'None'}
**Device:** {'CUDA' if torch.cuda.is_available() else 'CPU'}
**Transformers:** {getattr(__import__('transformers'), '__version__', 'Unknown')}
""")
# Right column
with gr.Column(scale=1):
gr.Markdown("## πŸ“‹ Medical Analysis Results")
output = gr.Textbox(
label="AI Medical Analysis",
lines=20,
show_copy_button=True,
placeholder="Upload a medical image and ask a clinical question..." if model_loaded else "Model unavailable - please check system status"
)
# Examples
if model_loaded:
with gr.Accordion("πŸ“š Example Cases", open=False):
examples = gr.Examples(
examples=[
[
"https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
"Analyze this chest X-ray systematically. Comment on heart size, lung fields, and any abnormalities.",
"Adult patient with respiratory symptoms"
]
],
inputs=[image_input, clinical_question, patient_history]
)
# Event handlers
analyze_btn.click(
fn=analyze_medical_image,
inputs=[image_input, clinical_question, patient_history],
outputs=output,
show_progress=True
)
clear_btn.click(
fn=lambda: (None, "", "", ""),
outputs=[image_input, clinical_question, patient_history, output]
)
# Footer
gr.Markdown("""
---
### πŸ”¬ About MedGemma
MedGemma-4B is Google's specialized medical AI model requiring transformers >= 4.52.0.
### πŸ”’ Privacy & Ethics
- Real-time processing, no data storage
- Educational and research purposes only
- No patient data should be uploaded
**Model:** Google MedGemma-4B | **License:** Apache 2.0
""")
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)