DjPapzin commited on
Commit
398c97f
·
1 Parent(s): d0faa42

Undo last commit

Browse files
Files changed (1) hide show
  1. frontend/app.py +432 -12
frontend/app.py CHANGED
@@ -56,23 +56,443 @@ st.title("Medi Scape Dashboard")
56
  # --- Session State Initialization ---
57
  if 'disease_model' not in st.session_state:
58
  try:
59
- # Check if running on Huggingface Spaces
60
- if 'HUGGINGFACE_SPACES' in os.environ:
61
- model_path = 'FINAL_MODEL.zip'
62
- with tf.keras.utils.get_file('FINAL_MODEL.keras', model_path, extract=True) as extracted_model_path:
63
- model_dir = os.path.dirname(extracted_model_path)
64
- model_path = os.path.join(model_dir, 'FINAL_MODEL.keras')
65
- st.session_state.disease_model = tf.keras.models.load_model(model_path)
66
- else: # Running locally
67
- model_path = 'FINAL_MODEL.keras'
68
  st.session_state.disease_model = tf.keras.models.load_model(model_path)
69
-
70
  print("Disease model loaded successfully!")
71
  except FileNotFoundError:
72
- st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.zip' (for Huggingface) or 'FINAL_MODEL.keras' (local) is in the same directory as this app.")
73
  st.session_state.disease_model = None
74
  except PermissionError:
75
  st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.")
76
  st.session_state.disease_model = None
77
 
78
- # ... rest of your code ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # --- Session State Initialization ---
57
  if 'disease_model' not in st.session_state:
58
  try:
59
+ model_path = 'FINAL_MODEL.zip' # Updated path to zip file
60
+ print(f"Attempting to load disease model from: {model_path}")
61
+ print(f"Model file exists: {os.path.exists(model_path)}")
62
+ with tf.keras.utils.get_file('FINAL_MODEL.keras', model_path, extract=True) as extracted_model_path:
63
+ model_dir = os.path.dirname(extracted_model_path)
64
+ model_path = os.path.join(model_dir, 'FINAL_MODEL.keras')
 
 
 
65
  st.session_state.disease_model = tf.keras.models.load_model(model_path)
 
66
  print("Disease model loaded successfully!")
67
  except FileNotFoundError:
68
+ st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.zip' is in the same directory as this app.")
69
  st.session_state.disease_model = None
70
  except PermissionError:
71
  st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.")
72
  st.session_state.disease_model = None
73
 
