DjPapzin commited on
Commit
5d08f2e
·
1 Parent(s): 324d008

fix: Handle model loading for both local and Huggingface deployments

Browse files
Files changed (1) hide show
  1. frontend/app.py +12 -432
frontend/app.py CHANGED
@@ -56,443 +56,23 @@ st.title("Medi Scape Dashboard")
56
  # --- Session State Initialization ---
57
  if 'disease_model' not in st.session_state:
58
  try:
59
- model_path = 'FINAL_MODEL.zip' # Updated path to zip file
60
- print(f"Attempting to load disease model from: {model_path}")
61
- print(f"Model file exists: {os.path.exists(model_path)}")
62
- with tf.keras.utils.get_file('FINAL_MODEL.keras', model_path, extract=True) as extracted_model_path:
63
- model_dir = os.path.dirname(extracted_model_path)
64
- model_path = os.path.join(model_dir, 'FINAL_MODEL.keras')
 
 
 
65
  st.session_state.disease_model = tf.keras.models.load_model(model_path)
 
66
  print("Disease model loaded successfully!")
67
  except FileNotFoundError:
68
- st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.zip' is in the same directory as this app.")
69
  st.session_state.disease_model = None
70
  except PermissionError:
71
  st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.")
72
  st.session_state.disease_model = None
73
 
