lumenex2 / app.py
walaa2022's picture
Update app.py
699420d verified
raw
history blame
15.6 kB
import gradio as gr
import librosa
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime
import json
import os
from PIL import Image
import google.generativeai as genai
from typing import Dict, List, Tuple, Optional
# Configure Gemini AI
# You'll need to set your API key: export GOOGLE_API_KEY="your_api_key_here"
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
gemini_model = genai.GenerativeModel('gemini-1.5-flash')
# Load the pre-trained ResNet model
@gr.utils.cache
def load_heartbeat_model():
try:
model = tf.keras.models.load_model('Heart_ResNet.h5')
return model
except:
print("Warning: Heart_ResNet.h5 model not found. Using mock predictions.")
return None
heartbeat_model = load_heartbeat_model()
# Global storage for patient data (in production, use a proper database)
patient_data = {}
def process_audio(file_path: str) -> Tuple[np.ndarray, np.ndarray, int]:
"""Process audio file and extract MFCC features."""
SAMPLE_RATE = 22050
DURATION = 10
input_length = int(SAMPLE_RATE * DURATION)
try:
X, sr = librosa.load(file_path, sr=SAMPLE_RATE, duration=DURATION)
if len(X) < input_length:
pad_width = input_length - len(X)
X = np.pad(X, (0, pad_width), mode='constant')
mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52,
n_fft=512, hop_length=256).T, axis=0)
return mfccs, X, sr
except Exception as e:
print(f"Error processing audio: {e}")
return None, None, None
def analyze_heartbeat(audio_file) -> Tuple[Dict, str]:
"""Analyze heartbeat audio and return results with visualization."""
if audio_file is None:
return {}, "No audio file provided"
try:
mfccs, waveform, sr = process_audio(audio_file)
if mfccs is None:
return {}, "Error processing audio file"
if heartbeat_model is not None:
features = mfccs.reshape(1, 52, 1)
preds = heartbeat_model.predict(features)
class_names = ["artifact", "murmur", "normal"]
results = {name: float(preds[0][i]) for i, name in enumerate(class_names)}
else:
# Mock results for demonstration
results = {"artifact": 0.15, "murmur": 0.25, "normal": 0.60}
# Create waveform visualization
fig, ax = plt.subplots(figsize=(12, 4))
librosa.display.waveshow(waveform, sr=sr, ax=ax)
ax.set_title("Heartbeat Waveform Analysis")
ax.set_xlabel("Time (seconds)")
ax.set_ylabel("Amplitude")
plt.tight_layout()
# Save plot
plot_path = f"temp_waveform_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.close()
return results, plot_path
except Exception as e:
return {}, f"Error analyzing heartbeat: {str(e)}"
def analyze_medical_image(image) -> str:
"""Analyze medical images using Gemini Vision."""
if image is None:
return "No image provided"
try:
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
prompt = """
Analyze this medical image/investigation result. Please provide:
1. Type of investigation/scan
2. Key findings visible in the image
3. Any abnormalities or areas of concern
4. Recommendations for follow-up if needed
Please be thorough but remember this is for educational purposes and should not replace professional medical diagnosis.
"""
response = gemini_model.generate_content([prompt, image])
return response.text
except Exception as e:
return f"Error analyzing image: {str(e)}"
def generate_comprehensive_assessment(patient_info: Dict) -> str:
"""Generate comprehensive medical assessment using Gemini AI."""
try:
# Prepare comprehensive prompt
prompt = f"""
Based on the following comprehensive patient data, provide a detailed medical assessment:
PATIENT DEMOGRAPHICS:
- Name: {patient_info.get('name', 'Not provided')}
- Age: {patient_info.get('age', 'Not provided')}
- Sex: {patient_info.get('sex', 'Not provided')}
- Weight: {patient_info.get('weight', 'Not provided')} kg
- Height: {patient_info.get('height', 'Not provided')} cm
CHIEF COMPLAINT:
{patient_info.get('complaint', 'Not provided')}
MEDICAL HISTORY:
{patient_info.get('medical_history', 'Not provided')}
PHYSICAL EXAMINATION:
{patient_info.get('examination', 'Not provided')}
HEART SOUNDS ANALYSIS:
{patient_info.get('heartbeat_analysis', 'Not performed')}
INVESTIGATIONS:
{patient_info.get('investigation_analysis', 'Not provided')}
Please provide a comprehensive medical assessment including:
1. Clinical Summary
2. Differential Diagnosis (list possible conditions)
3. Risk Factors Assessment
4. Recommended Treatment Plan
5. Follow-up Recommendations
6. Patient Education Points
7. Prognosis
Please structure your response professionally and remember this is for educational purposes.
"""
response = gemini_model.generate_content(prompt)
return response.text
except Exception as e:
return f"Error generating assessment: {str(e)}"
def save_patient_data(name, age, sex, weight, height, complaint, medical_history,
examination, heartbeat_results, investigation_analysis):
"""Save all patient data to global storage."""
global patient_data
patient_data = {
'name': name,
'age': age,
'sex': sex,
'weight': weight,
'height': height,
'complaint': complaint,
'medical_history': medical_history,
'examination': examination,
'heartbeat_analysis': heartbeat_results,
'investigation_analysis': investigation_analysis,
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
return "Patient data saved successfully!"
def process_complete_consultation(name, age, sex, weight, height, complaint,
medical_history, examination, audio_file,
investigation_image):
"""Process complete medical consultation."""
# Analyze heartbeat if audio provided
heartbeat_results = ""
waveform_plot = None
if audio_file is not None:
results, plot_path = analyze_heartbeat(audio_file)
if results:
heartbeat_results = f"""
Heartbeat Analysis Results:
- Normal: {results.get('normal', 0)*100:.1f}%
- Murmur: {results.get('murmur', 0)*100:.1f}%
- Artifact: {results.get('artifact', 0)*100:.1f}%
"""
waveform_plot = plot_path
# Analyze investigation image if provided
investigation_analysis = ""
if investigation_image is not None:
investigation_analysis = analyze_medical_image(investigation_image)
# Save patient data
save_patient_data(name, age, sex, weight, height, complaint, medical_history,
examination, heartbeat_results, investigation_analysis)
# Generate comprehensive assessment
comprehensive_assessment = generate_comprehensive_assessment(patient_data)
return comprehensive_assessment, waveform_plot, heartbeat_results, investigation_analysis
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Comprehensive Medical Consultation System", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ₯ Comprehensive Medical Consultation System
### Integrated AI-Powered Medical Assessment Platform
""")
with gr.Tab("πŸ“‹ Patient Information"):
gr.Markdown("## Patient Demographics")
with gr.Row():
with gr.Column():
name = gr.Textbox(label="Full Name", placeholder="Enter patient's full name")
age = gr.Number(label="Age (years)", minimum=0, maximum=120)
sex = gr.Radio(["Male", "Female", "Other"], label="Sex")
with gr.Column():
weight = gr.Number(label="Weight (kg)", minimum=0, maximum=300)
height = gr.Number(label="Height (cm)", minimum=0, maximum=250)
gr.Markdown("## Chief Complaint")
complaint = gr.Textbox(
label="Chief Complaint",
placeholder="Describe the main symptoms or reason for consultation...",
lines=3
)
gr.Markdown("## Medical History")
medical_history = gr.Textbox(
label="Past Medical History",
placeholder="Include previous illnesses, surgeries, medications, allergies, family history...",
lines=5
)
with gr.Tab("🩺 Physical Examination"):
gr.Markdown("## Physical Examination Findings")
examination = gr.Textbox(
label="Examination Findings",
placeholder="General appearance, vital signs, systemic examination findings...",
lines=6
)
gr.Markdown("## Heart Sounds Analysis")
audio_file = gr.Audio(
label="Heart Sounds Recording",
type="filepath",
sources=["upload", "microphone"]
)
heartbeat_analyze_btn = gr.Button("πŸ” Analyze Heart Sounds", variant="secondary")
heartbeat_results = gr.Textbox(label="Heart Sounds Analysis Results", lines=4)
waveform_plot = gr.Image(label="Heart Sounds Waveform")
heartbeat_analyze_btn.click(
fn=analyze_heartbeat,
inputs=[audio_file],
outputs=[heartbeat_results, waveform_plot]
)
with gr.Tab("πŸ”¬ Investigations"):
gr.Markdown("## Medical Investigations & Imaging")
investigation_image = gr.Image(
label="Upload Investigation Results (X-ray, ECG, Lab reports, etc.)",
type="pil"
)
investigate_btn = gr.Button("πŸ” Analyze Investigation", variant="secondary")
investigation_results = gr.Textbox(
label="Investigation Analysis",
lines=6,
placeholder="AI analysis of uploaded investigation will appear here..."
)
investigate_btn.click(
fn=analyze_medical_image,
inputs=[investigation_image],
outputs=[investigation_results]
)
with gr.Tab("πŸ€– AI Assessment"):
gr.Markdown("## Comprehensive Medical Assessment")
generate_btn = gr.Button(
"🧠 Generate Comprehensive Assessment",
variant="primary",
size="lg"
)
assessment_output = gr.Textbox(
label="AI-Generated Medical Assessment",
lines=15,
placeholder="Complete medical assessment will be generated here based on all provided information..."
)
# Hidden outputs to collect all data
hidden_heartbeat = gr.Textbox(visible=False)
hidden_investigation = gr.Textbox(visible=False)
hidden_waveform = gr.Image(visible=False)
generate_btn.click(
fn=process_complete_consultation,
inputs=[name, age, sex, weight, height, complaint, medical_history,
examination, audio_file, investigation_image],
outputs=[assessment_output, hidden_waveform, hidden_heartbeat,
hidden_investigation]
)
with gr.Tab("πŸ“Š Patient Summary"):
gr.Markdown("## Patient Data Summary")
refresh_btn = gr.Button("πŸ”„ Refresh Patient Data", variant="secondary")
with gr.Row():
with gr.Column():
summary_demographics = gr.JSON(label="Demographics")
summary_clinical = gr.JSON(label="Clinical Data")
with gr.Column():
summary_results = gr.JSON(label="Investigation Results")
def refresh_patient_summary():
if patient_data:
demographics = {
"Name": patient_data.get('name', 'N/A'),
"Age": patient_data.get('age', 'N/A'),
"Sex": patient_data.get('sex', 'N/A'),
"Weight": f"{patient_data.get('weight', 'N/A')} kg",
"Height": f"{patient_data.get('height', 'N/A')} cm"
}
clinical = {
"Chief Complaint": patient_data.get('complaint', 'N/A'),
"Medical History": patient_data.get('medical_history', 'N/A')[:100] + "..." if len(patient_data.get('medical_history', '')) > 100 else patient_data.get('medical_history', 'N/A'),
"Examination": patient_data.get('examination', 'N/A')[:100] + "..." if len(patient_data.get('examination', '')) > 100 else patient_data.get('examination', 'N/A')
}
results = {
"Heartbeat Analysis": "Completed" if patient_data.get('heartbeat_analysis') else "Not performed",
"Investigation Analysis": "Completed" if patient_data.get('investigation_analysis') else "Not performed",
"Last Updated": patient_data.get('timestamp', 'N/A')
}
return demographics, clinical, results
else:
return {}, {}, {}
refresh_btn.click(
fn=refresh_patient_summary,
outputs=[summary_demographics, summary_clinical, summary_results]
)
gr.Markdown("""
---
### πŸ“ Important Notes:
- This system is for educational and research purposes only
- Always consult qualified healthcare professionals for medical decisions
- Ensure patient privacy and data protection compliance
- AI assessments should supplement, not replace, clinical judgment
""")
return demo
# Launch the application
if __name__ == "__main__":
# Check if required environment variables are set
if not os.getenv("GOOGLE_API_KEY"):
print("Warning: GOOGLE_API_KEY not set. Gemini AI features will not work.")
print("Set your API key with: export GOOGLE_API_KEY='your_api_key_here'")
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=True
)