|
|
|
import gradio as gr |
|
import torch |
|
import logging |
|
from collections import defaultdict, Counter |
|
import time |
|
import traceback |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def fix_transformers_compatibility(): |
|
"""Fix compatibility issues with transformers library""" |
|
try: |
|
|
|
import transformers.modeling_utils as modeling_utils |
|
if not hasattr(modeling_utils, 'ALL_PARALLEL_STYLES'): |
|
modeling_utils.ALL_PARALLEL_STYLES = [] |
|
elif getattr(modeling_utils, 'ALL_PARALLEL_STYLES', None) is None: |
|
modeling_utils.ALL_PARALLEL_STYLES = [] |
|
|
|
|
|
try: |
|
import transformers.models.llava_next.modeling_llava_next as llava_next |
|
if not hasattr(llava_next, 'ALL_PARALLEL_STYLES'): |
|
llava_next.ALL_PARALLEL_STYLES = [] |
|
elif getattr(llava_next, 'ALL_PARALLEL_STYLES', None) is None: |
|
llava_next.ALL_PARALLEL_STYLES = [] |
|
except ImportError: |
|
pass |
|
|
|
|
|
try: |
|
import transformers.models.mistral.modeling_mistral as mistral |
|
if not hasattr(mistral, 'ALL_PARALLEL_STYLES'): |
|
mistral.ALL_PARALLEL_STYLES = [] |
|
elif getattr(mistral, 'ALL_PARALLEL_STYLES', None) is None: |
|
mistral.ALL_PARALLEL_STYLES = [] |
|
except ImportError: |
|
pass |
|
|
|
logger.info("β
Applied compatibility fixes") |
|
return True |
|
except Exception as e: |
|
logger.warning(f"β οΈ Could not apply compatibility fixes: {e}") |
|
return False |
|
|
|
|
|
fix_transformers_compatibility() |
|
|
|
|
|
try: |
|
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration |
|
from PIL import Image |
|
logger.info("β
Transformers imported successfully") |
|
except Exception as e: |
|
logger.error(f"β Failed to import transformers: {e}") |
|
|
|
try: |
|
from transformers import LlavaProcessor, LlavaForConditionalGeneration as LlavaNextForConditionalGeneration |
|
from transformers import AutoProcessor as LlavaNextProcessor |
|
logger.info("β
Using fallback LLaVA imports") |
|
except Exception as e2: |
|
logger.error(f"β Fallback imports also failed: {e2}") |
|
|
|
|
|
class UsageTracker: |
|
def __init__(self): |
|
self.stats = { |
|
'total_analyses': 0, |
|
'successful_analyses': 0, |
|
'failed_analyses': 0, |
|
'average_processing_time': 0.0, |
|
'question_types': Counter() |
|
} |
|
|
|
def log_analysis(self, success, duration, question_type=None): |
|
self.stats['total_analyses'] += 1 |
|
if success: |
|
self.stats['successful_analyses'] += 1 |
|
else: |
|
self.stats['failed_analyses'] += 1 |
|
|
|
total_time = self.stats['average_processing_time'] * (self.stats['total_analyses'] - 1) |
|
self.stats['average_processing_time'] = (total_time + duration) / self.stats['total_analyses'] |
|
|
|
if question_type: |
|
self.stats['question_types'][question_type] += 1 |
|
|
|
|
|
class RateLimiter: |
|
def __init__(self, max_requests_per_hour=20): |
|
self.max_requests_per_hour = max_requests_per_hour |
|
self.requests = defaultdict(list) |
|
|
|
def is_allowed(self, user_id="default"): |
|
current_time = time.time() |
|
hour_ago = current_time - 3600 |
|
self.requests[user_id] = [req_time for req_time in self.requests[user_id] if req_time > hour_ago] |
|
if len(self.requests[user_id]) < self.max_requests_per_hour: |
|
self.requests[user_id].append(current_time) |
|
return True |
|
return False |
|
|
|
|
|
usage_tracker = UsageTracker() |
|
rate_limiter = RateLimiter() |
|
|
|
|
|
MODEL_ID = "llava-hf/llava-v1.6-mistral-7b-hf" |
|
|
|
|
|
model = None |
|
processor = None |
|
|
|
def load_llava_safe(): |
|
"""Load LLaVA model with comprehensive error handling""" |
|
global model, processor |
|
|
|
try: |
|
logger.info(f"Loading LLaVA model: {MODEL_ID}") |
|
|
|
|
|
loading_methods = [ |
|
("Standard LlavaNext", lambda: ( |
|
LlavaNextProcessor.from_pretrained(MODEL_ID), |
|
LlavaNextForConditionalGeneration.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.float32, |
|
device_map=None, |
|
low_cpu_mem_usage=True, |
|
attn_implementation="eager" |
|
) |
|
)), |
|
("Auto Processor Fallback", lambda: ( |
|
LlavaNextProcessor.from_pretrained(MODEL_ID), |
|
LlavaNextForConditionalGeneration.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True, |
|
use_safetensors=True |
|
) |
|
)), |
|
] |
|
|
|
for method_name, method_func in loading_methods: |
|
try: |
|
logger.info(f"Trying {method_name}...") |
|
processor, model = method_func() |
|
logger.info(f"β
LLaVA loaded successfully using {method_name}!") |
|
return True |
|
except Exception as e: |
|
logger.warning(f"β {method_name} failed: {str(e)}") |
|
continue |
|
|
|
logger.error("β All loading methods failed") |
|
return False |
|
|
|
except Exception as e: |
|
logger.error(f"β Error loading LLaVA: {str(e)}") |
|
logger.error(f"Full traceback: {traceback.format_exc()}") |
|
return False |
|
|
|
|
|
llava_ready = load_llava_safe() |
|
|
|
def analyze_medical_image_llava(image, clinical_question, patient_history=""): |
|
"""Analyze medical image using LLaVA with robust error handling""" |
|
start_time = time.time() |
|
|
|
|
|
if not rate_limiter.is_allowed(): |
|
usage_tracker.log_analysis(False, time.time() - start_time) |
|
return "β οΈ Rate limit exceeded. Please wait before trying again." |
|
|
|
if not llava_ready or model is None: |
|
usage_tracker.log_analysis(False, time.time() - start_time) |
|
return """β **LLaVA Model Loading Issue** |
|
|
|
The LLaVA model failed to load due to compatibility issues. This is often caused by: |
|
|
|
1. **Library Version Conflicts**: Try refreshing the page - we've applied compatibility fixes |
|
2. **Memory Constraints**: The 7B model requires significant resources |
|
3. **Transformers Version**: Some versions have compatibility issues |
|
|
|
**Suggested Solutions:** |
|
- **Refresh the page** and wait 2-3 minutes for model loading |
|
- **Upgrade to GPU hardware** for better performance and stability |
|
- **Try a different image** if the issue persists |
|
|
|
**Technical Info**: There may be version conflicts in the transformers library. The model files downloaded successfully but initialization failed.""" |
|
|
|
if image is None: |
|
return "β οΈ Please upload a medical image first." |
|
|
|
if not clinical_question.strip(): |
|
return "β οΈ Please provide a clinical question." |
|
|
|
try: |
|
logger.info("Starting LLaVA medical analysis...") |
|
|
|
|
|
medical_prompt = f"""You are an expert medical AI assistant analyzing medical images. Please provide a comprehensive medical analysis. |
|
|
|
{f"Patient History: {patient_history}" if patient_history.strip() else ""} |
|
|
|
Clinical Question: {clinical_question} |
|
|
|
Please analyze this medical image systematically: |
|
|
|
1. **Image Quality**: Assess technical quality and diagnostic adequacy |
|
2. **Anatomical Structures**: Identify visible normal structures |
|
3. **Abnormal Findings**: Describe any pathological changes |
|
4. **Clinical Significance**: Explain the importance of findings |
|
5. **Assessment**: Provide clinical interpretation |
|
6. **Recommendations**: Suggest next steps if appropriate |
|
|
|
Provide detailed, educational medical analysis suitable for learning purposes.""" |
|
|
|
|
|
prompt_formats = [ |
|
|
|
lambda: f"USER: <image>\n{medical_prompt}\nASSISTANT:", |
|
|
|
|
|
lambda: processor.apply_chat_template([ |
|
{"role": "user", "content": [ |
|
{"type": "image", "image": image}, |
|
{"type": "text", "text": medical_prompt} |
|
]} |
|
], add_generation_prompt=True), |
|
|
|
|
|
lambda: medical_prompt |
|
] |
|
|
|
|
|
for i, prompt_func in enumerate(prompt_formats): |
|
try: |
|
logger.info(f"Trying prompt format {i+1}...") |
|
|
|
if i == 1: |
|
try: |
|
prompt = prompt_func() |
|
except: |
|
continue |
|
else: |
|
prompt = prompt_func() |
|
|
|
|
|
inputs = processor(prompt, image, return_tensors='pt') |
|
|
|
|
|
logger.info("Generating medical analysis...") |
|
with torch.inference_mode(): |
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=1000, |
|
do_sample=True, |
|
temperature=0.3, |
|
top_p=0.9, |
|
repetition_penalty=1.1, |
|
use_cache=False |
|
) |
|
|
|
|
|
generated_text = processor.decode(output[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) |
|
|
|
if generated_text and generated_text.strip(): |
|
break |
|
|
|
except Exception as e: |
|
logger.warning(f"Prompt format {i+1} failed: {e}") |
|
if i == len(prompt_formats) - 1: |
|
raise e |
|
continue |
|
|
|
|
|
response = generated_text.strip() if generated_text else "Analysis completed." |
|
|
|
|
|
formatted_response = f"""# π₯ **LLaVA Medical Analysis** |
|
|
|
## **Clinical Question:** {clinical_question} |
|
{f"## **Patient History:** {patient_history}" if patient_history.strip() else ""} |
|
|
|
--- |
|
|
|
## π **Medical Analysis Results** |
|
|
|
{response} |
|
|
|
--- |
|
|
|
## π **Clinical Summary** |
|
|
|
This analysis was generated using LLaVA (Large Language and Vision Assistant) for educational purposes. The findings should be interpreted by qualified medical professionals and correlated with clinical presentation. |
|
|
|
**Key Points:** |
|
- Analysis based on visual medical image interpretation |
|
- Systematic approach to medical imaging assessment |
|
- Educational tool for medical learning and training |
|
- Requires professional medical validation |
|
""" |
|
|
|
|
|
disclaimer = """ |
|
--- |
|
## β οΈ **MEDICAL DISCLAIMER** |
|
|
|
**FOR EDUCATIONAL PURPOSES ONLY** |
|
|
|
- **Not Diagnostic**: This AI analysis is not a medical diagnosis |
|
- **Professional Review**: All findings require validation by healthcare professionals |
|
- **Emergency Care**: Contact emergency services for urgent medical concerns |
|
- **Educational Tool**: Designed for medical education and training |
|
- **No PHI**: Do not upload patient identifiable information |
|
|
|
--- |
|
**Powered by**: LLaVA (Large Language and Vision Assistant) |
|
""" |
|
|
|
|
|
duration = time.time() - start_time |
|
question_type = classify_question(clinical_question) |
|
usage_tracker.log_analysis(True, duration, question_type) |
|
|
|
logger.info("β
LLaVA medical analysis completed successfully") |
|
return formatted_response + disclaimer |
|
|
|
except Exception as e: |
|
duration = time.time() - start_time |
|
usage_tracker.log_analysis(False, duration) |
|
logger.error(f"β LLaVA analysis error: {str(e)}") |
|
|
|
return f"""β **Analysis Error** |
|
|
|
The analysis failed with error: {str(e)} |
|
|
|
**Common Solutions:** |
|
- **Try again**: Sometimes temporary processing issues occur |
|
- **Smaller image**: Try with a smaller or different format image |
|
- **Simpler question**: Use a more straightforward clinical question |
|
- **Refresh page**: Reload the page if model seems unstable |
|
|
|
**Technical Details:** {str(e)[:200]}""" |
|
|
|
def classify_question(question): |
|
"""Classify clinical question type""" |
|
question_lower = question.lower() |
|
if any(word in question_lower for word in ['describe', 'findings', 'observe']): |
|
return 'descriptive' |
|
elif any(word in question_lower for word in ['diagnosis', 'differential', 'condition']): |
|
return 'diagnostic' |
|
elif any(word in question_lower for word in ['abnormal', 'pathology', 'disease']): |
|
return 'pathological' |
|
else: |
|
return 'general' |
|
|
|
def get_usage_stats(): |
|
"""Get usage statistics""" |
|
stats = usage_tracker.stats |
|
if stats['total_analyses'] == 0: |
|
return "π **Usage Statistics**\n\nNo analyses performed yet." |
|
|
|
success_rate = (stats['successful_analyses'] / stats['total_analyses']) * 100 |
|
|
|
return f"""π **LLaVA Usage Statistics** |
|
|
|
**Performance:** |
|
- Total Analyses: {stats['total_analyses']} |
|
- Success Rate: {success_rate:.1f}% |
|
- Avg Processing Time: {stats['average_processing_time']:.2f}s |
|
|
|
**Popular Question Types:** |
|
{chr(10).join([f"- {qtype}: {count}" for qtype, count in stats['question_types'].most_common(3)])} |
|
|
|
**Model Status**: {'π’ Ready' if llava_ready else 'π΄ Loading Issues'} |
|
""" |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks( |
|
title="LLaVA Medical Analysis", |
|
theme=gr.themes.Soft(), |
|
css=""" |
|
.gradio-container { max-width: 1200px !important; } |
|
.disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; } |
|
.success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px; margin: 16px 0; } |
|
.warning { background-color: #fffbeb; border: 1px solid #fed7aa; border-radius: 8px; padding: 16px; margin: 16px 0; } |
|
""" |
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
# π₯ LLaVA Medical Image Analysis |
|
|
|
**Advanced Medical AI powered by LLaVA (Large Language and Vision Assistant)** |
|
|
|
**Medical Capabilities:** π« Radiology β’ π¬ Pathology β’ π©Ί Dermatology β’ ποΈ Ophthalmology |
|
""") |
|
|
|
|
|
if llava_ready: |
|
gr.Markdown(""" |
|
<div class="success"> |
|
β
<strong>LLAVA MEDICAL AI READY</strong><br> |
|
LLaVA model loaded successfully with compatibility fixes. Ready for medical image analysis. |
|
</div> |
|
""") |
|
else: |
|
gr.Markdown(""" |
|
<div class="warning"> |
|
β οΈ <strong>MODEL LOADING ISSUE</strong><br> |
|
LLaVA model had loading problems. Try refreshing the page or contact support for assistance. |
|
</div> |
|
""") |
|
|
|
|
|
gr.Markdown(""" |
|
<div class="disclaimer"> |
|
β οΈ <strong>MEDICAL DISCLAIMER</strong><br> |
|
This AI provides medical analysis for <strong>educational purposes only</strong>. |
|
Do not upload real patient data. Always consult healthcare professionals for medical decisions. |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## π€ Medical Image") |
|
image_input = gr.Image( |
|
label="Upload Medical Image", |
|
type="pil", |
|
height=300 |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("## π¬ Clinical Information") |
|
clinical_question = gr.Textbox( |
|
label="Clinical Question *", |
|
placeholder="Examples:\nβ’ Analyze this medical image\nβ’ What abnormalities are visible?\nβ’ Describe the findings\nβ’ Provide medical interpretation", |
|
lines=4 |
|
) |
|
|
|
patient_history = gr.Textbox( |
|
label="Patient History (Optional)", |
|
placeholder="e.g., 45-year-old with chest pain", |
|
lines=2 |
|
) |
|
|
|
with gr.Row(): |
|
clear_btn = gr.Button("ποΈ Clear", variant="secondary") |
|
analyze_btn = gr.Button("π Analyze with LLaVA", variant="primary", size="lg") |
|
|
|
gr.Markdown("## π Medical Analysis Results") |
|
output = gr.Textbox( |
|
label="LLaVA Medical Analysis", |
|
lines=20, |
|
show_copy_button=True, |
|
placeholder="Upload a medical image and clinical question..." if llava_ready else "Model loading issues - please refresh the page" |
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## βΉοΈ System Status") |
|
|
|
status = "β
Ready" if llava_ready else "β οΈ Loading Issues" |
|
|
|
gr.Markdown(f""" |
|
**Model Status:** {status} |
|
**AI Model:** LLaVA-v1.6-Mistral-7B |
|
**Device:** {'GPU' if torch.cuda.is_available() else 'CPU'} |
|
**Compatibility:** Fixed for stability |
|
**Rate Limit:** 20 requests/hour |
|
""") |
|
|
|
gr.Markdown("## π Usage Statistics") |
|
stats_display = gr.Markdown("") |
|
refresh_stats_btn = gr.Button("π Refresh Stats", size="sm") |
|
|
|
if llava_ready: |
|
gr.Markdown("## π― Quick Examples") |
|
general_btn = gr.Button("General Analysis", size="sm") |
|
findings_btn = gr.Button("Find Abnormalities", size="sm") |
|
interpret_btn = gr.Button("Medical Interpretation", size="sm") |
|
|
|
|
|
if llava_ready: |
|
with gr.Accordion("π Example Cases", open=False): |
|
examples = gr.Examples( |
|
examples=[ |
|
[ |
|
"https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png", |
|
"Please analyze this chest X-ray and describe any findings. Assess the image quality, identify normal structures, and note any abnormalities.", |
|
"Adult patient with respiratory symptoms" |
|
] |
|
], |
|
inputs=[image_input, clinical_question, patient_history] |
|
) |
|
|
|
|
|
analyze_btn.click( |
|
fn=analyze_medical_image_llava, |
|
inputs=[image_input, clinical_question, patient_history], |
|
outputs=output, |
|
show_progress=True |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: (None, "", "", ""), |
|
outputs=[image_input, clinical_question, patient_history, output] |
|
) |
|
|
|
refresh_stats_btn.click( |
|
fn=get_usage_stats, |
|
outputs=stats_display |
|
) |
|
|
|
|
|
if llava_ready: |
|
general_btn.click( |
|
fn=lambda: ("Analyze this medical image comprehensively. Describe what you observe and provide medical interpretation.", ""), |
|
outputs=[clinical_question, patient_history] |
|
) |
|
|
|
findings_btn.click( |
|
fn=lambda: ("What abnormalities or pathological findings are visible in this medical image?", ""), |
|
outputs=[clinical_question, patient_history] |
|
) |
|
|
|
interpret_btn.click( |
|
fn=lambda: ("Provide medical interpretation of this image including clinical significance of any findings.", ""), |
|
outputs=[clinical_question, patient_history] |
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
--- |
|
### π€ LLaVA Medical AI |
|
|
|
**Large Language and Vision Assistant** optimized for medical image analysis with compatibility fixes for stable operation. |
|
|
|
**Features:** |
|
- Advanced medical image interpretation |
|
- Systematic clinical analysis approach |
|
- Educational medical explanations |
|
- Comprehensive error handling |
|
|
|
**Model:** LLaVA-v1.6-Mistral-7B | **Purpose:** Medical Education & Research |
|
""") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True |
|
) |