74
- # Load the vectorizer
75
- if 'vectorizer' not in st.session_state:
76
- try:
77
- vectorizer_path = "vectorizer.pkl"
78
- print(f"Attempting to load vectorizer from: {vectorizer_path}")
79
- print(f"Vectorizer file exists: {os.path.exists(vectorizer_path)}")
80
- st.session_state.vectorizer = pd.read_pickle(vectorizer_path)
81
- print("Vectorizer loaded successfully!")
82
- except FileNotFoundError:
83
- st.error("Vectorizer file not found. Please ensure 'vectorizer.pkl' is in the same directory as this app.")
84
- st.session_state.vectorizer = None
85
- except Exception as e:
86
- st.error(f"An error occurred while loading the vectorizer: {e}")
87
- st.session_state.vectorizer = None
88
-
89
- if 'model_llm' not in st.session_state:
90
- try:
91
- llm_model_path = "logistic_regression_model.pkl" # Corrected path
92
- print(f"Attempting to load LLM model from: {llm_model_path}")
93
- print(f"LLM Model file exists: {os.path.exists(llm_model_path)}")
94
- st.session_state.model_llm = pd.read_pickle(llm_model_path)
95
- print("LLM model loaded successfully!")
96
- except FileNotFoundError:
97
- st.error("LLM model file not found. Please ensure 'logistic_regression_model.pkl' is in the 'frontend' directory.")
98
- st.session_state.model_llm = None
99
- except Exception as e:
100
- st.error(f"An error occurred while loading the LLM model: {e}")
101
- st.session_state.model_llm = None
102
-
103
- # --- End of Session State Initialization ---
104
-
105
- # Load the disease classification model
106
- try:
107
- model_path = 'FINAL_MODEL.zip' # Updated path to zip file
108
- with tf.keras.utils.get_file('FINAL_MODEL.keras', model_path, extract=True) as extracted_model_path:
109
- model_dir = os.path.dirname(extracted_model_path)
110
- model_path = os.path.join(model_dir, 'FINAL_MODEL.keras')
111
- disease_model = tf.keras.models.load_model(model_path)
112
- except FileNotFoundError:
113
- st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.zip' is in the same directory as this app.")
114
- disease_model = None
115
- except PermissionError:
116
- st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.")
117
- disease_model = None
118
-
119
- # Sidebar Navigation
120
- st.sidebar.title("Navigation")
121
- page = st.sidebar.radio("Go to", ["Home", "AI Chatbot Diagnosis", "Drug Identification", "Disease Detection", "Outbreak Alert"])
122
-
123
- # Access secrets using st.secrets
124
- if "INFERENCE_API_URL" not in st.secrets or "INFERENCE_API_KEY" not in st.secrets:
125
- st.error("Please make sure to set your secrets in the Streamlit secrets settings.")
126
- else:
127
- # Initialize the Inference Client
128
- CLIENT = InferenceHTTPClient(
129
- api_url=st.secrets["INFERENCE_API_URL"],
130
- api_key=st.secrets["INFERENCE_API_KEY"]
131
- )
132
-
133
- # Function to preprocess the image
134
- def preprocess_image(image_path):
135
- # Load the image
136
- image = cv2.imread(image_path)
137
-
138
- # Convert to grayscale
139
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
140
-
141
- # Remove noise
142
- blurred = cv2.GaussianBlur(gray, (5, 5), 0)
143
-
144
- # Thresholding/Binarization
145
- _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
146
-
147
- # Dilation and Erosion
148
- kernel = np.ones((1, 1), np.uint8)
149
- dilated = cv2.dilate(binary, kernel, iterations=1)
150
- eroded = cv2.erode(dilated, kernel, iterations=1)
151
-
152
- # Edge detection
153
- edges = cv2.Canny(eroded, 100, 200)
154
-
155
- # Deskewing
156
- coords = np.column_stack(np.where(edges > 0))
157
- angle = cv2.minAreaRect(coords)[-1]
158
- if angle < -45:
159
- angle = -(90 + angle)
160
- else:
161
- angle = -angle
162
-
163
- (h, w) = edges.shape[:2]
164
- center = (w // 2, h // 2)
165
- M = cv2.getRotationMatrix2D(center, angle, 1.0)
166
- deskewed = cv2.warpAffine(edges, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
167
-
168
- # Find contours
169
- contours, _ = cv2.findContours(deskewed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
170
-
171
- # Draw contours on the original image
172
- contour_image = image.copy()
173
- cv2.drawContours(contour_image, contours, -1, (0, 255, 0), 2)
174
-
175
- return contour_image
176
-
177
- def get_x1(detection):
178
- return detection.xyxy[0][0]
179
-
180
- # --- Prediction function (using session state) ---
181
- def predict_disease(symptoms):
182
- if st.session_state.vectorizer is not None and st.session_state.model_llm is not None:
183
- preprocessed_symptoms = preprocess_text(symptoms)
184
- symptoms_vectorized = st.session_state.vectorizer.transform([preprocessed_symptoms])
185
- prediction = st.session_state.model_llm.predict(symptoms_vectorized)
186
- return prediction[0]
187
- else:
188
- st.error("Unable to make prediction. Vectorizer or LLM model is not loaded.")
189
- return None
190
-
191
- # --- New function to analyze X-ray with LLM ---
192
- def analyze_xray_with_llm(predicted_class):
193
- prompt = f"""
194
- Based on a chest X-ray analysis, the predicted condition is {predicted_class}.
195
- Please provide a concise summary of this condition, including:
196
- - A brief description of the condition.
197
- - Common symptoms associated with it.
198
- - Potential causes.
199
- - General treatment approaches.
200
- - Any other relevant information for a patient.
201
- """
202
- llm_response = get_ai71_response(prompt)
203
- st.write("## LLM Analysis of X-ray Results:")
204
- st.write(llm_response)
205
-
206
- # --- Functions for Symptom Detection ---
207
- def precaution(label):
208
- dataset_precau = pd.read_csv("disease_precaution.csv", encoding='latin1')
209
- label = str(label)
210
- label = label.lower()
211
-
212
- dataset_precau["Disease"] = dataset_precau["Disease"].str.lower()
213
- # Filter the DataFrame for the given label
214
- filtered_precautions = dataset_precau[dataset_precau["Disease"] == label]
215
-
216
- # Extract precaution columns
217
- precautions = filtered_precautions[["Precaution_1", "Precaution_2", "Precaution_3", "Precaution_4"]]
218
- return precautions.values.tolist() # Convert DataFrame to a list of lists
219
- # Return an empty list if no matching label is found
220
-
221
- def occurance(label):
222
- dataset_occur = pd.read_csv("disease_riskFactors.csv", encoding='latin1')
223
- label = str(label)
224
- label = label.lower()
225
-
226
- dataset_occur["DNAME"] = dataset_occur["DNAME"].str.lower()
227
- # Filter the DataFrame for the given label
228
- filtered_occurrence = dataset_occur[dataset_occur["DNAME"] == label]
229
-
230
- occurrences = filtered_occurrence["OCCUR"].tolist() # Convert Series to list
231
- return occurrences
232
- # Return an empty list if no matching label is found
233
-
234
- if page == "Home":
235
- st.markdown("## Welcome to Medi Scape")
236
- st.write("Medi Scape is an AI-powered healthcare application designed to streamline the process of understanding and managing medical information. It leverages advanced AI models to provide features such as prescription analysis, disease detection from chest X-rays, and symptom-based diagnosis assistance.")
237
-
238
- st.markdown("## Features")
239
- st.write("Medi Scape provides various AI-powered tools for remote healthcare, including:")
240
- features = [
241
- "**AI Chatbot Diagnosis:** Interact with an AI chatbot for preliminary diagnosis and medical information.",
242
- "**Drug Identification:** Upload a prescription image to identify medications and access relevant details.",
243
- "**Doctor's Handwriting Identification:** Our system can accurately recognize and process doctor's handwriting.",
244
- "**Disease Detection:** Upload a chest X-ray image to detect potential diseases.",
245
- "**Outbreak Alert:** Stay informed about potential disease outbreaks in your area."
246
- ]
247
- for feature in features:
248
- st.markdown(f"- {feature}")
249
-
250
- st.markdown("## How it Works")
251
- steps = [
252
- "**Upload:** You can upload a prescription image for drug identification or a chest X-ray image for disease detection.",
253
- "**Process:** Our AI models will analyze the image and extract relevant information.",
254
- "**Results:** You will receive identified drug names, uses, side effects, and more, or a potential disease diagnosis."
255
- ]
256
- for i, step in enumerate(steps, 1):
257
- st.markdown(f"{i}. {step}")
258
-
259
- st.markdown("## Key Features")
260
- key_features = [
261
- "**AI-Powered:** Leverages advanced AI models for accurate analysis and diagnosis.",
262
- "**User-Friendly:** Simple and intuitive interface for easy navigation and interaction.",
263
- "**Secure:** Your data is protected and handled with confidentiality."
264
- ]
265
- for feature in key_features:
266
- st.markdown(f"- {feature}")
267
-
268
- st.markdown("Please use the sidebar to navigate to different features.")
269
-
270
- elif page == "AI Chatbot Diagnosis":
271
- st.write("Enter your symptoms separated by commas:")
272
- symptoms_input = st.text_area("Symptoms:")
273
- col1, col2 = st.columns(2)
274
- with col1:
275
- if st.button("Diagnose with Regression Model"):
276
- if symptoms_input:
277
- # --- Pipeline 1 Implementation ---
278
- # 1. Symptom Input (already done with st.text_area)
279
- # 2. Regression Prediction
280
- regression_prediction = predict_disease(symptoms_input)
281
-
282
- if regression_prediction is not None:
283
- st.write("## Logistic Regression Prediction:")
284
- st.write(regression_prediction)
285
-
286
- st.write("## Precautions:")
287
- precautions_names = precaution(regression_prediction)
288
- st.write(precautions_names)
289
-
290
- st.write("## Occurrence:")
291
- occurance_name = occurance(regression_prediction)
292
- st.write(occurance_name)
293
-
294
- else:
295
- st.write("Please enter your symptoms.")
296
-
297
- with col2:
298
- if st.button("Diagnose with LLM"):
299
- if symptoms_input:
300
- # --- Pipeline 2 Implementation (LLM Only) ---
301
- prompt = f"""The user is experiencing the following symptoms: {symptoms_input}.
302
- Based on these symptoms, provide a detailed explanation of possible conditions, including
303
- potential causes, common symptoms, and general treatment approaches. Also, suggest when
304
- a patient should consult a doctor."""
305
-
306
- llm_response = get_ai71_response(prompt)
307
-
308
- st.write("## LLM Diagnosis:")
309
- st.write(llm_response)
310
- else:
311
- st.write("Please enter your symptoms.")
312
-
313
- elif page == "Drug Identification":
314
- st.write("Upload a prescription image for drug identification.")
315
- uploaded_file = st.file_uploader("Upload prescription", type=["png", "jpg", "jpeg"])
316
-
317
- if uploaded_file is not None:
318
- # Display the uploaded image
319
- image = Image.open(uploaded_file)
320
- st.image(image, caption="Uploaded Prescription", use_column_width=True)
321
-
322
- if st.button("Process Prescription"):
323
- # Save the image to a temporary file
324
- temp_image_path = "temp_image.jpg"
325
- image.save(temp_image_path)
326
-
327
- # Preprocess the image
328
- preprocessed_image = preprocess_image(temp_image_path)
329
-
330
- # Perform inference
331
- result_doch1 = CLIENT.infer(preprocessed_image, model_id="doctor-s-handwriting/1")
332
-
333
- # Extract labels and detections
334
- labels = [item["class"] for item in result_doch1["predictions"]]
335
- detections = sv.Detections.from_inference(result_doch1)
336
-
337
- # Sort detections and labels
338
- sorted_indices = sorted(range(len(detections)), key=lambda i: get_x1(detections[i]))
339
- sorted_detections = [detections[i] for i in sorted_indices]
340
- sorted_labels = [labels[i] for i in sorted_indices]
341
-
342
- # Convert list to string
343
- resulting_string = ''.join(sorted_labels)
344
-
345
- # Display results
346
- st.subheader("Processed Prescription")
347
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
348
-
349
- # Plot bounding boxes
350
- image_with_boxes = preprocessed_image.copy()
351
- for detection in sorted_detections:
352
- x1, y1, x2, y2 = detection.xyxy[0]
353
- cv2.rectangle(image_with_boxes, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)
354
- ax1.imshow(cv2.cvtColor(image_with_boxes, cv2.COLOR_BGR2RGB))
355
- ax1.set_title("Bounding Boxes")
356
- ax1.axis('off')
357
-
358
- # Plot labels
359
- image_with_labels = preprocessed_image.copy()
360
- for i, detection in enumerate(sorted_detections):
361
- x1, y1, x2, y2 = detection.xyxy[0]
362
- label = sorted_labels[i]
363
- cv2.putText(image_with_labels, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
364
- ax2.imshow(cv2.cvtColor(image_with_labels, cv2.COLOR_BGR2RGB))
365
- ax2.set_title("Labels")
366
- ax2.axis('off')
367
-
368
- st.pyplot(fig)
369
-
370
- st.write("Extracted Text from Prescription:", resulting_string)
371
-
372
- # Prepare prompt for LLM
373
- prompt = f"""Analyze the following prescription text:
374
- {resulting_string}
375
-
376
- Please provide:
377
- 1. Identified drug name(s)
378
- 2. Full name of each identified drug
379
- 3. Primary uses of each drug
380
- 4. Common side effects
381
- 5. Recommended dosage (if identifiable from the text)
382
- 6. Any warnings or precautions
383
- 7. Potential interactions with other medications (if multiple drugs are identified)
384
- 8. Any additional relevant information for the patient
385
-
386
- If any part of the prescription is unclear or seems incomplete, please mention that and provide information about possible interpretations or matches. Always emphasize the importance of consulting a healthcare professional for accurate interpretation and advice."""
387
-
388
- # Get LLM response
389
- llm_response = get_ai71_response(prompt)
390
-
391
- st.subheader("AI Analysis of the Prescription")
392
- st.write(llm_response)
393
-
394
- # Remove the temporary image file
395
- os.remove(temp_image_path)
396
-
397
- else:
398
- st.info("Please upload a prescription image to proceed.")
399
-
400
- elif page == "Disease Detection":
401
- st.write("Upload a chest X-ray image for disease detection.")
402
- uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
403
-
404
- if uploaded_image is not None and st.session_state.disease_model is not None:
405
- # Display the image
406
- img_opened = Image.open(uploaded_image).convert('RGB')
407
- image_pred = np.array(img_opened)
408
- image_pred = cv2.resize(image_pred, (150, 150))
409
-
410
- # Convert the image to a numpy array
411
- image_pred = np.array(image_pred)
412
-
413
- # Rescale the image (if the model was trained with rescaling)
414
- image_pred = image_pred / 255.0
415
-
416
- # Add an extra dimension to match the input shape (1, 150, 150, 3)
417
- image_pred = np.expand_dims(image_pred, axis=0)
418
-
419
- # Predict using the model
420
- prediction = st.session_state.disease_model.predict(image_pred)
421
-
422
- # Get the predicted class
423
- predicted_ = np.argmax(prediction)
424
-
425
- # Decode the prediction
426
- if predicted_ == 0:
427
- predicted_class = "Covid"
428
- elif predicted_ == 1:
429
- predicted_class = "Normal Chest X-ray"
430
- else:
431
- predicted_class = "Pneumonia"
432
-
433
- st.image(image_pred, caption='Input image by user', use_column_width=True)
434
- st.write("Prediction Classes for different types:")
435
- st.write("COVID: 0")
436
- st.write("Normal Chest X-ray: 1")
437
- st.write("Pneumonia: 2")
438
- st.write("\n")
439
- st.write("DETECTED DISEASE DISPLAY")
440
- st.write(f"Predicted Class : {predicted_}")
441
- st.write(predicted_class)
442
-
443
- # Analyze X-ray results with LLM
444
- analyze_xray_with_llm(predicted_class)
445
- else:
446
- st.write("Please upload an image file or ensure the disease model is loaded.")
447
-
448
- elif page == "Outbreak Alert":
449
- st.markdown("## **Disease Outbreak News (from WHO)**")
450
-
451
- # Fetch WHO news page
452
- url = "https://www.who.int/news-room/events"
453
- response = requests.get(url)
454
- response.raise_for_status() # Raise an exception for bad status codes
455
-
456
- soup = BeautifulSoup(response.content, 'html.parser')
457
-
458
- # Find news articles (adjust selectors if WHO website changes)
459
- articles = soup.find_all('div', class_='list-view--item')
460
-
461
- for article in articles[:5]: # Display the top 5 news articles
462
- title_element = article.find('a', class_='link-container')
463
- if title_element:
464
- title = title_element.text.strip()
465
- link = title_element['href']
466
- date_element = article.find('span', class_='date')
467
- date = date_element.text.strip() if date_element else "Date not found"
468
-
469
- # Format date
470
- date_parts = date.split()
471
- if len(date_parts) >= 3:
472
- try:
473
- formatted_date = datetime.strptime(date, "%d %B %Y").strftime("%Y-%m-%d")
474
- except ValueError:
475
- formatted_date = date # Keep the original date if formatting fails
476
- else:
477
- formatted_date = date
478
-
479
- # Display news item in a card-like container
480
- with st.container():
481
- st.markdown(f"**{formatted_date}**")
482
- st.markdown(f"[{title}]({link})")
483
- st.markdown("---")
484
- else:
485
- st.write("Could not find article details.")
486
-
487
- # Auto-scroll to the bottom of the chat container
488
- st.markdown(
489
- """
490
- <script>
491
- const chatContainer = document.querySelector('.st-chat-container');
492
- if (chatContainer) {
493
- chatContainer.scrollTop = chatContainer.scrollHeight;
494
- }
495
- </script>
496
- """,
497
- unsafe_allow_html=True,
498
- )
 
56
  # --- Session State Initialization ---
57
  if 'disease_model' not in st.session_state:
58
  try:
59
+ # Check if running on Huggingface Spaces
60
+ if 'HUGGINGFACE_SPACES' in os.environ:
61
+ model_path = 'FINAL_MODEL.zip'
62
+ with tf.keras.utils.get_file('FINAL_MODEL.keras', model_path, extract=True) as extracted_model_path:
63
+ model_dir = os.path.dirname(extracted_model_path)
64
+ model_path = os.path.join(model_dir, 'FINAL_MODEL.keras')
65
+ st.session_state.disease_model = tf.keras.models.load_model(model_path)
66
+ else: # Running locally
67
+ model_path = 'FINAL_MODEL.keras'
68
  st.session_state.disease_model = tf.keras.models.load_model(model_path)
69
+
70
  print("Disease model loaded successfully!")
71
  except FileNotFoundError:
72
+ st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.zip' (for Huggingface) or 'FINAL_MODEL.keras' (local) is in the same directory as this app.")
73
  st.session_state.disease_model = None
74
  except PermissionError:
75
  st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.")
76
  st.session_state.disease_model = None
77
 
78
+ # ... rest of your code ...