74
+ # Load the vectorizer
75
+ if 'vectorizer' not in st.session_state:
76
+ try:
77
+ vectorizer_path = "vectorizer.pkl"
78
+ print(f"Attempting to load vectorizer from: {vectorizer_path}")
79
+ print(f"Vectorizer file exists: {os.path.exists(vectorizer_path)}")
80
+ st.session_state.vectorizer = pd.read_pickle(vectorizer_path)
81
+ print("Vectorizer loaded successfully!")
82
+ except FileNotFoundError:
83
+ st.error("Vectorizer file not found. Please ensure 'vectorizer.pkl' is in the same directory as this app.")
84
+ st.session_state.vectorizer = None
85
+ except Exception as e:
86
+ st.error(f"An error occurred while loading the vectorizer: {e}")
87
+ st.session_state.vectorizer = None
88
+
89
+ if 'model_llm' not in st.session_state:
90
+ try:
91
+ llm_model_path = "logistic_regression_model.pkl" # Corrected path
92
+ print(f"Attempting to load LLM model from: {llm_model_path}")
93
+ print(f"LLM Model file exists: {os.path.exists(llm_model_path)}")
94
+ st.session_state.model_llm = pd.read_pickle(llm_model_path)
95
+ print("LLM model loaded successfully!")
96
+ except FileNotFoundError:
97
+ st.error("LLM model file not found. Please ensure 'logistic_regression_model.pkl' is in the 'frontend' directory.")
98
+ st.session_state.model_llm = None
99
+ except Exception as e:
100
+ st.error(f"An error occurred while loading the LLM model: {e}")
101
+ st.session_state.model_llm = None
102
+
103
+ # --- End of Session State Initialization ---
104
+
105
+ # Load the disease classification model
106
+ try:
107
+ model_path = 'FINAL_MODEL.zip' # Updated path to zip file
108
+ with tf.keras.utils.get_file('FINAL_MODEL.keras', model_path, extract=True) as extracted_model_path:
109
+ model_dir = os.path.dirname(extracted_model_path)
110
+ model_path = os.path.join(model_dir, 'FINAL_MODEL.keras')
111
+ disease_model = tf.keras.models.load_model(model_path)
112
+ except FileNotFoundError:
113
+ st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.zip' is in the same directory as this app.")
114
+ disease_model = None
115
+ except PermissionError:
116
+ st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.")
117
+ disease_model = None
118
+
119
+ # Sidebar Navigation
120
+ st.sidebar.title("Navigation")
121
+ page = st.sidebar.radio("Go to", ["Home", "AI Chatbot Diagnosis", "Drug Identification", "Disease Detection", "Outbreak Alert"])
122
+
123
+ # Access secrets using st.secrets
124
+ if "INFERENCE_API_URL" not in st.secrets or "INFERENCE_API_KEY" not in st.secrets:
125
+ st.error("Please make sure to set your secrets in the Streamlit secrets settings.")
126
+ else:
127
+ # Initialize the Inference Client
128
+ CLIENT = InferenceHTTPClient(
129
+ api_url=st.secrets["INFERENCE_API_URL"],
130
+ api_key=st.secrets["INFERENCE_API_KEY"]
131
+ )
132
+
133
+ # Function to preprocess the image
134
+ def preprocess_image(image_path):
135
+ # Load the image
136
+ image = cv2.imread(image_path)
137
+
138
+ # Convert to grayscale
139
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
140
+
141
+ # Remove noise
142
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
143
+
144
+ # Thresholding/Binarization
145
+ _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
146
+
147
+ # Dilation and Erosion
148
+ kernel = np.ones((1, 1), np.uint8)
149
+ dilated = cv2.dilate(binary, kernel, iterations=1)
150
+ eroded = cv2.erode(dilated, kernel, iterations=1)
151
+
152
+ # Edge detection
153
+ edges = cv2.Canny(eroded, 100, 200)
154
+
155
+ # Deskewing
156
+ coords = np.column_stack(np.where(edges > 0))
157
+ angle = cv2.minAreaRect(coords)[-1]
158
+ if angle < -45:
159
+ angle = -(90 + angle)
160
+ else:
161
+ angle = -angle
162
+
163
+ (h, w) = edges.shape[:2]
164
+ center = (w // 2, h // 2)
165
+ M = cv2.getRotationMatrix2D(center, angle, 1.0)
166
+ deskewed = cv2.warpAffine(edges, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
167
+
168
+ # Find contours
169
+ contours, _ = cv2.findContours(deskewed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
170
+
171
+ # Draw contours on the original image
172
+ contour_image = image.copy()
173
+ cv2.drawContours(contour_image, contours, -1, (0, 255, 0), 2)
174
+
175
+ return contour_image
176
+
177
+ def get_x1(detection):
178
+ return detection.xyxy[0][0]
179
+
180
+ # --- Prediction function (using session state) ---
181
+ def predict_disease(symptoms):
182
+ if st.session_state.vectorizer is not None and st.session_state.model_llm is not None:
183
+ preprocessed_symptoms = preprocess_text(symptoms)
184
+ symptoms_vectorized = st.session_state.vectorizer.transform([preprocessed_symptoms])
185
+ prediction = st.session_state.model_llm.predict(symptoms_vectorized)
186
+ return prediction[0]
187
+ else:
188
+ st.error("Unable to make prediction. Vectorizer or LLM model is not loaded.")
189
+ return None
190
+
191
+ # --- New function to analyze X-ray with LLM ---
192
+ def analyze_xray_with_llm(predicted_class):
193
+ prompt = f"""
194
+ Based on a chest X-ray analysis, the predicted condition is {predicted_class}.
195
+ Please provide a concise summary of this condition, including:
196
+ - A brief description of the condition.
197
+ - Common symptoms associated with it.
198
+ - Potential causes.
199
+ - General treatment approaches.
200
+ - Any other relevant information for a patient.
201
+ """
202
+ llm_response = get_ai71_response(prompt)
203
+ st.write("## LLM Analysis of X-ray Results:")
204
+ st.write(llm_response)
205
+
206
+ # --- Functions for Symptom Detection ---
207
+ def precaution(label):
208
+ dataset_precau = pd.read_csv("disease_precaution.csv", encoding='latin1')
209
+ label = str(label)
210
+ label = label.lower()
211
+
212
+ dataset_precau["Disease"] = dataset_precau["Disease"].str.lower()
213
+ # Filter the DataFrame for the given label
214
+ filtered_precautions = dataset_precau[dataset_precau["Disease"] == label]
215
+
216
+ # Extract precaution columns
217
+ precautions = filtered_precautions[["Precaution_1", "Precaution_2", "Precaution_3", "Precaution_4"]]
218
+ return precautions.values.tolist() # Convert DataFrame to a list of lists
219
+ # Return an empty list if no matching label is found
220
+
221
+ def occurance(label):
222
+ dataset_occur = pd.read_csv("disease_riskFactors.csv", encoding='latin1')
223
+ label = str(label)
224
+ label = label.lower()
225
+
226
+ dataset_occur["DNAME"] = dataset_occur["DNAME"].str.lower()
227
+ # Filter the DataFrame for the given label
228
+ filtered_occurrence = dataset_occur[dataset_occur["DNAME"] == label]
229
+
230
+ occurrences = filtered_occurrence["OCCUR"].tolist() # Convert Series to list
231
+ return occurrences
232
+ # Return an empty list if no matching label is found
233
+
234
+ if page == "Home":
235
+ st.markdown("## Welcome to Medi Scape")
236
+ st.write("Medi Scape is an AI-powered healthcare application designed to streamline the process of understanding and managing medical information. It leverages advanced AI models to provide features such as prescription analysis, disease detection from chest X-rays, and symptom-based diagnosis assistance.")
237
+
238
+ st.markdown("## Features")
239
+ st.write("Medi Scape provides various AI-powered tools for remote healthcare, including:")
240
+ features = [
241
+ "**AI Chatbot Diagnosis:** Interact with an AI chatbot for preliminary diagnosis and medical information.",
242
+ "**Drug Identification:** Upload a prescription image to identify medications and access relevant details.",
243
+ "**Doctor's Handwriting Identification:** Our system can accurately recognize and process doctor's handwriting.",
244
+ "**Disease Detection:** Upload a chest X-ray image to detect potential diseases.",
245
+ "**Outbreak Alert:** Stay informed about potential disease outbreaks in your area."
246
+ ]
247
+ for feature in features:
248
+ st.markdown(f"- {feature}")
249
+
250
+ st.markdown("## How it Works")
251
+ steps = [
252
+ "**Upload:** You can upload a prescription image for drug identification or a chest X-ray image for disease detection.",
253
+ "**Process:** Our AI models will analyze the image and extract relevant information.",
254
+ "**Results:** You will receive identified drug names, uses, side effects, and more, or a potential disease diagnosis."
255
+ ]
256
+ for i, step in enumerate(steps, 1):
257
+ st.markdown(f"{i}. {step}")
258
+
259
+ st.markdown("## Key Features")
260
+ key_features = [
261
+ "**AI-Powered:** Leverages advanced AI models for accurate analysis and diagnosis.",
262
+ "**User-Friendly:** Simple and intuitive interface for easy navigation and interaction.",
263
+ "**Secure:** Your data is protected and handled with confidentiality."
264
+ ]
265
+ for feature in key_features:
266
+ st.markdown(f"- {feature}")
267
+
268
+ st.markdown("Please use the sidebar to navigate to different features.")
269
+
270
+ elif page == "AI Chatbot Diagnosis":
271
+ st.write("Enter your symptoms separated by commas:")
272
+ symptoms_input = st.text_area("Symptoms:")
273
+ col1, col2 = st.columns(2)
274
+ with col1:
275
+ if st.button("Diagnose with Regression Model"):
276
+ if symptoms_input:
277
+ # --- Pipeline 1 Implementation ---
278
+ # 1. Symptom Input (already done with st.text_area)
279
+ # 2. Regression Prediction
280
+ regression_prediction = predict_disease(symptoms_input)
281
+
282
+ if regression_prediction is not None:
283
+ st.write("## Logistic Regression Prediction:")
284
+ st.write(regression_prediction)
285
+
286
+ st.write("## Precautions:")
287
+ precautions_names = precaution(regression_prediction)
288
+ st.write(precautions_names)
289
+
290
+ st.write("## Occurrence:")
291
+ occurance_name = occurance(regression_prediction)
292
+ st.write(occurance_name)
293
+
294
+ else:
295
+ st.write("Please enter your symptoms.")
296
+
297
+ with col2:
298
+ if st.button("Diagnose with LLM"):
299
+ if symptoms_input:
300
+ # --- Pipeline 2 Implementation (LLM Only) ---
301
+ prompt = f"""The user is experiencing the following symptoms: {symptoms_input}.
302
+ Based on these symptoms, provide a detailed explanation of possible conditions, including
303
+ potential causes, common symptoms, and general treatment approaches. Also, suggest when
304
+ a patient should consult a doctor."""
305
+
306
+ llm_response = get_ai71_response(prompt)
307
+
308
+ st.write("## LLM Diagnosis:")
309
+ st.write(llm_response)
310
+ else:
311
+ st.write("Please enter your symptoms.")
312
+
313
+ elif page == "Drug Identification":
314
+ st.write("Upload a prescription image for drug identification.")
315
+ uploaded_file = st.file_uploader("Upload prescription", type=["png", "jpg", "jpeg"])
316
+
317
+ if uploaded_file is not None:
318
+ # Display the uploaded image
319
+ image = Image.open(uploaded_file)
320
+ st.image(image, caption="Uploaded Prescription", use_column_width=True)
321
+
322
+ if st.button("Process Prescription"):
323
+ # Save the image to a temporary file
324
+ temp_image_path = "temp_image.jpg"
325
+ image.save(temp_image_path)
326
+
327
+ # Preprocess the image
328
+ preprocessed_image = preprocess_image(temp_image_path)
329
+
330
+ # Perform inference
331
+ result_doch1 = CLIENT.infer(preprocessed_image, model_id="doctor-s-handwriting/1")
332
+
333
+ # Extract labels and detections
334
+ labels = [item["class"] for item in result_doch1["predictions"]]
335
+ detections = sv.Detections.from_inference(result_doch1)
336
+
337
+ # Sort detections and labels
338
+ sorted_indices = sorted(range(len(detections)), key=lambda i: get_x1(detections[i]))
339
+ sorted_detections = [detections[i] for i in sorted_indices]
340
+ sorted_labels = [labels[i] for i in sorted_indices]
341
+
342
+ # Convert list to string
343
+ resulting_string = ''.join(sorted_labels)
344
+
345
+ # Display results
346
+ st.subheader("Processed Prescription")
347
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
348
+
349
+ # Plot bounding boxes
350
+ image_with_boxes = preprocessed_image.copy()
351
+ for detection in sorted_detections:
352
+ x1, y1, x2, y2 = detection.xyxy[0]
353
+ cv2.rectangle(image_with_boxes, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)
354
+ ax1.imshow(cv2.cvtColor(image_with_boxes, cv2.COLOR_BGR2RGB))
355
+ ax1.set_title("Bounding Boxes")
356
+ ax1.axis('off')
357
+
358
+ # Plot labels
359
+ image_with_labels = preprocessed_image.copy()
360
+ for i, detection in enumerate(sorted_detections):
361
+ x1, y1, x2, y2 = detection.xyxy[0]
362
+ label = sorted_labels[i]
363
+ cv2.putText(image_with_labels, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
364
+ ax2.imshow(cv2.cvtColor(image_with_labels, cv2.COLOR_BGR2RGB))
365
+ ax2.set_title("Labels")
366
+ ax2.axis('off')
367
+
368
+ st.pyplot(fig)
369
+
370
+ st.write("Extracted Text from Prescription:", resulting_string)
371
+
372
+ # Prepare prompt for LLM
373
+ prompt = f"""Analyze the following prescription text:
374
+ {resulting_string}
375
+
376
+ Please provide:
377
+ 1. Identified drug name(s)
378
+ 2. Full name of each identified drug
379
+ 3. Primary uses of each drug
380
+ 4. Common side effects
381
+ 5. Recommended dosage (if identifiable from the text)
382
+ 6. Any warnings or precautions
383
+ 7. Potential interactions with other medications (if multiple drugs are identified)
384
+ 8. Any additional relevant information for the patient
385
+
386
+ If any part of the prescription is unclear or seems incomplete, please mention that and provide information about possible interpretations or matches. Always emphasize the importance of consulting a healthcare professional for accurate interpretation and advice."""
387
+
388
+ # Get LLM response
389
+ llm_response = get_ai71_response(prompt)
390
+
391
+ st.subheader("AI Analysis of the Prescription")
392
+ st.write(llm_response)
393
+
394
+ # Remove the temporary image file
395
+ os.remove(temp_image_path)
396
+
397
+ else:
398
+ st.info("Please upload a prescription image to proceed.")
399
+
400
+ elif page == "Disease Detection":
401
+ st.write("Upload a chest X-ray image for disease detection.")
402
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
403
+
404
+ if uploaded_image is not None and st.session_state.disease_model is not None:
405
+ # Display the image
406
+ img_opened = Image.open(uploaded_image).convert('RGB')
407
+ image_pred = np.array(img_opened)
408
+ image_pred = cv2.resize(image_pred, (150, 150))
409
+
410
+ # Convert the image to a numpy array
411
+ image_pred = np.array(image_pred)
412
+
413
+ # Rescale the image (if the model was trained with rescaling)
414
+ image_pred = image_pred / 255.0
415
+
416
+ # Add an extra dimension to match the input shape (1, 150, 150, 3)
417
+ image_pred = np.expand_dims(image_pred, axis=0)
418
+
419
+ # Predict using the model
420
+ prediction = st.session_state.disease_model.predict(image_pred)
421
+
422
+ # Get the predicted class
423
+ predicted_ = np.argmax(prediction)
424
+
425
+ # Decode the prediction
426
+ if predicted_ == 0:
427
+ predicted_class = "Covid"
428
+ elif predicted_ == 1:
429
+ predicted_class = "Normal Chest X-ray"
430
+ else:
431
+ predicted_class = "Pneumonia"
432
+
433
+ st.image(image_pred, caption='Input image by user', use_column_width=True)
434
+ st.write("Prediction Classes for different types:")
435
+ st.write("COVID: 0")
436
+ st.write("Normal Chest X-ray: 1")
437
+ st.write("Pneumonia: 2")
438
+ st.write("\n")
439
+ st.write("DETECTED DISEASE DISPLAY")
440
+ st.write(f"Predicted Class : {predicted_}")
441
+ st.write(predicted_class)
442
+
443
+ # Analyze X-ray results with LLM
444
+ analyze_xray_with_llm(predicted_class)
445
+ else:
446
+ st.write("Please upload an image file or ensure the disease model is loaded.")
447
+
448
+ elif page == "Outbreak Alert":
449
+ st.markdown("## **Disease Outbreak News (from WHO)**")
450
+
451
+ # Fetch WHO news page
452
+ url = "https://www.who.int/news-room/events"
453
+ response = requests.get(url)
454
+ response.raise_for_status() # Raise an exception for bad status codes
455
+
456
+ soup = BeautifulSoup(response.content, 'html.parser')
457
+
458
+ # Find news articles (adjust selectors if WHO website changes)
459
+ articles = soup.find_all('div', class_='list-view--item')
460
+
461
+ for article in articles[:5]: # Display the top 5 news articles
462
+ title_element = article.find('a', class_='link-container')
463
+ if title_element:
464
+ title = title_element.text.strip()
465
+ link = title_element['href']
466
+ date_element = article.find('span', class_='date')
467
+ date = date_element.text.strip() if date_element else "Date not found"
468
+
469
+ # Format date
470
+ date_parts = date.split()
471
+ if len(date_parts) >= 3:
472
+ try:
473
+ formatted_date = datetime.strptime(date, "%d %B %Y").strftime("%Y-%m-%d")
474
+ except ValueError:
475
+ formatted_date = date # Keep the original date if formatting fails
476
+ else:
477
+ formatted_date = date
478
+
479
+ # Display news item in a card-like container
480
+ with st.container():
481
+ st.markdown(f"**{formatted_date}**")
482
+ st.markdown(f"[{title}]({link})")
483
+ st.markdown("---")
484
+ else:
485
+ st.write("Could not find article details.")
486
+
487
+ # Auto-scroll to the bottom of the chat container
488
+ st.markdown(
489
+ """
490
+ <script>
491
+ const chatContainer = document.querySelector('.st-chat-container');
492
+ if (chatContainer) {
493
+ chatContainer.scrollTop = chatContainer.scrollHeight;
494
+ }
495
+ </script>
496
+ """,
497
+ unsafe_allow_html=True,
498
+ )