lumenex / app.py
walaa2022's picture
Update app.py
0c3e999 verified
raw
history blame
22 kB
# app.py - Fixed LLaVA Medical AI with NoneType Error Resolution
import gradio as gr
import torch
import logging
from collections import defaultdict, Counter
import time
import traceback
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Fix the NoneType compatibility issue
def fix_transformers_compatibility():
"""Fix compatibility issues with transformers library"""
try:
# Import and fix the parallel styles issue
import transformers.modeling_utils as modeling_utils
if not hasattr(modeling_utils, 'ALL_PARALLEL_STYLES'):
modeling_utils.ALL_PARALLEL_STYLES = []
elif getattr(modeling_utils, 'ALL_PARALLEL_STYLES', None) is None:
modeling_utils.ALL_PARALLEL_STYLES = []
# Fix in specific model files
try:
import transformers.models.llava_next.modeling_llava_next as llava_next
if not hasattr(llava_next, 'ALL_PARALLEL_STYLES'):
llava_next.ALL_PARALLEL_STYLES = []
elif getattr(llava_next, 'ALL_PARALLEL_STYLES', None) is None:
llava_next.ALL_PARALLEL_STYLES = []
except ImportError:
pass
# Fix in mistral files if they exist
try:
import transformers.models.mistral.modeling_mistral as mistral
if not hasattr(mistral, 'ALL_PARALLEL_STYLES'):
mistral.ALL_PARALLEL_STYLES = []
elif getattr(mistral, 'ALL_PARALLEL_STYLES', None) is None:
mistral.ALL_PARALLEL_STYLES = []
except ImportError:
pass
logger.info("βœ… Applied compatibility fixes")
return True
except Exception as e:
logger.warning(f"⚠️ Could not apply compatibility fixes: {e}")
return False
# Apply compatibility fix before imports
fix_transformers_compatibility()
# Now import transformers
try:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
logger.info("βœ… Transformers imported successfully")
except Exception as e:
logger.error(f"❌ Failed to import transformers: {e}")
# Fallback imports
try:
from transformers import LlavaProcessor, LlavaForConditionalGeneration as LlavaNextForConditionalGeneration
from transformers import AutoProcessor as LlavaNextProcessor
logger.info("βœ… Using fallback LLaVA imports")
except Exception as e2:
logger.error(f"❌ Fallback imports also failed: {e2}")
# Usage tracking
class UsageTracker:
def __init__(self):
self.stats = {
'total_analyses': 0,
'successful_analyses': 0,
'failed_analyses': 0,
'average_processing_time': 0.0,
'question_types': Counter()
}
def log_analysis(self, success, duration, question_type=None):
self.stats['total_analyses'] += 1
if success:
self.stats['successful_analyses'] += 1
else:
self.stats['failed_analyses'] += 1
total_time = self.stats['average_processing_time'] * (self.stats['total_analyses'] - 1)
self.stats['average_processing_time'] = (total_time + duration) / self.stats['total_analyses']
if question_type:
self.stats['question_types'][question_type] += 1
# Rate limiting
class RateLimiter:
def __init__(self, max_requests_per_hour=20):
self.max_requests_per_hour = max_requests_per_hour
self.requests = defaultdict(list)
def is_allowed(self, user_id="default"):
current_time = time.time()
hour_ago = current_time - 3600
self.requests[user_id] = [req_time for req_time in self.requests[user_id] if req_time > hour_ago]
if len(self.requests[user_id]) < self.max_requests_per_hour:
self.requests[user_id].append(current_time)
return True
return False
# Initialize components
usage_tracker = UsageTracker()
rate_limiter = RateLimiter()
# Model configuration
MODEL_ID = "llava-hf/llava-v1.6-mistral-7b-hf"
# Global variables
model = None
processor = None
def load_llava_safe():
"""Load LLaVA model with comprehensive error handling"""
global model, processor
try:
logger.info(f"Loading LLaVA model: {MODEL_ID}")
# Try different loading approaches
loading_methods = [
("Standard LlavaNext", lambda: (
LlavaNextProcessor.from_pretrained(MODEL_ID),
LlavaNextForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32, # Use float32 for stability
device_map=None, # Let PyTorch handle device placement
low_cpu_mem_usage=True,
attn_implementation="eager" # Use eager attention to avoid issues
)
)),
("Auto Processor Fallback", lambda: (
LlavaNextProcessor.from_pretrained(MODEL_ID),
LlavaNextForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
trust_remote_code=True,
use_safetensors=True
)
)),
]
for method_name, method_func in loading_methods:
try:
logger.info(f"Trying {method_name}...")
processor, model = method_func()
logger.info(f"βœ… LLaVA loaded successfully using {method_name}!")
return True
except Exception as e:
logger.warning(f"❌ {method_name} failed: {str(e)}")
continue
logger.error("❌ All loading methods failed")
return False
except Exception as e:
logger.error(f"❌ Error loading LLaVA: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return False
# Load model at startup
llava_ready = load_llava_safe()
def analyze_medical_image_llava(image, clinical_question, patient_history=""):
"""Analyze medical image using LLaVA with robust error handling"""
start_time = time.time()
# Rate limiting
if not rate_limiter.is_allowed():
usage_tracker.log_analysis(False, time.time() - start_time)
return "⚠️ Rate limit exceeded. Please wait before trying again."
if not llava_ready or model is None:
usage_tracker.log_analysis(False, time.time() - start_time)
return """❌ **LLaVA Model Loading Issue**
The LLaVA model failed to load due to compatibility issues. This is often caused by:
1. **Library Version Conflicts**: Try refreshing the page - we've applied compatibility fixes
2. **Memory Constraints**: The 7B model requires significant resources
3. **Transformers Version**: Some versions have compatibility issues
**Suggested Solutions:**
- **Refresh the page** and wait 2-3 minutes for model loading
- **Upgrade to GPU hardware** for better performance and stability
- **Try a different image** if the issue persists
**Technical Info**: There may be version conflicts in the transformers library. The model files downloaded successfully but initialization failed."""
if image is None:
return "⚠️ Please upload a medical image first."
if not clinical_question.strip():
return "⚠️ Please provide a clinical question."
try:
logger.info("Starting LLaVA medical analysis...")
# Prepare medical prompt
medical_prompt = f"""You are an expert medical AI assistant analyzing medical images. Please provide a comprehensive medical analysis.
{f"Patient History: {patient_history}" if patient_history.strip() else ""}
Clinical Question: {clinical_question}
Please analyze this medical image systematically:
1. **Image Quality**: Assess technical quality and diagnostic adequacy
2. **Anatomical Structures**: Identify visible normal structures
3. **Abnormal Findings**: Describe any pathological changes
4. **Clinical Significance**: Explain the importance of findings
5. **Assessment**: Provide clinical interpretation
6. **Recommendations**: Suggest next steps if appropriate
Provide detailed, educational medical analysis suitable for learning purposes."""
# Different prompt formats to try
prompt_formats = [
# Format 1: Simple user message
lambda: f"USER: <image>\n{medical_prompt}\nASSISTANT:",
# Format 2: Chat format
lambda: processor.apply_chat_template([
{"role": "user", "content": [
{"type": "image", "image": image},
{"type": "text", "text": medical_prompt}
]}
], add_generation_prompt=True),
# Format 3: Direct format
lambda: medical_prompt
]
# Try different prompt formats
for i, prompt_func in enumerate(prompt_formats):
try:
logger.info(f"Trying prompt format {i+1}...")
if i == 1: # Chat template format
try:
prompt = prompt_func()
except:
continue
else:
prompt = prompt_func()
# Process inputs
inputs = processor(prompt, image, return_tensors='pt')
# Generate response with conservative settings
logger.info("Generating medical analysis...")
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=1000, # Conservative limit
do_sample=True,
temperature=0.3,
top_p=0.9,
repetition_penalty=1.1,
use_cache=False # Disable cache for stability
)
# Decode response
generated_text = processor.decode(output[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
if generated_text and generated_text.strip():
break
except Exception as e:
logger.warning(f"Prompt format {i+1} failed: {e}")
if i == len(prompt_formats) - 1: # Last attempt
raise e
continue
# Clean up response
response = generated_text.strip() if generated_text else "Analysis completed."
# Format the response
formatted_response = f"""# πŸ₯ **LLaVA Medical Analysis**
## **Clinical Question:** {clinical_question}
{f"## **Patient History:** {patient_history}" if patient_history.strip() else ""}
---
## πŸ” **Medical Analysis Results**
{response}
---
## πŸ“‹ **Clinical Summary**
This analysis was generated using LLaVA (Large Language and Vision Assistant) for educational purposes. The findings should be interpreted by qualified medical professionals and correlated with clinical presentation.
**Key Points:**
- Analysis based on visual medical image interpretation
- Systematic approach to medical imaging assessment
- Educational tool for medical learning and training
- Requires professional medical validation
"""
# Add medical disclaimer
disclaimer = """
---
## ⚠️ **MEDICAL DISCLAIMER**
**FOR EDUCATIONAL PURPOSES ONLY**
- **Not Diagnostic**: This AI analysis is not a medical diagnosis
- **Professional Review**: All findings require validation by healthcare professionals
- **Emergency Care**: Contact emergency services for urgent medical concerns
- **Educational Tool**: Designed for medical education and training
- **No PHI**: Do not upload patient identifiable information
---
**Powered by**: LLaVA (Large Language and Vision Assistant)
"""
# Log successful analysis
duration = time.time() - start_time
question_type = classify_question(clinical_question)
usage_tracker.log_analysis(True, duration, question_type)
logger.info("βœ… LLaVA medical analysis completed successfully")
return formatted_response + disclaimer
except Exception as e:
duration = time.time() - start_time
usage_tracker.log_analysis(False, duration)
logger.error(f"❌ LLaVA analysis error: {str(e)}")
return f"""❌ **Analysis Error**
The analysis failed with error: {str(e)}
**Common Solutions:**
- **Try again**: Sometimes temporary processing issues occur
- **Smaller image**: Try with a smaller or different format image
- **Simpler question**: Use a more straightforward clinical question
- **Refresh page**: Reload the page if model seems unstable
**Technical Details:** {str(e)[:200]}"""
def classify_question(question):
"""Classify clinical question type"""
question_lower = question.lower()
if any(word in question_lower for word in ['describe', 'findings', 'observe']):
return 'descriptive'
elif any(word in question_lower for word in ['diagnosis', 'differential', 'condition']):
return 'diagnostic'
elif any(word in question_lower for word in ['abnormal', 'pathology', 'disease']):
return 'pathological'
else:
return 'general'
def get_usage_stats():
"""Get usage statistics"""
stats = usage_tracker.stats
if stats['total_analyses'] == 0:
return "πŸ“Š **Usage Statistics**\n\nNo analyses performed yet."
success_rate = (stats['successful_analyses'] / stats['total_analyses']) * 100
return f"""πŸ“Š **LLaVA Usage Statistics**
**Performance:**
- Total Analyses: {stats['total_analyses']}
- Success Rate: {success_rate:.1f}%
- Avg Processing Time: {stats['average_processing_time']:.2f}s
**Popular Question Types:**
{chr(10).join([f"- {qtype}: {count}" for qtype, count in stats['question_types'].most_common(3)])}
**Model Status**: {'🟒 Ready' if llava_ready else 'πŸ”΄ Loading Issues'}
"""
# Create Gradio interface
def create_interface():
with gr.Blocks(
title="LLaVA Medical Analysis",
theme=gr.themes.Soft(),
css="""
.gradio-container { max-width: 1200px !important; }
.disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; }
.success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px; margin: 16px 0; }
.warning { background-color: #fffbeb; border: 1px solid #fed7aa; border-radius: 8px; padding: 16px; margin: 16px 0; }
"""
) as demo:
# Header
gr.Markdown("""
# πŸ₯ LLaVA Medical Image Analysis
**Advanced Medical AI powered by LLaVA (Large Language and Vision Assistant)**
**Medical Capabilities:** 🫁 Radiology β€’ πŸ”¬ Pathology β€’ 🩺 Dermatology β€’ πŸ‘οΈ Ophthalmology
""")
# Status display
if llava_ready:
gr.Markdown("""
<div class="success">
βœ… <strong>LLAVA MEDICAL AI READY</strong><br>
LLaVA model loaded successfully with compatibility fixes. Ready for medical image analysis.
</div>
""")
else:
gr.Markdown("""
<div class="warning">
⚠️ <strong>MODEL LOADING ISSUE</strong><br>
LLaVA model had loading problems. Try refreshing the page or contact support for assistance.
</div>
""")
# Medical disclaimer
gr.Markdown("""
<div class="disclaimer">
⚠️ <strong>MEDICAL DISCLAIMER</strong><br>
This AI provides medical analysis for <strong>educational purposes only</strong>.
Do not upload real patient data. Always consult healthcare professionals for medical decisions.
</div>
""")
with gr.Row():
# Left column
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
gr.Markdown("## πŸ“€ Medical Image")
image_input = gr.Image(
label="Upload Medical Image",
type="pil",
height=300
)
with gr.Column():
gr.Markdown("## πŸ’¬ Clinical Information")
clinical_question = gr.Textbox(
label="Clinical Question *",
placeholder="Examples:\nβ€’ Analyze this medical image\nβ€’ What abnormalities are visible?\nβ€’ Describe the findings\nβ€’ Provide medical interpretation",
lines=4
)
patient_history = gr.Textbox(
label="Patient History (Optional)",
placeholder="e.g., 45-year-old with chest pain",
lines=2
)
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
analyze_btn = gr.Button("πŸ” Analyze with LLaVA", variant="primary", size="lg")
gr.Markdown("## πŸ“‹ Medical Analysis Results")
output = gr.Textbox(
label="LLaVA Medical Analysis",
lines=20,
show_copy_button=True,
placeholder="Upload a medical image and clinical question..." if llava_ready else "Model loading issues - please refresh the page"
)
# Right column
with gr.Column(scale=1):
gr.Markdown("## ℹ️ System Status")
status = "βœ… Ready" if llava_ready else "⚠️ Loading Issues"
gr.Markdown(f"""
**Model Status:** {status}
**AI Model:** LLaVA-v1.6-Mistral-7B
**Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}
**Compatibility:** Fixed for stability
**Rate Limit:** 20 requests/hour
""")
gr.Markdown("## πŸ“Š Usage Statistics")
stats_display = gr.Markdown("")
refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats", size="sm")
if llava_ready:
gr.Markdown("## 🎯 Quick Examples")
general_btn = gr.Button("General Analysis", size="sm")
findings_btn = gr.Button("Find Abnormalities", size="sm")
interpret_btn = gr.Button("Medical Interpretation", size="sm")
# Example cases
if llava_ready:
with gr.Accordion("πŸ“š Example Cases", open=False):
examples = gr.Examples(
examples=[
[
"https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
"Please analyze this chest X-ray and describe any findings. Assess the image quality, identify normal structures, and note any abnormalities.",
"Adult patient with respiratory symptoms"
]
],
inputs=[image_input, clinical_question, patient_history]
)
# Event handlers
analyze_btn.click(
fn=analyze_medical_image_llava,
inputs=[image_input, clinical_question, patient_history],
outputs=output,
show_progress=True
)
clear_btn.click(
fn=lambda: (None, "", "", ""),
outputs=[image_input, clinical_question, patient_history, output]
)
refresh_stats_btn.click(
fn=get_usage_stats,
outputs=stats_display
)
# Quick example handlers
if llava_ready:
general_btn.click(
fn=lambda: ("Analyze this medical image comprehensively. Describe what you observe and provide medical interpretation.", ""),
outputs=[clinical_question, patient_history]
)
findings_btn.click(
fn=lambda: ("What abnormalities or pathological findings are visible in this medical image?", ""),
outputs=[clinical_question, patient_history]
)
interpret_btn.click(
fn=lambda: ("Provide medical interpretation of this image including clinical significance of any findings.", ""),
outputs=[clinical_question, patient_history]
)
# Footer
gr.Markdown("""
---
### πŸ€– LLaVA Medical AI
**Large Language and Vision Assistant** optimized for medical image analysis with compatibility fixes for stable operation.
**Features:**
- Advanced medical image interpretation
- Systematic clinical analysis approach
- Educational medical explanations
- Comprehensive error handling
**Model:** LLaVA-v1.6-Mistral-7B | **Purpose:** Medical Education & Research
""")
return demo
# Launch the application
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)