# app.py - Fixed LLaVA Medical AI with NoneType Error Resolution import gradio as gr import torch import logging from collections import defaultdict, Counter import time import traceback # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Fix the NoneType compatibility issue def fix_transformers_compatibility(): """Fix compatibility issues with transformers library""" try: # Import and fix the parallel styles issue import transformers.modeling_utils as modeling_utils if not hasattr(modeling_utils, 'ALL_PARALLEL_STYLES'): modeling_utils.ALL_PARALLEL_STYLES = [] elif getattr(modeling_utils, 'ALL_PARALLEL_STYLES', None) is None: modeling_utils.ALL_PARALLEL_STYLES = [] # Fix in specific model files try: import transformers.models.llava_next.modeling_llava_next as llava_next if not hasattr(llava_next, 'ALL_PARALLEL_STYLES'): llava_next.ALL_PARALLEL_STYLES = [] elif getattr(llava_next, 'ALL_PARALLEL_STYLES', None) is None: llava_next.ALL_PARALLEL_STYLES = [] except ImportError: pass # Fix in mistral files if they exist try: import transformers.models.mistral.modeling_mistral as mistral if not hasattr(mistral, 'ALL_PARALLEL_STYLES'): mistral.ALL_PARALLEL_STYLES = [] elif getattr(mistral, 'ALL_PARALLEL_STYLES', None) is None: mistral.ALL_PARALLEL_STYLES = [] except ImportError: pass logger.info("✅ Applied compatibility fixes") return True except Exception as e: logger.warning(f"⚠️ Could not apply compatibility fixes: {e}") return False # Apply compatibility fix before imports fix_transformers_compatibility() # Now import transformers try: from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration from PIL import Image logger.info("✅ Transformers imported successfully") except Exception as e: logger.error(f"❌ Failed to import transformers: {e}") # Fallback imports try: from transformers import LlavaProcessor, LlavaForConditionalGeneration as LlavaNextForConditionalGeneration from transformers import AutoProcessor as LlavaNextProcessor logger.info("✅ Using fallback LLaVA imports") except Exception as e2: logger.error(f"❌ Fallback imports also failed: {e2}") # Usage tracking class UsageTracker: def __init__(self): self.stats = { 'total_analyses': 0, 'successful_analyses': 0, 'failed_analyses': 0, 'average_processing_time': 0.0, 'question_types': Counter() } def log_analysis(self, success, duration, question_type=None): self.stats['total_analyses'] += 1 if success: self.stats['successful_analyses'] += 1 else: self.stats['failed_analyses'] += 1 total_time = self.stats['average_processing_time'] * (self.stats['total_analyses'] - 1) self.stats['average_processing_time'] = (total_time + duration) / self.stats['total_analyses'] if question_type: self.stats['question_types'][question_type] += 1 # Rate limiting class RateLimiter: def __init__(self, max_requests_per_hour=20): self.max_requests_per_hour = max_requests_per_hour self.requests = defaultdict(list) def is_allowed(self, user_id="default"): current_time = time.time() hour_ago = current_time - 3600 self.requests[user_id] = [req_time for req_time in self.requests[user_id] if req_time > hour_ago] if len(self.requests[user_id]) < self.max_requests_per_hour: self.requests[user_id].append(current_time) return True return False # Initialize components usage_tracker = UsageTracker() rate_limiter = RateLimiter() # Model configuration MODEL_ID = "llava-hf/llava-v1.6-mistral-7b-hf" # Global variables model = None processor = None def load_llava_safe(): """Load LLaVA model with comprehensive error handling""" global model, processor try: logger.info(f"Loading LLaVA model: {MODEL_ID}") # Try different loading approaches loading_methods = [ ("Standard LlavaNext", lambda: ( LlavaNextProcessor.from_pretrained(MODEL_ID), LlavaNextForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.float32, # Use float32 for stability device_map=None, # Let PyTorch handle device placement low_cpu_mem_usage=True, attn_implementation="eager" # Use eager attention to avoid issues ) )), ("Auto Processor Fallback", lambda: ( LlavaNextProcessor.from_pretrained(MODEL_ID), LlavaNextForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.float32, trust_remote_code=True, use_safetensors=True ) )), ] for method_name, method_func in loading_methods: try: logger.info(f"Trying {method_name}...") processor, model = method_func() logger.info(f"✅ LLaVA loaded successfully using {method_name}!") return True except Exception as e: logger.warning(f"❌ {method_name} failed: {str(e)}") continue logger.error("❌ All loading methods failed") return False except Exception as e: logger.error(f"❌ Error loading LLaVA: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") return False # Load model at startup llava_ready = load_llava_safe() def analyze_medical_image_llava(image, clinical_question, patient_history=""): """Analyze medical image using LLaVA with robust error handling""" start_time = time.time() # Rate limiting if not rate_limiter.is_allowed(): usage_tracker.log_analysis(False, time.time() - start_time) return "⚠️ Rate limit exceeded. Please wait before trying again." if not llava_ready or model is None: usage_tracker.log_analysis(False, time.time() - start_time) return """❌ **LLaVA Model Loading Issue** The LLaVA model failed to load due to compatibility issues. This is often caused by: 1. **Library Version Conflicts**: Try refreshing the page - we've applied compatibility fixes 2. **Memory Constraints**: The 7B model requires significant resources 3. **Transformers Version**: Some versions have compatibility issues **Suggested Solutions:** - **Refresh the page** and wait 2-3 minutes for model loading - **Upgrade to GPU hardware** for better performance and stability - **Try a different image** if the issue persists **Technical Info**: There may be version conflicts in the transformers library. The model files downloaded successfully but initialization failed.""" if image is None: return "⚠️ Please upload a medical image first." if not clinical_question.strip(): return "⚠️ Please provide a clinical question." try: logger.info("Starting LLaVA medical analysis...") # Prepare medical prompt medical_prompt = f"""You are an expert medical AI assistant analyzing medical images. Please provide a comprehensive medical analysis. {f"Patient History: {patient_history}" if patient_history.strip() else ""} Clinical Question: {clinical_question} Please analyze this medical image systematically: 1. **Image Quality**: Assess technical quality and diagnostic adequacy 2. **Anatomical Structures**: Identify visible normal structures 3. **Abnormal Findings**: Describe any pathological changes 4. **Clinical Significance**: Explain the importance of findings 5. **Assessment**: Provide clinical interpretation 6. **Recommendations**: Suggest next steps if appropriate Provide detailed, educational medical analysis suitable for learning purposes.""" # Different prompt formats to try prompt_formats = [ # Format 1: Simple user message lambda: f"USER: \n{medical_prompt}\nASSISTANT:", # Format 2: Chat format lambda: processor.apply_chat_template([ {"role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": medical_prompt} ]} ], add_generation_prompt=True), # Format 3: Direct format lambda: medical_prompt ] # Try different prompt formats for i, prompt_func in enumerate(prompt_formats): try: logger.info(f"Trying prompt format {i+1}...") if i == 1: # Chat template format try: prompt = prompt_func() except: continue else: prompt = prompt_func() # Process inputs inputs = processor(prompt, image, return_tensors='pt') # Generate response with conservative settings logger.info("Generating medical analysis...") with torch.inference_mode(): output = model.generate( **inputs, max_new_tokens=1000, # Conservative limit do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.1, use_cache=False # Disable cache for stability ) # Decode response generated_text = processor.decode(output[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) if generated_text and generated_text.strip(): break except Exception as e: logger.warning(f"Prompt format {i+1} failed: {e}") if i == len(prompt_formats) - 1: # Last attempt raise e continue # Clean up response response = generated_text.strip() if generated_text else "Analysis completed." # Format the response formatted_response = f"""# 🏥 **LLaVA Medical Analysis** ## **Clinical Question:** {clinical_question} {f"## **Patient History:** {patient_history}" if patient_history.strip() else ""} --- ## 🔍 **Medical Analysis Results** {response} --- ## 📋 **Clinical Summary** This analysis was generated using LLaVA (Large Language and Vision Assistant) for educational purposes. The findings should be interpreted by qualified medical professionals and correlated with clinical presentation. **Key Points:** - Analysis based on visual medical image interpretation - Systematic approach to medical imaging assessment - Educational tool for medical learning and training - Requires professional medical validation """ # Add medical disclaimer disclaimer = """ --- ## ⚠️ **MEDICAL DISCLAIMER** **FOR EDUCATIONAL PURPOSES ONLY** - **Not Diagnostic**: This AI analysis is not a medical diagnosis - **Professional Review**: All findings require validation by healthcare professionals - **Emergency Care**: Contact emergency services for urgent medical concerns - **Educational Tool**: Designed for medical education and training - **No PHI**: Do not upload patient identifiable information --- **Powered by**: LLaVA (Large Language and Vision Assistant) """ # Log successful analysis duration = time.time() - start_time question_type = classify_question(clinical_question) usage_tracker.log_analysis(True, duration, question_type) logger.info("✅ LLaVA medical analysis completed successfully") return formatted_response + disclaimer except Exception as e: duration = time.time() - start_time usage_tracker.log_analysis(False, duration) logger.error(f"❌ LLaVA analysis error: {str(e)}") return f"""❌ **Analysis Error** The analysis failed with error: {str(e)} **Common Solutions:** - **Try again**: Sometimes temporary processing issues occur - **Smaller image**: Try with a smaller or different format image - **Simpler question**: Use a more straightforward clinical question - **Refresh page**: Reload the page if model seems unstable **Technical Details:** {str(e)[:200]}""" def classify_question(question): """Classify clinical question type""" question_lower = question.lower() if any(word in question_lower for word in ['describe', 'findings', 'observe']): return 'descriptive' elif any(word in question_lower for word in ['diagnosis', 'differential', 'condition']): return 'diagnostic' elif any(word in question_lower for word in ['abnormal', 'pathology', 'disease']): return 'pathological' else: return 'general' def get_usage_stats(): """Get usage statistics""" stats = usage_tracker.stats if stats['total_analyses'] == 0: return "📊 **Usage Statistics**\n\nNo analyses performed yet." success_rate = (stats['successful_analyses'] / stats['total_analyses']) * 100 return f"""📊 **LLaVA Usage Statistics** **Performance:** - Total Analyses: {stats['total_analyses']} - Success Rate: {success_rate:.1f}% - Avg Processing Time: {stats['average_processing_time']:.2f}s **Popular Question Types:** {chr(10).join([f"- {qtype}: {count}" for qtype, count in stats['question_types'].most_common(3)])} **Model Status**: {'🟢 Ready' if llava_ready else '🔴 Loading Issues'} """ # Create Gradio interface def create_interface(): with gr.Blocks( title="LLaVA Medical Analysis", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } .disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; } .success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px; margin: 16px 0; } .warning { background-color: #fffbeb; border: 1px solid #fed7aa; border-radius: 8px; padding: 16px; margin: 16px 0; } """ ) as demo: # Header gr.Markdown(""" # 🏥 LLaVA Medical Image Analysis **Advanced Medical AI powered by LLaVA (Large Language and Vision Assistant)** **Medical Capabilities:** 🫁 Radiology • 🔬 Pathology • 🩺 Dermatology • 👁️ Ophthalmology """) # Status display if llava_ready: gr.Markdown("""
LLAVA MEDICAL AI READY
LLaVA model loaded successfully with compatibility fixes. Ready for medical image analysis.
""") else: gr.Markdown("""
⚠️ MODEL LOADING ISSUE
LLaVA model had loading problems. Try refreshing the page or contact support for assistance.
""") # Medical disclaimer gr.Markdown("""
⚠️ MEDICAL DISCLAIMER
This AI provides medical analysis for educational purposes only. Do not upload real patient data. Always consult healthcare professionals for medical decisions.
""") with gr.Row(): # Left column with gr.Column(scale=2): with gr.Row(): with gr.Column(): gr.Markdown("## 📤 Medical Image") image_input = gr.Image( label="Upload Medical Image", type="pil", height=300 ) with gr.Column(): gr.Markdown("## 💬 Clinical Information") clinical_question = gr.Textbox( label="Clinical Question *", placeholder="Examples:\n• Analyze this medical image\n• What abnormalities are visible?\n• Describe the findings\n• Provide medical interpretation", lines=4 ) patient_history = gr.Textbox( label="Patient History (Optional)", placeholder="e.g., 45-year-old with chest pain", lines=2 ) with gr.Row(): clear_btn = gr.Button("🗑️ Clear", variant="secondary") analyze_btn = gr.Button("🔍 Analyze with LLaVA", variant="primary", size="lg") gr.Markdown("## 📋 Medical Analysis Results") output = gr.Textbox( label="LLaVA Medical Analysis", lines=20, show_copy_button=True, placeholder="Upload a medical image and clinical question..." if llava_ready else "Model loading issues - please refresh the page" ) # Right column with gr.Column(scale=1): gr.Markdown("## ℹ️ System Status") status = "✅ Ready" if llava_ready else "⚠️ Loading Issues" gr.Markdown(f""" **Model Status:** {status} **AI Model:** LLaVA-v1.6-Mistral-7B **Device:** {'GPU' if torch.cuda.is_available() else 'CPU'} **Compatibility:** Fixed for stability **Rate Limit:** 20 requests/hour """) gr.Markdown("## 📊 Usage Statistics") stats_display = gr.Markdown("") refresh_stats_btn = gr.Button("🔄 Refresh Stats", size="sm") if llava_ready: gr.Markdown("## 🎯 Quick Examples") general_btn = gr.Button("General Analysis", size="sm") findings_btn = gr.Button("Find Abnormalities", size="sm") interpret_btn = gr.Button("Medical Interpretation", size="sm") # Example cases if llava_ready: with gr.Accordion("📚 Example Cases", open=False): examples = gr.Examples( examples=[ [ "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png", "Please analyze this chest X-ray and describe any findings. Assess the image quality, identify normal structures, and note any abnormalities.", "Adult patient with respiratory symptoms" ] ], inputs=[image_input, clinical_question, patient_history] ) # Event handlers analyze_btn.click( fn=analyze_medical_image_llava, inputs=[image_input, clinical_question, patient_history], outputs=output, show_progress=True ) clear_btn.click( fn=lambda: (None, "", "", ""), outputs=[image_input, clinical_question, patient_history, output] ) refresh_stats_btn.click( fn=get_usage_stats, outputs=stats_display ) # Quick example handlers if llava_ready: general_btn.click( fn=lambda: ("Analyze this medical image comprehensively. Describe what you observe and provide medical interpretation.", ""), outputs=[clinical_question, patient_history] ) findings_btn.click( fn=lambda: ("What abnormalities or pathological findings are visible in this medical image?", ""), outputs=[clinical_question, patient_history] ) interpret_btn.click( fn=lambda: ("Provide medical interpretation of this image including clinical significance of any findings.", ""), outputs=[clinical_question, patient_history] ) # Footer gr.Markdown(""" --- ### 🤖 LLaVA Medical AI **Large Language and Vision Assistant** optimized for medical image analysis with compatibility fixes for stable operation. **Features:** - Advanced medical image interpretation - Systematic clinical analysis approach - Educational medical explanations - Comprehensive error handling **Model:** LLaVA-v1.6-Mistral-7B | **Purpose:** Medical Education & Research """) return demo # Launch the application if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True )