# app.py - Medical AI with Proper Vision Analysis import gradio as gr import torch from transformers import ( BlipProcessor, BlipForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, pipeline ) from PIL import Image import logging from collections import defaultdict, Counter import time import requests from io import BytesIO # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Usage tracking class UsageTracker: def __init__(self): self.stats = { 'total_analyses': 0, 'successful_analyses': 0, 'failed_analyses': 0, 'average_processing_time': 0.0, 'question_types': Counter() } def log_analysis(self, success, duration, question_type=None): self.stats['total_analyses'] += 1 if success: self.stats['successful_analyses'] += 1 else: self.stats['failed_analyses'] += 1 total_time = self.stats['average_processing_time'] * (self.stats['total_analyses'] - 1) self.stats['average_processing_time'] = (total_time + duration) / self.stats['total_analyses'] if question_type: self.stats['question_types'][question_type] += 1 # Rate limiting class RateLimiter: def __init__(self, max_requests_per_hour=60): self.max_requests_per_hour = max_requests_per_hour self.requests = defaultdict(list) def is_allowed(self, user_id="default"): current_time = time.time() hour_ago = current_time - 3600 self.requests[user_id] = [req_time for req_time in self.requests[user_id] if req_time > hour_ago] if len(self.requests[user_id]) < self.max_requests_per_hour: self.requests[user_id].append(current_time) return True return False # Initialize components usage_tracker = UsageTracker() rate_limiter = RateLimiter() # Try multiple models for better medical analysis MODELS_TO_TRY = [ "microsoft/git-base-coco", # Better for detailed descriptions "Salesforce/blip2-opt-2.7b", # More capable BLIP2 model "Salesforce/blip-image-captioning-large" # Fallback ] # Global variables model = None processor = None device = "cuda" if torch.cuda.is_available() else "cpu" current_model_name = None def load_best_model(): """Try to load the best available model for medical image analysis""" global model, processor, current_model_name for model_name in MODELS_TO_TRY: try: logger.info(f"Trying to load: {model_name}") if "git-base" in model_name: # Use transformers pipeline for GIT model model = pipeline("image-to-text", model=model_name, device=0 if torch.cuda.is_available() else -1) processor = None current_model_name = model_name logger.info(f"✅ Successfully loaded GIT model: {model_name}") return True elif "blip2" in model_name: # Try BLIP2 model processor = AutoProcessor.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) current_model_name = model_name logger.info(f"✅ Successfully loaded BLIP2 model: {model_name}") return True else: # Standard BLIP model processor = BlipProcessor.from_pretrained(model_name) model = BlipForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) if torch.cuda.is_available() and hasattr(model, 'to'): model = model.to(device) current_model_name = model_name logger.info(f"✅ Successfully loaded BLIP model: {model_name}") return True except Exception as e: logger.warning(f"Failed to load {model_name}: {e}") continue logger.error("❌ Failed to load any model") return False # Load model at startup model_ready = load_best_model() def get_detailed_medical_analysis(image, question): """Get detailed medical analysis using the best available model""" try: if "git-base" in current_model_name: # Use GIT model (usually gives more detailed descriptions) results = model(image, max_new_tokens=200) description = results[0]['generated_text'] if results else "Unable to analyze image" # For medical questions, try to expand the analysis if any(word in question.lower() for word in ['abnormal', 'diagnosis', 'condition', 'pathology']): # Add medical context to the basic description medical_prompt = f"Medical analysis: {description}" return description, medical_prompt return description, description elif "blip2" in current_model_name: # Use BLIP2 model inputs = processor(image, question, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=150, do_sample=False) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Also get unconditional description basic_inputs = processor(image, return_tensors="pt") if torch.cuda.is_available(): basic_inputs = {k: v.to(device) for k, v in basic_inputs.items()} with torch.no_grad(): basic_ids = model.generate(**basic_inputs, max_new_tokens=100, do_sample=False) basic_text = processor.batch_decode(basic_ids, skip_special_tokens=True)[0] return basic_text, generated_text else: # Standard BLIP model - improved approach # Get unconditional caption first inputs = processor(image, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): output_ids = model.generate(**inputs, max_length=100, num_beams=3, do_sample=False) basic_description = processor.decode(output_ids[0], skip_special_tokens=True) # Try conditional generation with better prompting medical_prompts = [ f"Question: {question} Answer:", f"Medical analysis: {question}", f"Describe the medical findings: {question}" ] best_response = basic_description for prompt in medical_prompts: try: inputs_qa = processor(image, prompt, return_tensors="pt") if torch.cuda.is_available(): inputs_qa = {k: v.to(device) for k, v in inputs_qa.items()} with torch.no_grad(): qa_output_ids = model.generate( **inputs_qa, max_length=200, num_beams=3, do_sample=False, early_stopping=True ) # Decode only generated part input_length = inputs_qa['input_ids'].shape[1] qa_response = processor.decode(qa_output_ids[0][input_length:], skip_special_tokens=True).strip() if qa_response and len(qa_response) > 20 and not qa_response.lower().startswith('question'): best_response = qa_response break except Exception as e: continue return basic_description, best_response except Exception as e: logger.error(f"Analysis failed: {e}") return "Unable to analyze image", "Analysis failed" def enhance_medical_description(basic_desc, clinical_question, patient_history): """Enhance basic description with medical context and educational content""" # Common medical image analysis patterns chest_xray_analysis = """ **Systematic Chest X-ray Analysis:** **Technical Quality:** - Image appears to be a standard PA chest radiograph - Adequate penetration and positioning for diagnostic evaluation **Anatomical Review:** - **Heart**: Cardiac silhouette evaluation for size and contour - **Lungs**: Assessment of lung fields for opacity, consolidation, or air trapping - **Pleura**: Examination for pleural effusion or pneumothorax - **Bones**: Rib cage and spine alignment assessment - **Soft Tissues**: Evaluation of surrounding structures **Clinical Correlation Needed:** Given the patient's presentation with cough and fever, key considerations include: - **Pneumonia**: Look for consolidation, air bronchograms, or infiltrates - **Viral vs Bacterial**: Pattern recognition for different infectious etiologies - **Atelectasis**: Collapsed lung segments that might appear as increased opacity - **Pleural Changes**: Fluid collection that could indicate infection complications **Educational Points:** - Chest X-rays are the first-line imaging for respiratory symptoms - Clinical correlation is essential - symptoms guide interpretation - Follow-up imaging may be needed based on treatment response """ # Determine if this is likely a chest X-ray if any(term in basic_desc.lower() for term in ['chest', 'lung', 'rib', 'heart', 'x-ray', 'radiograph']) or \ any(term in clinical_question.lower() for term in ['chest', 'lung', 'respiratory', 'cough']): enhanced_analysis = chest_xray_analysis else: # Generic medical image analysis enhanced_analysis = f""" **Medical Image Analysis Framework:** **Image Description:** {basic_desc} **Clinical Context Integration:** - Patient presentation: {patient_history if patient_history else 'Clinical history provided'} - Imaging indication: {clinical_question} **Systematic Approach:** 1. **Technical Assessment**: Image quality and acquisition parameters 2. **Anatomical Review**: Systematic evaluation of visible structures 3. **Pathological Assessment**: Identification of any abnormal findings 4. **Clinical Correlation**: Integration with patient symptoms and history **Educational Considerations:** - Medical imaging interpretation requires systematic approach - Clinical context significantly influences interpretation priorities - Multiple imaging modalities may be complementary for diagnosis - Professional radiological review is essential for clinical decisions """ return enhanced_analysis def analyze_medical_image(image, clinical_question, patient_history=""): """Enhanced medical image analysis with better AI models""" start_time = time.time() # Rate limiting if not rate_limiter.is_allowed(): usage_tracker.log_analysis(False, time.time() - start_time) return "⚠️ Rate limit exceeded. Please wait before trying again." if not model_ready or model is None: usage_tracker.log_analysis(False, time.time() - start_time) return "❌ Medical AI model not loaded. Please refresh the page." if image is None: return "⚠️ Please upload a medical image first." if not clinical_question.strip(): return "⚠️ Please provide a clinical question." try: logger.info("Starting enhanced medical image analysis...") # Get detailed analysis from AI model basic_description, detailed_response = get_detailed_medical_analysis(image, clinical_question) # Enhance with medical knowledge enhanced_analysis = enhance_medical_description(basic_description, clinical_question, patient_history) # Create comprehensive medical report formatted_response = f"""# 🏥 **Enhanced Medical AI Analysis** ## **Clinical Question:** {clinical_question} {f"## **Patient History:** {patient_history}" if patient_history.strip() else ""} --- ## 🔍 **AI Vision Analysis** ### **Image Description:** {basic_description} ### **Question-Specific Analysis:** {detailed_response} --- ## 📋 **Medical Assessment Framework** {enhanced_analysis} --- ## 🎓 **Educational Summary** **Learning Objectives:** - Demonstrate systematic approach to medical image interpretation - Integrate clinical history with imaging findings - Understand the importance of professional validation in medical diagnosis **Key Teaching Points:** - Medical imaging is one component of comprehensive patient assessment - Clinical correlation enhances diagnostic accuracy - Multiple imaging modalities may provide complementary information - Professional interpretation is essential for patient care decisions **Clinical Decision Making:** Based on the combination of: - Patient symptoms: {patient_history if patient_history else 'As provided'} - Imaging findings: As described above - Clinical context: {clinical_question} **Next Steps in Clinical Practice:** - Professional radiological review - Correlation with laboratory findings - Consider additional imaging if clinically indicated - Follow-up based on treatment response """ # Add medical disclaimer disclaimer = """ --- ## ⚠️ **IMPORTANT MEDICAL DISCLAIMER** **FOR EDUCATIONAL AND RESEARCH PURPOSES ONLY** - **🚫 AI Limitations**: AI analysis has significant limitations for medical diagnosis - **👨‍⚕️ Professional Review Required**: All findings must be validated by qualified healthcare professionals - **🚨 Emergency Care**: For urgent medical concerns, seek immediate medical attention - **🏥 Clinical Integration**: AI findings are educational tools, not diagnostic conclusions - **📋 Learning Tool**: Designed for medical education and training purposes - **🔒 Privacy**: Do not upload real patient data or identifiable information **This analysis demonstrates AI-assisted medical image interpretation concepts for educational purposes only.** --- **Model**: {current_model_name} | **Device**: {device.upper()} | **Purpose**: Medical Education """ # Log successful analysis duration = time.time() - start_time question_type = classify_question(clinical_question) usage_tracker.log_analysis(True, duration, question_type) logger.info(f"✅ Enhanced medical analysis completed in {duration:.2f}s") return formatted_response + disclaimer except Exception as e: duration = time.time() - start_time usage_tracker.log_analysis(False, duration) logger.error(f"❌ Analysis error: {str(e)}") return f"❌ Enhanced analysis failed: {str(e)}\n\nPlease try again with a different image." def classify_question(question): """Classify clinical question type""" question_lower = question.lower() if any(word in question_lower for word in ['describe', 'findings', 'observe', 'see', 'show']): return 'descriptive' elif any(word in question_lower for word in ['diagnosis', 'differential', 'condition']): return 'diagnostic' elif any(word in question_lower for word in ['abnormal', 'pathology', 'disease']): return 'pathological' else: return 'general' def get_usage_stats(): """Get usage statistics""" stats = usage_tracker.stats if stats['total_analyses'] == 0: return "📊 **Usage Statistics**\n\nNo analyses performed yet." success_rate = (stats['successful_analyses'] / stats['total_analyses']) * 100 return f"""📊 **Enhanced Medical AI Statistics** **Performance Metrics:** - **Total Analyses**: {stats['total_analyses']} - **Success Rate**: {success_rate:.1f}% - **Average Processing Time**: {stats['average_processing_time']:.2f} seconds **Question Types:** {chr(10).join([f"- **{qtype.title()}**: {count}" for qtype, count in stats['question_types'].most_common(3)])} **System Status**: {'🟢 Enhanced Model Active' if model_ready else '🔴 Offline'} **Current Model**: {current_model_name if current_model_name else 'None'} **Device**: {device.upper()} """ def clear_all(): """Clear all inputs and outputs""" return None, "", "", "" def set_chest_example(): """Set chest X-ray example""" return "Describe this chest X-ray systematically and identify any abnormalities", "30-year-old patient with productive cough, fever, and shortness of breath" def set_pathology_example(): """Set pathology example""" return "What pathological findings are visible? Describe the tissue characteristics.", "Biopsy specimen for histopathological evaluation" def set_general_example(): """Set general analysis example""" return "Provide a systematic analysis of this medical image", "Patient requiring comprehensive imaging evaluation" # Create enhanced Gradio interface def create_interface(): with gr.Blocks( title="Enhanced Medical AI Analysis", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1400px !important; } .disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; } .success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px 0; } .enhanced { background-color: #f0fdf4; border: 1px solid #bbf7d0; border-radius: 8px; padding: 16px 0; } """ ) as demo: # Header gr.Markdown(""" # 🏥 Enhanced Medical AI Image Analysis **Advanced Medical AI with Better Vision Models - Educational Analysis** **Enhanced Features:** 🧠 Multiple AI Models • 🔬 Systematic Analysis • 📋 Educational Framework • 🎓 Clinical Integration """) # Status display if model_ready: gr.Markdown(f"""
ENHANCED MEDICAL AI READY
Advanced model loaded: {current_model_name}
Now provides detailed medical image analysis with systematic framework and educational content.
""") else: gr.Markdown("""
⚠️ MODEL LOADING
Enhanced Medical AI is loading. Please wait and refresh if needed.
""") # Medical disclaimer gr.Markdown("""
⚠️ MEDICAL DISCLAIMER
This enhanced tool provides AI-assisted medical analysis for educational purposes only. Uses advanced vision models for better image understanding. Do not upload real patient data.
""") with gr.Row(): # Left column - Main interface with gr.Column(scale=2): # Image upload gr.Markdown("## 📤 Medical Image Upload") image_input = gr.Image( label="Upload Medical Image (Enhanced Analysis)", type="pil", height=300 ) # Clinical inputs gr.Markdown("## 💬 Clinical Information") clinical_question = gr.Textbox( label="Clinical Question *", placeholder="Enhanced examples:\n• Systematically describe this chest X-ray and identify abnormalities\n• What pathological findings are visible in this image?\n• Provide detailed analysis of anatomical structures\n• Analyze this medical scan for educational purposes", lines=4 ) patient_history = gr.Textbox( label="Patient History & Clinical Context", placeholder="Detailed example: 35-year-old female with 3-day history of productive cough, fever (38.5°C), shortness of breath, and left-sided chest pain", lines=3 ) # Action buttons with gr.Row(): clear_btn = gr.Button("🗑️ Clear All", variant="secondary") analyze_btn = gr.Button("🔍 Enhanced Medical Analysis", variant="primary", size="lg") # Results gr.Markdown("## 📋 Enhanced Medical Analysis Results") output = gr.Textbox( label="Advanced AI Medical Analysis (Multiple Models)", lines=25, show_copy_button=True, placeholder="Upload a medical image and provide detailed clinical question for comprehensive AI analysis..." ) # Right column - Status and controls with gr.Column(scale=1): gr.Markdown("## ℹ️ Enhanced System Status") system_info = f""" **Status**: {'✅ Advanced Models Active' if model_ready else '🔄 Loading'} **Primary Model**: {current_model_name if current_model_name else 'Loading...'} **Device**: {device.upper()} **Enhancement**: 🧠 Multiple AI Models **Analysis**: 📋 Systematic Framework """ gr.Markdown(system_info) # Statistics gr.Markdown("## 📊 Usage Analytics") stats_display = gr.Markdown(get_usage_stats()) refresh_stats_btn = gr.Button("🔄 Refresh Stats", size="sm") # Quick examples if model_ready: gr.Markdown("## 🎯 Enhanced Examples") chest_btn = gr.Button("🫁 Chest X-ray Analysis", size="sm") pathology_btn = gr.Button("🔬 Pathology Study", size="sm") general_btn = gr.Button("📋 Systematic Analysis", size="sm") gr.Markdown("## 🚀 Enhancements") gr.Markdown(f""" ✅ **Advanced Vision Models** ✅ **Systematic Medical Framework** ✅ **Educational Integration** ✅ **Clinical Context Analysis** ✅ **Model**: {current_model_name.split('/')[-1] if current_model_name else 'Enhanced'} """) # Event handlers analyze_btn.click( fn=analyze_medical_image, inputs=[image_input, clinical_question, patient_history], outputs=output, show_progress=True ) clear_btn.click( fn=clear_all, outputs=[image_input, clinical_question, patient_history, output] ) refresh_stats_btn.click( fn=get_usage_stats, outputs=stats_display ) # Quick example handlers if model_ready: chest_btn.click( fn=set_chest_example, outputs=[clinical_question, patient_history] ) pathology_btn.click( fn=set_pathology_example, outputs=[clinical_question, patient_history] ) general_btn.click( fn=set_general_example, outputs=[clinical_question, patient_history] ) # Footer gr.Markdown(f""" --- ## 🚀 **Enhanced Medical AI Features** ### **Advanced Vision Models:** - **Microsoft GIT**: Enhanced image-to-text capabilities - **BLIP2**: Advanced vision-language understanding - **Multi-Model Fallback**: Automatic best model selection - **Better Descriptions**: More detailed and accurate analysis ### **Medical Framework Integration:** - **Systematic Analysis**: Structured medical image interpretation - **Clinical Correlation**: Integration of symptoms with imaging - **Educational Content**: Teaching points and learning objectives - **Professional Guidelines**: Follows medical education standards **Current Model**: {current_model_name if current_model_name else 'Loading...'} | **Purpose**: Enhanced Medical Education """) return demo # Launch the application if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True, share=False )