# app.py - Working MedGemma with Correct Implementation import gradio as gr import torch from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline from PIL import Image import os import logging from huggingface_hub import login # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Authenticate with Hugging Face def authenticate_hf(): """Authenticate with Hugging Face using token""" try: hf_token = os.getenv('HF_TOKEN') if hf_token: login(token=hf_token) logger.info("✅ Authenticated with Hugging Face") return True else: logger.warning("⚠️ No HF_TOKEN found in environment") return False except Exception as e: logger.error(f"❌ Authentication failed: {e}") return False # Model configuration MODEL_ID = "google/medgemma-4b-it" # Global variables model = None processor = None pipeline_model = None def load_model(): """Load MedGemma model using the recommended approach""" global model, processor, pipeline_model try: # First authenticate auth_success = authenticate_hf() if not auth_success: logger.error("❌ Authentication required for MedGemma") return False logger.info(f"Loading MedGemma: {MODEL_ID}") # Method 1: Try using pipeline (recommended by HuggingFace) try: logger.info("Attempting to load using pipeline...") pipeline_model = pipeline( "image-text-to-text", model=MODEL_ID, torch_dtype=torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) logger.info("✅ Pipeline model loaded successfully!") return True except Exception as e: logger.warning(f"Pipeline loading failed: {e}") # Method 2: Try direct model loading logger.info("Attempting direct model loading...") # Load processor processor = AutoProcessor.from_pretrained( MODEL_ID, trust_remote_code=True, token=True ) logger.info("✅ Processor loaded") # Load model model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, torch_dtype=torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, token=True ) logger.info("✅ Model loaded successfully!") return True except Exception as e: logger.error(f"❌ Error loading model: {str(e)}") import traceback logger.error(f"Full traceback: {traceback.format_exc()}") return False # Initialize model at startup model_loaded = load_model() def analyze_medical_image(image, clinical_question, patient_history=""): """Analyze medical image with clinical context""" global model, processor, pipeline_model # Check if model is loaded if not model_loaded: return """❌ **Model Loading Issue** MedGemma failed to load. This is likely due to: 1. **Transformers version**: Make sure you're using transformers >= 4.52.0 2. **Authentication**: Ensure HF_TOKEN is properly set 3. **Model compatibility**: MedGemma requires the latest transformers library **Status**: Model loading failed. Please try refreshing the page or contact support.""" if image is None: return "⚠️ Please upload a medical image first." if not clinical_question.strip(): return "⚠️ Please provide a clinical question." try: # Method 1: Use pipeline if available if pipeline_model is not None: logger.info("Using pipeline for analysis...") # Prepare message in the format expected by pipeline messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": f"Patient History: {patient_history}\n\nClinical Question: {clinical_question}\n\nAs MedGemma, provide a detailed medical analysis of this image for educational purposes only."} ] } ] # Generate response using pipeline result = pipeline_model(messages, max_new_tokens=1000) # Extract response text response = result[0]['generated_text'] if isinstance(result, list) else result['generated_text'] # Method 2: Use direct model if pipeline failed elif model is not None and processor is not None: logger.info("Using direct model for analysis...") # Prepare messages for direct model messages = [ { "role": "system", "content": [{"type": "text", "text": "You are MedGemma, an expert medical AI assistant. Provide detailed medical analysis for educational purposes only."}] }, { "role": "user", "content": [ {"type": "text", "text": f"Patient History: {patient_history}\n\nClinical Question: {clinical_question}"}, {"type": "image", "image": image} ] } ] # Process inputs inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ) # Generate response with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=1000, do_sample=True, temperature=0.3, top_p=0.9 ) # Decode response response = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) else: return "❌ No model available for analysis. Please try refreshing the page." # Clean up response response = response.strip() # Add medical disclaimer disclaimer = """ --- ### ⚠️ MEDICAL DISCLAIMER **This analysis is for educational and research purposes only.** - This AI assistant is not a substitute for professional medical advice - Always consult qualified healthcare professionals for diagnosis and treatment - Do not make medical decisions based solely on this analysis - In case of medical emergency, contact emergency services immediately --- """ logger.info("✅ Analysis completed successfully") return response + disclaimer except Exception as e: logger.error(f"❌ Error in analysis: {str(e)}") import traceback logger.error(f"Full traceback: {traceback.format_exc()}") return f"❌ Analysis failed: {str(e)}\n\nPlease try again with a different image or question." # Create Gradio interface def create_interface(): with gr.Blocks( title="MedGemma Medical Analysis", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } .disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; } .success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px; margin: 16px 0; } .warning { background-color: #fffbeb; border: 1px solid #fed7aa; border-radius: 8px; padding: 16px; margin: 16px 0; } """ ) as demo: # Header gr.Markdown(""" # 🏥 MedGemma Medical Image Analysis **Advanced Medical AI Assistant powered by Google's MedGemma-4B** Specialized in medical imaging across multiple modalities: 🫁 **Radiology** • 🔬 **Histopathology** • 👁️ **Ophthalmology** • 🩺 **Dermatology** """) # Status display if model_loaded: method = "Pipeline" if pipeline_model else "Direct Model" gr.Markdown(f"""