|
import gradio as gr |
|
import librosa |
|
import numpy as np |
|
import tensorflow as tf |
|
import matplotlib.pyplot as plt |
|
from datetime import datetime |
|
import json |
|
import os |
|
from PIL import Image |
|
import google.generativeai as genai |
|
from typing import Dict, List, Tuple, Optional |
|
|
|
|
|
|
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
gemini_model = genai.GenerativeModel('gemini-1.5-flash') |
|
|
|
|
|
@gr.utils.cache |
|
def load_heartbeat_model(): |
|
try: |
|
model = tf.keras.models.load_model('Heart_ResNet.h5') |
|
return model |
|
except: |
|
print("Warning: Heart_ResNet.h5 model not found. Using mock predictions.") |
|
return None |
|
|
|
heartbeat_model = load_heartbeat_model() |
|
|
|
|
|
patient_data = {} |
|
|
|
def process_audio(file_path: str) -> Tuple[np.ndarray, np.ndarray, int]: |
|
"""Process audio file and extract MFCC features.""" |
|
SAMPLE_RATE = 22050 |
|
DURATION = 10 |
|
input_length = int(SAMPLE_RATE * DURATION) |
|
|
|
try: |
|
X, sr = librosa.load(file_path, sr=SAMPLE_RATE, duration=DURATION) |
|
|
|
if len(X) < input_length: |
|
pad_width = input_length - len(X) |
|
X = np.pad(X, (0, pad_width), mode='constant') |
|
|
|
mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52, |
|
n_fft=512, hop_length=256).T, axis=0) |
|
return mfccs, X, sr |
|
except Exception as e: |
|
print(f"Error processing audio: {e}") |
|
return None, None, None |
|
|
|
def analyze_heartbeat(audio_file) -> Tuple[Dict, str]: |
|
"""Analyze heartbeat audio and return results with visualization.""" |
|
if audio_file is None: |
|
return {}, "No audio file provided" |
|
|
|
try: |
|
mfccs, waveform, sr = process_audio(audio_file) |
|
if mfccs is None: |
|
return {}, "Error processing audio file" |
|
|
|
if heartbeat_model is not None: |
|
features = mfccs.reshape(1, 52, 1) |
|
preds = heartbeat_model.predict(features) |
|
class_names = ["artifact", "murmur", "normal"] |
|
results = {name: float(preds[0][i]) for i, name in enumerate(class_names)} |
|
else: |
|
|
|
results = {"artifact": 0.15, "murmur": 0.25, "normal": 0.60} |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 4)) |
|
librosa.display.waveshow(waveform, sr=sr, ax=ax) |
|
ax.set_title("Heartbeat Waveform Analysis") |
|
ax.set_xlabel("Time (seconds)") |
|
ax.set_ylabel("Amplitude") |
|
plt.tight_layout() |
|
|
|
|
|
plot_path = f"temp_waveform_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" |
|
plt.savefig(plot_path, dpi=150, bbox_inches='tight') |
|
plt.close() |
|
|
|
return results, plot_path |
|
|
|
except Exception as e: |
|
return {}, f"Error analyzing heartbeat: {str(e)}" |
|
|
|
def analyze_medical_image(image) -> str: |
|
"""Analyze medical images using Gemini Vision.""" |
|
if image is None: |
|
return "No image provided" |
|
|
|
try: |
|
|
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(image) |
|
|
|
prompt = """ |
|
Analyze this medical image/investigation result. Please provide: |
|
1. Type of investigation/scan |
|
2. Key findings visible in the image |
|
3. Any abnormalities or areas of concern |
|
4. Recommendations for follow-up if needed |
|
|
|
Please be thorough but remember this is for educational purposes and should not replace professional medical diagnosis. |
|
""" |
|
|
|
response = gemini_model.generate_content([prompt, image]) |
|
return response.text |
|
|
|
except Exception as e: |
|
return f"Error analyzing image: {str(e)}" |
|
|
|
def generate_comprehensive_assessment(patient_info: Dict) -> str: |
|
"""Generate comprehensive medical assessment using Gemini AI.""" |
|
try: |
|
|
|
prompt = f""" |
|
Based on the following comprehensive patient data, provide a detailed medical assessment: |
|
|
|
PATIENT DEMOGRAPHICS: |
|
- Name: {patient_info.get('name', 'Not provided')} |
|
- Age: {patient_info.get('age', 'Not provided')} |
|
- Sex: {patient_info.get('sex', 'Not provided')} |
|
- Weight: {patient_info.get('weight', 'Not provided')} kg |
|
- Height: {patient_info.get('height', 'Not provided')} cm |
|
|
|
CHIEF COMPLAINT: |
|
{patient_info.get('complaint', 'Not provided')} |
|
|
|
MEDICAL HISTORY: |
|
{patient_info.get('medical_history', 'Not provided')} |
|
|
|
PHYSICAL EXAMINATION: |
|
{patient_info.get('examination', 'Not provided')} |
|
|
|
HEART SOUNDS ANALYSIS: |
|
{patient_info.get('heartbeat_analysis', 'Not performed')} |
|
|
|
INVESTIGATIONS: |
|
{patient_info.get('investigation_analysis', 'Not provided')} |
|
|
|
Please provide a comprehensive medical assessment including: |
|
1. Clinical Summary |
|
2. Differential Diagnosis (list possible conditions) |
|
3. Risk Factors Assessment |
|
4. Recommended Treatment Plan |
|
5. Follow-up Recommendations |
|
6. Patient Education Points |
|
7. Prognosis |
|
|
|
Please structure your response professionally and remember this is for educational purposes. |
|
""" |
|
|
|
response = gemini_model.generate_content(prompt) |
|
return response.text |
|
|
|
except Exception as e: |
|
return f"Error generating assessment: {str(e)}" |
|
|
|
def save_patient_data(name, age, sex, weight, height, complaint, medical_history, |
|
examination, heartbeat_results, investigation_analysis): |
|
"""Save all patient data to global storage.""" |
|
global patient_data |
|
|
|
patient_data = { |
|
'name': name, |
|
'age': age, |
|
'sex': sex, |
|
'weight': weight, |
|
'height': height, |
|
'complaint': complaint, |
|
'medical_history': medical_history, |
|
'examination': examination, |
|
'heartbeat_analysis': heartbeat_results, |
|
'investigation_analysis': investigation_analysis, |
|
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
} |
|
|
|
return "Patient data saved successfully!" |
|
|
|
def process_complete_consultation(name, age, sex, weight, height, complaint, |
|
medical_history, examination, audio_file, |
|
investigation_image): |
|
"""Process complete medical consultation.""" |
|
|
|
|
|
heartbeat_results = "" |
|
waveform_plot = None |
|
if audio_file is not None: |
|
results, plot_path = analyze_heartbeat(audio_file) |
|
if results: |
|
heartbeat_results = f""" |
|
Heartbeat Analysis Results: |
|
- Normal: {results.get('normal', 0)*100:.1f}% |
|
- Murmur: {results.get('murmur', 0)*100:.1f}% |
|
- Artifact: {results.get('artifact', 0)*100:.1f}% |
|
""" |
|
waveform_plot = plot_path |
|
|
|
|
|
investigation_analysis = "" |
|
if investigation_image is not None: |
|
investigation_analysis = analyze_medical_image(investigation_image) |
|
|
|
|
|
save_patient_data(name, age, sex, weight, height, complaint, medical_history, |
|
examination, heartbeat_results, investigation_analysis) |
|
|
|
|
|
comprehensive_assessment = generate_comprehensive_assessment(patient_data) |
|
|
|
return comprehensive_assessment, waveform_plot, heartbeat_results, investigation_analysis |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="Comprehensive Medical Consultation System", theme=gr.themes.Soft()) as demo: |
|
|
|
gr.Markdown(""" |
|
# π₯ Comprehensive Medical Consultation System |
|
### Integrated AI-Powered Medical Assessment Platform |
|
""") |
|
|
|
with gr.Tab("π Patient Information"): |
|
gr.Markdown("## Patient Demographics") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
name = gr.Textbox(label="Full Name", placeholder="Enter patient's full name") |
|
age = gr.Number(label="Age (years)", minimum=0, maximum=120) |
|
sex = gr.Radio(["Male", "Female", "Other"], label="Sex") |
|
|
|
with gr.Column(): |
|
weight = gr.Number(label="Weight (kg)", minimum=0, maximum=300) |
|
height = gr.Number(label="Height (cm)", minimum=0, maximum=250) |
|
|
|
gr.Markdown("## Chief Complaint") |
|
complaint = gr.Textbox( |
|
label="Chief Complaint", |
|
placeholder="Describe the main symptoms or reason for consultation...", |
|
lines=3 |
|
) |
|
|
|
gr.Markdown("## Medical History") |
|
medical_history = gr.Textbox( |
|
label="Past Medical History", |
|
placeholder="Include previous illnesses, surgeries, medications, allergies, family history...", |
|
lines=5 |
|
) |
|
|
|
with gr.Tab("π©Ί Physical Examination"): |
|
gr.Markdown("## Physical Examination Findings") |
|
|
|
examination = gr.Textbox( |
|
label="Examination Findings", |
|
placeholder="General appearance, vital signs, systemic examination findings...", |
|
lines=6 |
|
) |
|
|
|
gr.Markdown("## Heart Sounds Analysis") |
|
audio_file = gr.Audio( |
|
label="Heart Sounds Recording", |
|
type="filepath", |
|
sources=["upload", "microphone"] |
|
) |
|
|
|
heartbeat_analyze_btn = gr.Button("π Analyze Heart Sounds", variant="secondary") |
|
heartbeat_results = gr.Textbox(label="Heart Sounds Analysis Results", lines=4) |
|
waveform_plot = gr.Image(label="Heart Sounds Waveform") |
|
|
|
heartbeat_analyze_btn.click( |
|
fn=analyze_heartbeat, |
|
inputs=[audio_file], |
|
outputs=[heartbeat_results, waveform_plot] |
|
) |
|
|
|
with gr.Tab("π¬ Investigations"): |
|
gr.Markdown("## Medical Investigations & Imaging") |
|
|
|
investigation_image = gr.Image( |
|
label="Upload Investigation Results (X-ray, ECG, Lab reports, etc.)", |
|
type="pil" |
|
) |
|
|
|
investigate_btn = gr.Button("π Analyze Investigation", variant="secondary") |
|
investigation_results = gr.Textbox( |
|
label="Investigation Analysis", |
|
lines=6, |
|
placeholder="AI analysis of uploaded investigation will appear here..." |
|
) |
|
|
|
investigate_btn.click( |
|
fn=analyze_medical_image, |
|
inputs=[investigation_image], |
|
outputs=[investigation_results] |
|
) |
|
|
|
with gr.Tab("π€ AI Assessment"): |
|
gr.Markdown("## Comprehensive Medical Assessment") |
|
|
|
generate_btn = gr.Button( |
|
"π§ Generate Comprehensive Assessment", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
assessment_output = gr.Textbox( |
|
label="AI-Generated Medical Assessment", |
|
lines=15, |
|
placeholder="Complete medical assessment will be generated here based on all provided information..." |
|
) |
|
|
|
|
|
hidden_heartbeat = gr.Textbox(visible=False) |
|
hidden_investigation = gr.Textbox(visible=False) |
|
hidden_waveform = gr.Image(visible=False) |
|
|
|
generate_btn.click( |
|
fn=process_complete_consultation, |
|
inputs=[name, age, sex, weight, height, complaint, medical_history, |
|
examination, audio_file, investigation_image], |
|
outputs=[assessment_output, hidden_waveform, hidden_heartbeat, |
|
hidden_investigation] |
|
) |
|
|
|
with gr.Tab("π Patient Summary"): |
|
gr.Markdown("## Patient Data Summary") |
|
|
|
refresh_btn = gr.Button("π Refresh Patient Data", variant="secondary") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
summary_demographics = gr.JSON(label="Demographics") |
|
summary_clinical = gr.JSON(label="Clinical Data") |
|
|
|
with gr.Column(): |
|
summary_results = gr.JSON(label="Investigation Results") |
|
|
|
def refresh_patient_summary(): |
|
if patient_data: |
|
demographics = { |
|
"Name": patient_data.get('name', 'N/A'), |
|
"Age": patient_data.get('age', 'N/A'), |
|
"Sex": patient_data.get('sex', 'N/A'), |
|
"Weight": f"{patient_data.get('weight', 'N/A')} kg", |
|
"Height": f"{patient_data.get('height', 'N/A')} cm" |
|
} |
|
|
|
clinical = { |
|
"Chief Complaint": patient_data.get('complaint', 'N/A'), |
|
"Medical History": patient_data.get('medical_history', 'N/A')[:100] + "..." if len(patient_data.get('medical_history', '')) > 100 else patient_data.get('medical_history', 'N/A'), |
|
"Examination": patient_data.get('examination', 'N/A')[:100] + "..." if len(patient_data.get('examination', '')) > 100 else patient_data.get('examination', 'N/A') |
|
} |
|
|
|
results = { |
|
"Heartbeat Analysis": "Completed" if patient_data.get('heartbeat_analysis') else "Not performed", |
|
"Investigation Analysis": "Completed" if patient_data.get('investigation_analysis') else "Not performed", |
|
"Last Updated": patient_data.get('timestamp', 'N/A') |
|
} |
|
|
|
return demographics, clinical, results |
|
else: |
|
return {}, {}, {} |
|
|
|
refresh_btn.click( |
|
fn=refresh_patient_summary, |
|
outputs=[summary_demographics, summary_clinical, summary_results] |
|
) |
|
|
|
gr.Markdown(""" |
|
--- |
|
### π Important Notes: |
|
- This system is for educational and research purposes only |
|
- Always consult qualified healthcare professionals for medical decisions |
|
- Ensure patient privacy and data protection compliance |
|
- AI assessments should supplement, not replace, clinical judgment |
|
""") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if not os.getenv("GOOGLE_API_KEY"): |
|
print("Warning: GOOGLE_API_KEY not set. Gemini AI features will not work.") |
|
print("Set your API key with: export GOOGLE_API_KEY='your_api_key_here'") |
|
|
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True, |
|
debug=True |
|
) |