walaa2022 commited on
Commit
699420d
Β·
verified Β·
1 Parent(s): 997d728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +375 -151
app.py CHANGED
@@ -1,172 +1,396 @@
1
- import streamlit as st
2
  import librosa
3
  import numpy as np
4
  import tensorflow as tf
5
  import matplotlib.pyplot as plt
6
  from datetime import datetime
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Load the pre-trained ResNet model
9
- @st.cache_resource
10
- def load_model():
11
- model = tf.keras.models.load_model('Heart_ResNet.h5')
12
- return model
13
-
14
- model = load_model()
15
-
16
- # Initialize session state
17
- if 'page' not in st.session_state:
18
- st.session_state.page = '🏠 Home'
19
- if 'theme' not in st.session_state:
20
- st.session_state.theme = 'Light Green'
21
- if 'history' not in st.session_state:
22
- st.session_state.history = []
23
-
24
- # Custom CSS for theme
25
- def apply_theme():
26
- if st.session_state.theme == "Light Green":
27
- st.markdown("""
28
- <style>
29
- body, .stApp { background-color: #e8f5e9; }
30
- .stApp { color: #004d40; }
31
- .stButton > button, .stFileUpload > div {
32
- background-color: #004d40;
33
- color: white;
34
- }
35
- .stButton > button:hover, .stFileUpload > div:hover {
36
- background-color: #00332c;
37
- }
38
- </style>
39
- """, unsafe_allow_html=True)
40
- else:
41
- st.markdown("""
42
- <style>
43
- body, .stApp { background-color: #e0f7fa; }
44
- .stApp { color: #006064; }
45
- .stButton > button, .stFileUpload > div {
46
- background-color: #006064;
47
- color: white;
48
- }
49
- .stButton > button:hover, .stFileUpload > div:hover {
50
- background-color: #004d40;
51
- }
52
- </style>
53
- """, unsafe_allow_html=True)
54
-
55
- # Sidebar navigation
56
- with st.sidebar:
57
- st.title("Heartbeat Analysis 🩺")
58
- st.session_state.page = st.radio(
59
- "Navigation",
60
- ["🏠 Home", "βš™οΈ Settings", "πŸ‘€ Profile"],
61
- index=["🏠 Home", "βš™οΈ Settings", "πŸ‘€ Profile"].index(st.session_state.page)
62
- )
63
-
64
- # Audio processing function
65
- def process_audio(file_path):
66
  SAMPLE_RATE = 22050
67
  DURATION = 10
68
  input_length = int(SAMPLE_RATE * DURATION)
69
 
70
- X, sr = librosa.load(file_path, sr=SAMPLE_RATE, duration=DURATION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- if len(X) < input_length:
73
- pad_width = input_length - len(X)
74
- X = np.pad(X, (0, pad_width), mode='constant')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52,
77
- n_fft=512, hop_length=256).T, axis=0)
78
- return mfccs, X, sr
79
-
80
- def classify_audio(filepath):
81
- mfccs, waveform, sr = process_audio(filepath)
82
- features = mfccs.reshape(1, 52, 1)
83
- preds = model.predict(features)
84
- class_names = ["artifact", "murmur", "normal"]
85
- result = {name: float(preds[0][i]) for i, name in enumerate(class_names)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # Store in history
88
- st.session_state.history.append({
89
- 'date': datetime.now().strftime("%Y-%m-%d %H:%M"),
90
- 'file': filepath,
91
- 'result': result
92
- })
 
 
 
 
 
 
 
93
 
94
- return result, waveform, sr
95
 
96
- # Page rendering functions
97
- def home_page():
98
- st.title("Heartbeat Analysis")
99
- uploaded_file = st.file_uploader("Upload your heartbeat audio", type=["wav", "mp3"])
100
 
101
- if uploaded_file is not None:
102
- st.audio(uploaded_file.read(), format='audio/wav')
103
- uploaded_file.seek(0)
104
-
105
- if st.button("Analyze Now"):
106
- with st.spinner('Analyzing...'):
107
- with open("temp.wav", "wb") as f:
108
- f.write(uploaded_file.getbuffer())
109
-
110
- results, waveform, sr = classify_audio("temp.wav")
111
-
112
- st.subheader("Analysis Results")
113
- cols = st.columns(3)
114
- labels = {
115
- 'artifact': "🚨 Artifact",
116
- 'murmur': "πŸ’” Murmur",
117
- 'normal': "❀️ Normal"
118
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- for (label, value), col in zip(results.items(), cols):
121
- with col:
122
- st.metric(labels[label], f"{value*100:.2f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- st.subheader("Heartbeat Waveform")
125
- fig, ax = plt.subplots(figsize=(10, 3))
126
- librosa.display.waveshow(waveform, sr=sr, ax=ax)
127
- ax.set_title("Audio Waveform Analysis")
128
- st.pyplot(fig)
129
-
130
- def settings_page():
131
- st.title("Settings")
132
- new_theme = st.selectbox(
133
- "Select Theme",
134
- ["Light Green", "Light Blue"],
135
- index=0 if st.session_state.theme == "Light Green" else 1
136
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- if new_theme != st.session_state.theme:
139
- st.session_state.theme = new_theme
140
- st.experimental_rerun()
141
-
142
- def profile_page():
143
- st.title("Medical Profile")
144
- with st.expander("Personal Information", expanded=True):
145
- col1, col2 = st.columns(2)
146
- with col1:
147
- st.write("**Name:** Kpetaa Patrick")
148
- st.write("**Age:** 35")
149
- with col2:
150
- st.write("**Blood Type:** O+")
151
- st.write("**Last Checkup:** 2025-06-17")
152
 
153
- st.subheader("Analysis History")
154
- if not st.session_state.history:
155
- st.write("No previous analyses found")
156
- else:
157
- for analysis in reversed(st.session_state.history):
158
- with st.expander(f"Analysis from {analysis['date']}"):
159
- st.write(f"File: {analysis['file']}")
160
- st.write("Results:")
161
- for label, value in analysis['result'].items():
162
- st.progress(value, text=f"{label.capitalize()}: {value*100:.2f}%")
163
-
164
- # Main app logic
165
- apply_theme()
166
-
167
- if st.session_state.page == "🏠 Home":
168
- home_page()
169
- elif st.session_state.page == "βš™οΈ Settings":
170
- settings_page()
171
- elif st.session_state.page == "πŸ‘€ Profile":
172
- profile_page()
 
1
+ import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import tensorflow as tf
5
  import matplotlib.pyplot as plt
6
  from datetime import datetime
7
+ import json
8
+ import os
9
+ from PIL import Image
10
+ import google.generativeai as genai
11
+ from typing import Dict, List, Tuple, Optional
12
+
13
+ # Configure Gemini AI
14
+ # You'll need to set your API key: export GOOGLE_API_KEY="your_api_key_here"
15
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
16
+ gemini_model = genai.GenerativeModel('gemini-1.5-flash')
17
 
18
  # Load the pre-trained ResNet model
19
+ @gr.utils.cache
20
+ def load_heartbeat_model():
21
+ try:
22
+ model = tf.keras.models.load_model('Heart_ResNet.h5')
23
+ return model
24
+ except:
25
+ print("Warning: Heart_ResNet.h5 model not found. Using mock predictions.")
26
+ return None
27
+
28
+ heartbeat_model = load_heartbeat_model()
29
+
30
+ # Global storage for patient data (in production, use a proper database)
31
+ patient_data = {}
32
+
33
+ def process_audio(file_path: str) -> Tuple[np.ndarray, np.ndarray, int]:
34
+ """Process audio file and extract MFCC features."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  SAMPLE_RATE = 22050
36
  DURATION = 10
37
  input_length = int(SAMPLE_RATE * DURATION)
38
 
39
+ try:
40
+ X, sr = librosa.load(file_path, sr=SAMPLE_RATE, duration=DURATION)
41
+
42
+ if len(X) < input_length:
43
+ pad_width = input_length - len(X)
44
+ X = np.pad(X, (0, pad_width), mode='constant')
45
+
46
+ mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52,
47
+ n_fft=512, hop_length=256).T, axis=0)
48
+ return mfccs, X, sr
49
+ except Exception as e:
50
+ print(f"Error processing audio: {e}")
51
+ return None, None, None
52
+
53
+ def analyze_heartbeat(audio_file) -> Tuple[Dict, str]:
54
+ """Analyze heartbeat audio and return results with visualization."""
55
+ if audio_file is None:
56
+ return {}, "No audio file provided"
57
 
58
+ try:
59
+ mfccs, waveform, sr = process_audio(audio_file)
60
+ if mfccs is None:
61
+ return {}, "Error processing audio file"
62
+
63
+ if heartbeat_model is not None:
64
+ features = mfccs.reshape(1, 52, 1)
65
+ preds = heartbeat_model.predict(features)
66
+ class_names = ["artifact", "murmur", "normal"]
67
+ results = {name: float(preds[0][i]) for i, name in enumerate(class_names)}
68
+ else:
69
+ # Mock results for demonstration
70
+ results = {"artifact": 0.15, "murmur": 0.25, "normal": 0.60}
71
+
72
+ # Create waveform visualization
73
+ fig, ax = plt.subplots(figsize=(12, 4))
74
+ librosa.display.waveshow(waveform, sr=sr, ax=ax)
75
+ ax.set_title("Heartbeat Waveform Analysis")
76
+ ax.set_xlabel("Time (seconds)")
77
+ ax.set_ylabel("Amplitude")
78
+ plt.tight_layout()
79
+
80
+ # Save plot
81
+ plot_path = f"temp_waveform_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
82
+ plt.savefig(plot_path, dpi=150, bbox_inches='tight')
83
+ plt.close()
84
+
85
+ return results, plot_path
86
+
87
+ except Exception as e:
88
+ return {}, f"Error analyzing heartbeat: {str(e)}"
89
+
90
+ def analyze_medical_image(image) -> str:
91
+ """Analyze medical images using Gemini Vision."""
92
+ if image is None:
93
+ return "No image provided"
94
 
95
+ try:
96
+ # Convert to PIL Image if needed
97
+ if not isinstance(image, Image.Image):
98
+ image = Image.fromarray(image)
99
+
100
+ prompt = """
101
+ Analyze this medical image/investigation result. Please provide:
102
+ 1. Type of investigation/scan
103
+ 2. Key findings visible in the image
104
+ 3. Any abnormalities or areas of concern
105
+ 4. Recommendations for follow-up if needed
106
+
107
+ Please be thorough but remember this is for educational purposes and should not replace professional medical diagnosis.
108
+ """
109
+
110
+ response = gemini_model.generate_content([prompt, image])
111
+ return response.text
112
+
113
+ except Exception as e:
114
+ return f"Error analyzing image: {str(e)}"
115
+
116
+ def generate_comprehensive_assessment(patient_info: Dict) -> str:
117
+ """Generate comprehensive medical assessment using Gemini AI."""
118
+ try:
119
+ # Prepare comprehensive prompt
120
+ prompt = f"""
121
+ Based on the following comprehensive patient data, provide a detailed medical assessment:
122
+
123
+ PATIENT DEMOGRAPHICS:
124
+ - Name: {patient_info.get('name', 'Not provided')}
125
+ - Age: {patient_info.get('age', 'Not provided')}
126
+ - Sex: {patient_info.get('sex', 'Not provided')}
127
+ - Weight: {patient_info.get('weight', 'Not provided')} kg
128
+ - Height: {patient_info.get('height', 'Not provided')} cm
129
+
130
+ CHIEF COMPLAINT:
131
+ {patient_info.get('complaint', 'Not provided')}
132
+
133
+ MEDICAL HISTORY:
134
+ {patient_info.get('medical_history', 'Not provided')}
135
+
136
+ PHYSICAL EXAMINATION:
137
+ {patient_info.get('examination', 'Not provided')}
138
+
139
+ HEART SOUNDS ANALYSIS:
140
+ {patient_info.get('heartbeat_analysis', 'Not performed')}
141
+
142
+ INVESTIGATIONS:
143
+ {patient_info.get('investigation_analysis', 'Not provided')}
144
+
145
+ Please provide a comprehensive medical assessment including:
146
+ 1. Clinical Summary
147
+ 2. Differential Diagnosis (list possible conditions)
148
+ 3. Risk Factors Assessment
149
+ 4. Recommended Treatment Plan
150
+ 5. Follow-up Recommendations
151
+ 6. Patient Education Points
152
+ 7. Prognosis
153
+
154
+ Please structure your response professionally and remember this is for educational purposes.
155
+ """
156
+
157
+ response = gemini_model.generate_content(prompt)
158
+ return response.text
159
+
160
+ except Exception as e:
161
+ return f"Error generating assessment: {str(e)}"
162
+
163
+ def save_patient_data(name, age, sex, weight, height, complaint, medical_history,
164
+ examination, heartbeat_results, investigation_analysis):
165
+ """Save all patient data to global storage."""
166
+ global patient_data
167
 
168
+ patient_data = {
169
+ 'name': name,
170
+ 'age': age,
171
+ 'sex': sex,
172
+ 'weight': weight,
173
+ 'height': height,
174
+ 'complaint': complaint,
175
+ 'medical_history': medical_history,
176
+ 'examination': examination,
177
+ 'heartbeat_analysis': heartbeat_results,
178
+ 'investigation_analysis': investigation_analysis,
179
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
180
+ }
181
 
182
+ return "Patient data saved successfully!"
183
 
184
+ def process_complete_consultation(name, age, sex, weight, height, complaint,
185
+ medical_history, examination, audio_file,
186
+ investigation_image):
187
+ """Process complete medical consultation."""
188
 
189
+ # Analyze heartbeat if audio provided
190
+ heartbeat_results = ""
191
+ waveform_plot = None
192
+ if audio_file is not None:
193
+ results, plot_path = analyze_heartbeat(audio_file)
194
+ if results:
195
+ heartbeat_results = f"""
196
+ Heartbeat Analysis Results:
197
+ - Normal: {results.get('normal', 0)*100:.1f}%
198
+ - Murmur: {results.get('murmur', 0)*100:.1f}%
199
+ - Artifact: {results.get('artifact', 0)*100:.1f}%
200
+ """
201
+ waveform_plot = plot_path
202
+
203
+ # Analyze investigation image if provided
204
+ investigation_analysis = ""
205
+ if investigation_image is not None:
206
+ investigation_analysis = analyze_medical_image(investigation_image)
207
+
208
+ # Save patient data
209
+ save_patient_data(name, age, sex, weight, height, complaint, medical_history,
210
+ examination, heartbeat_results, investigation_analysis)
211
+
212
+ # Generate comprehensive assessment
213
+ comprehensive_assessment = generate_comprehensive_assessment(patient_data)
214
+
215
+ return comprehensive_assessment, waveform_plot, heartbeat_results, investigation_analysis
216
+
217
+ # Create Gradio interface
218
+ def create_interface():
219
+ with gr.Blocks(title="Comprehensive Medical Consultation System", theme=gr.themes.Soft()) as demo:
220
+
221
+ gr.Markdown("""
222
+ # πŸ₯ Comprehensive Medical Consultation System
223
+ ### Integrated AI-Powered Medical Assessment Platform
224
+ """)
225
+
226
+ with gr.Tab("πŸ“‹ Patient Information"):
227
+ gr.Markdown("## Patient Demographics")
228
+
229
+ with gr.Row():
230
+ with gr.Column():
231
+ name = gr.Textbox(label="Full Name", placeholder="Enter patient's full name")
232
+ age = gr.Number(label="Age (years)", minimum=0, maximum=120)
233
+ sex = gr.Radio(["Male", "Female", "Other"], label="Sex")
234
 
235
+ with gr.Column():
236
+ weight = gr.Number(label="Weight (kg)", minimum=0, maximum=300)
237
+ height = gr.Number(label="Height (cm)", minimum=0, maximum=250)
238
+
239
+ gr.Markdown("## Chief Complaint")
240
+ complaint = gr.Textbox(
241
+ label="Chief Complaint",
242
+ placeholder="Describe the main symptoms or reason for consultation...",
243
+ lines=3
244
+ )
245
+
246
+ gr.Markdown("## Medical History")
247
+ medical_history = gr.Textbox(
248
+ label="Past Medical History",
249
+ placeholder="Include previous illnesses, surgeries, medications, allergies, family history...",
250
+ lines=5
251
+ )
252
+
253
+ with gr.Tab("🩺 Physical Examination"):
254
+ gr.Markdown("## Physical Examination Findings")
255
+
256
+ examination = gr.Textbox(
257
+ label="Examination Findings",
258
+ placeholder="General appearance, vital signs, systemic examination findings...",
259
+ lines=6
260
+ )
261
+
262
+ gr.Markdown("## Heart Sounds Analysis")
263
+ audio_file = gr.Audio(
264
+ label="Heart Sounds Recording",
265
+ type="filepath",
266
+ sources=["upload", "microphone"]
267
+ )
268
+
269
+ heartbeat_analyze_btn = gr.Button("πŸ” Analyze Heart Sounds", variant="secondary")
270
+ heartbeat_results = gr.Textbox(label="Heart Sounds Analysis Results", lines=4)
271
+ waveform_plot = gr.Image(label="Heart Sounds Waveform")
272
+
273
+ heartbeat_analyze_btn.click(
274
+ fn=analyze_heartbeat,
275
+ inputs=[audio_file],
276
+ outputs=[heartbeat_results, waveform_plot]
277
+ )
278
+
279
+ with gr.Tab("πŸ”¬ Investigations"):
280
+ gr.Markdown("## Medical Investigations & Imaging")
281
+
282
+ investigation_image = gr.Image(
283
+ label="Upload Investigation Results (X-ray, ECG, Lab reports, etc.)",
284
+ type="pil"
285
+ )
286
+
287
+ investigate_btn = gr.Button("πŸ” Analyze Investigation", variant="secondary")
288
+ investigation_results = gr.Textbox(
289
+ label="Investigation Analysis",
290
+ lines=6,
291
+ placeholder="AI analysis of uploaded investigation will appear here..."
292
+ )
293
+
294
+ investigate_btn.click(
295
+ fn=analyze_medical_image,
296
+ inputs=[investigation_image],
297
+ outputs=[investigation_results]
298
+ )
299
+
300
+ with gr.Tab("πŸ€– AI Assessment"):
301
+ gr.Markdown("## Comprehensive Medical Assessment")
302
+
303
+ generate_btn = gr.Button(
304
+ "🧠 Generate Comprehensive Assessment",
305
+ variant="primary",
306
+ size="lg"
307
+ )
308
+
309
+ assessment_output = gr.Textbox(
310
+ label="AI-Generated Medical Assessment",
311
+ lines=15,
312
+ placeholder="Complete medical assessment will be generated here based on all provided information..."
313
+ )
314
+
315
+ # Hidden outputs to collect all data
316
+ hidden_heartbeat = gr.Textbox(visible=False)
317
+ hidden_investigation = gr.Textbox(visible=False)
318
+ hidden_waveform = gr.Image(visible=False)
319
+
320
+ generate_btn.click(
321
+ fn=process_complete_consultation,
322
+ inputs=[name, age, sex, weight, height, complaint, medical_history,
323
+ examination, audio_file, investigation_image],
324
+ outputs=[assessment_output, hidden_waveform, hidden_heartbeat,
325
+ hidden_investigation]
326
+ )
327
+
328
+ with gr.Tab("πŸ“Š Patient Summary"):
329
+ gr.Markdown("## Patient Data Summary")
330
+
331
+ refresh_btn = gr.Button("πŸ”„ Refresh Patient Data", variant="secondary")
332
+
333
+ with gr.Row():
334
+ with gr.Column():
335
+ summary_demographics = gr.JSON(label="Demographics")
336
+ summary_clinical = gr.JSON(label="Clinical Data")
337
 
338
+ with gr.Column():
339
+ summary_results = gr.JSON(label="Investigation Results")
340
+
341
+ def refresh_patient_summary():
342
+ if patient_data:
343
+ demographics = {
344
+ "Name": patient_data.get('name', 'N/A'),
345
+ "Age": patient_data.get('age', 'N/A'),
346
+ "Sex": patient_data.get('sex', 'N/A'),
347
+ "Weight": f"{patient_data.get('weight', 'N/A')} kg",
348
+ "Height": f"{patient_data.get('height', 'N/A')} cm"
349
+ }
350
+
351
+ clinical = {
352
+ "Chief Complaint": patient_data.get('complaint', 'N/A'),
353
+ "Medical History": patient_data.get('medical_history', 'N/A')[:100] + "..." if len(patient_data.get('medical_history', '')) > 100 else patient_data.get('medical_history', 'N/A'),
354
+ "Examination": patient_data.get('examination', 'N/A')[:100] + "..." if len(patient_data.get('examination', '')) > 100 else patient_data.get('examination', 'N/A')
355
+ }
356
+
357
+ results = {
358
+ "Heartbeat Analysis": "Completed" if patient_data.get('heartbeat_analysis') else "Not performed",
359
+ "Investigation Analysis": "Completed" if patient_data.get('investigation_analysis') else "Not performed",
360
+ "Last Updated": patient_data.get('timestamp', 'N/A')
361
+ }
362
+
363
+ return demographics, clinical, results
364
+ else:
365
+ return {}, {}, {}
366
+
367
+ refresh_btn.click(
368
+ fn=refresh_patient_summary,
369
+ outputs=[summary_demographics, summary_clinical, summary_results]
370
+ )
371
+
372
+ gr.Markdown("""
373
+ ---
374
+ ### πŸ“ Important Notes:
375
+ - This system is for educational and research purposes only
376
+ - Always consult qualified healthcare professionals for medical decisions
377
+ - Ensure patient privacy and data protection compliance
378
+ - AI assessments should supplement, not replace, clinical judgment
379
+ """)
380
 
381
+ return demo
382
+
383
+ # Launch the application
384
+ if __name__ == "__main__":
385
+ # Check if required environment variables are set
386
+ if not os.getenv("GOOGLE_API_KEY"):
387
+ print("Warning: GOOGLE_API_KEY not set. Gemini AI features will not work.")
388
+ print("Set your API key with: export GOOGLE_API_KEY='your_api_key_here'")
 
 
 
 
 
 
389
 
390
+ demo = create_interface()
391
+ demo.launch(
392
+ server_name="0.0.0.0",
393
+ server_port=7860,
394
+ share=True,
395
+ debug=True
396
+ )