File size: 21,086 Bytes
9aeb50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313e620
9aeb50f
 
 
 
 
313e620
9aeb50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313e620
9aeb50f
313e620
9aeb50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
import streamlit as st
import requests
from utils.ai71_utils import get_ai71_response
from datetime import datetime
import cv2
import numpy as np
from PIL import Image
import supervision as sv
import matplotlib.pyplot as plt
import io
import os
from inference_sdk import InferenceHTTPClient
from bs4 import BeautifulSoup
import tensorflow as tf
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
import nltk
import re
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

# --- Preprocess text function (moved outside session state) ---
def preprocess_text(text):
    # Convert to lowercase
    text = text.lower()

    cleaned_text = re.sub(r'[^a-zA-Z0-9\s\,]', ' ', text)
    # Tokenize text
    tokens = word_tokenize(cleaned_text)

    # Remove stop words
    stop_words = set(stopwords.words('english'))
    tokens = [word for word in tokens if word not in stop_words]

    # Rejoin tokens into a single string
    cleaned_text = ' '.join(tokens)

    return cleaned_text

st.title("Medi Scape Dashboard")

# --- Session State Initialization ---
if 'disease_model' not in st.session_state:
    try:
        model_path = 'FINAL_MODEL.keras'
        print(f"Attempting to load disease model from: {model_path}")
        print(f"Model file exists: {os.path.exists(model_path)}") 
        st.session_state.disease_model = tf.keras.models.load_model(model_path)
        print("Disease model loaded successfully!")
    except FileNotFoundError:
        st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.keras' is in the same directory as this app.")
        st.session_state.disease_model = None

# Load the vectorizer 
if 'vectorizer' not in st.session_state:
    try:
        vectorizer_path = "vectorizer.pkl" 
        print(f"Attempting to load vectorizer from: {vectorizer_path}")
        print(f"Vectorizer file exists: {os.path.exists(vectorizer_path)}")
        st.session_state.vectorizer = pd.read_pickle(vectorizer_path)
        print("Vectorizer loaded successfully!")
    except FileNotFoundError:
        st.error("Vectorizer file not found. Please ensure 'vectorizer.pkl' is in the same directory as this app.")
        st.session_state.vectorizer = None

if 'model_llm' not in st.session_state:
    # --- Load pre-trained model and vectorizer ---
    st.session_state.model_llm = LogisticRegression()
    try:
        llm_model_path = "logistic_regression_model.pkl"
        print(f"Attempting to load LLM model from: {llm_model_path}")
        print(f"LLM Model file exists: {os.path.exists(llm_model_path)}")
        st.session_state.model_llm = pd.read_pickle(llm_model_path)  
        print("LLM model loaded successfully!")
    except FileNotFoundError:
        st.error("LLM model file not found. Please ensure 'logistic_regression_model.pkl' is in the same directory.")
        st.session_state.model_llm = None

    # Load datasets (only for reference, not used for training)
    dataset_1 = pd.read_csv("Symptoms_Detection/training_data.csv")
    dataset_2 = pd.read_csv("Symptoms_Detection/Symptom2Disease.csv")

    # Create symptoms_text column (only for reference, not used for training)
    dataset_1['symptoms_text'] = dataset_1.apply(lambda row: ','.join([col for col in dataset_1.columns if row[col] == 1]), axis=1)
    final_dataset = pd.DataFrame(dataset_1[["prognosis", "symptoms_text"]])
    final_dataset.columns = ['label', 'text']

    # Combine datasets (only for reference, not used for training)
    df_combined = pd.concat([final_dataset, dataset_2[['label', 'text']]], axis=0, ignore_index=True)

    # Store in session state (only for reference, not used for training)
    st.session_state.df_combined = df_combined
# --- End of Session State Initialization ---

# Load the disease classification model
try:
    disease_model = tf.keras.models.load_model('FINAL_MODEL.keras')
except FileNotFoundError:
    st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.keras' is in the same directory as this app.")
    disease_model = None

# Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Home", "AI Chatbot Diagnosis", "Drug Identification", "Disease Detection", "Outbreak Alert"])

# Access secrets using st.secrets
if "INFERENCE_API_URL" not in st.secrets or "INFERENCE_API_KEY" not in st.secrets:
    st.error("Please make sure to set your secrets in the Streamlit secrets settings.")
else:
    # Initialize the Inference Client
    CLIENT = InferenceHTTPClient(
        api_url=st.secrets["INFERENCE_API_URL"],
        api_key=st.secrets["INFERENCE_API_KEY"]
    )

    # Function to preprocess the image
    def preprocess_image(image_path):
        # Load the image
        image = cv2.imread(image_path)

        # Convert to grayscale
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Remove noise
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)

        # Thresholding/Binarization
        _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

        # Dilation and Erosion
        kernel = np.ones((1, 1), np.uint8)
        dilated = cv2.dilate(binary, kernel, iterations=1)
        eroded = cv2.erode(dilated, kernel, iterations=1)

        # Edge detection
        edges = cv2.Canny(eroded, 100, 200)

        # Deskewing
        coords = np.column_stack(np.where(edges > 0))
        angle = cv2.minAreaRect(coords)[-1]
        if angle < -45:
            angle = -(90 + angle)
        else:
            angle = -angle

        (h, w) = edges.shape[:2]
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, angle, 1.0)
        deskewed = cv2.warpAffine(edges, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)

        # Find contours
        contours, _ = cv2.findContours(deskewed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Draw contours on the original image
        contour_image = image.copy()
        cv2.drawContours(contour_image, contours, -1, (0, 255, 0), 2)

        return contour_image

    def get_x1(detection):
        return detection.xyxy[0][0]

# Access secrets using st.secrets
if "INFERENCE_API_URL" not in st.secrets or "INFERENCE_API_KEY" not in st.secrets:
    st.error("Please make sure to set your secrets in the Streamlit secrets settings.")
else:
    # Initialize the Inference Client
    CLIENT = InferenceHTTPClient(
        api_url=st.secrets["INFERENCE_API_URL"],
        api_key=st.secrets["INFERENCE_API_KEY"]
    )

    # Function to preprocess the image
    def preprocess_image(image_path):
        # Load the image
        image = cv2.imread(image_path)

        # Convert to grayscale
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Remove noise
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)

        # Thresholding/Binarization
        _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

        # Dilation and Erosion
        kernel = np.ones((1, 1), np.uint8)
        dilated = cv2.dilate(binary, kernel, iterations=1)
        eroded = cv2.erode(dilated, kernel, iterations=1)

        # Edge detection
        edges = cv2.Canny(eroded, 100, 200)

        # Deskewing
        coords = np.column_stack(np.where(edges > 0))
        angle = cv2.minAreaRect(coords)[-1]
        if angle < -45:
            angle = -(90 + angle)
        else:
            angle = -angle

        (h, w) = edges.shape[:2]
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, angle, 1.0)
        deskewed = cv2.warpAffine(edges, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)

        # Find contours
        contours, _ = cv2.findContours(deskewed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Draw contours on the original image
        contour_image = image.copy()
        cv2.drawContours(contour_image, contours, -1, (0, 255, 0), 2)

        return contour_image

    def get_x1(detection):
        return detection.xyxy[0][0]

    # --- Prediction function (using session state) ---
    def predict_disease(symptoms):
        if st.session_state.vectorizer is not None and st.session_state.model_llm is not None:
            preprocessed_symptoms = preprocess_text(symptoms)
            symptoms_vectorized = st.session_state.vectorizer.transform([preprocessed_symptoms])
            prediction = st.session_state.model_llm.predict(symptoms_vectorized)
            return prediction[0]
        else:
            st.error("Unable to make prediction. Vectorizer or LLM model is not loaded.")
            return None

    # --- New function to analyze X-ray with LLM ---
    def analyze_xray_with_llm(predicted_class):
        prompt = f"""
        Based on a chest X-ray analysis, the predicted condition is {predicted_class}. 
        Please provide a concise summary of this condition, including:
        - A brief description of the condition.
        - Common symptoms associated with it.
        - Potential causes.
        - General treatment approaches.
        - Any other relevant information for a patient.
        """
        llm_response = get_ai71_response(prompt)
        st.write("## LLM Analysis of X-ray Results:")
        st.write(llm_response)

    if page == "Home":
        st.markdown("## Welcome to Medi Scape")
        st.write("Medi Scape is an AI-powered healthcare application designed to streamline the process of understanding and managing medical information.  It leverages advanced AI models to provide features such as prescription analysis, disease detection from chest X-rays, and symptom-based diagnosis assistance.")

        st.markdown("## Features")
        st.write("Medi Scape provides various AI-powered tools for remote healthcare, including:")
        features = [
            "**AI Chatbot Diagnosis:** Interact with an AI chatbot for preliminary diagnosis and medical information.",
            "**Drug Identification:** Upload a prescription image to identify medications and access relevant details.",
            "**Doctor's Handwriting Identification:** Our system can accurately recognize and process doctor's handwriting.",
            "**Disease Detection:** Upload a chest X-ray image to detect potential diseases.",
            "**Outbreak Alert:** Stay informed about potential disease outbreaks in your area."
        ]
        for feature in features:
            st.markdown(f"- {feature}")

        st.markdown("## How it Works")
        steps = [
            "**Upload:** You can upload a prescription image for drug identification or a chest X-ray image for disease detection.",
            "**Process:** Our AI models will analyze the image and extract relevant information.",
            "**Results:** You will receive identified drug names, uses, side effects, and more, or a potential disease diagnosis."
        ]
        for i, step in enumerate(steps, 1):
            st.markdown(f"{i}. {step}")

        st.markdown("## Key Features")
        key_features = [
            "**AI-Powered:** Leverages advanced AI models for accurate analysis and diagnosis.",
            "**User-Friendly:** Simple and intuitive interface for easy navigation and interaction.",
            "**Secure:** Your data is protected and handled with confidentiality."
        ]
        for feature in key_features:
            st.markdown(f"- {feature}")

        st.markdown("Please use the sidebar to navigate to different features.")

    elif page == "AI Chatbot Diagnosis":
        st.write("Enter your symptoms separated by commas:")
        symptoms_input = st.text_area("Symptoms:")
        if st.button("Diagnose"):
            if symptoms_input:
                # --- Pipeline 1 Implementation ---
                # 1. Symptom Input (already done with st.text_area)
                # 2. Regression Prediction
                regression_prediction = predict_disease(symptoms_input)

                if regression_prediction is not None:
                    # 3. LLM Prompt Enhancement
                    prompt = f"""The predicted condition based on a symptom analysis is {regression_prediction}. 
                    Provide a detailed explanation of this condition, including possible causes, common symptoms, 
                    and general treatment approaches. Also, suggest when a patient should consult a doctor."""

                    # 4. LLM Output
                    llm_response = get_ai71_response(prompt)

                    # 5. Combined Output
                    st.write("## Logistic Regression Prediction:")
                    st.write(regression_prediction)

                    st.write("## LLM Explanation:")
                    st.write(llm_response)
                # --- End of Pipeline 1 Implementation ---

            else:
                st.write("Please enter your symptoms.")

    elif page == "Drug Identification":
        st.write("Upload a prescription image for drug identification.")
        uploaded_file = st.file_uploader("Upload prescription", type=["png", "jpg", "jpeg"])

        if uploaded_file is not None:
            # Display the uploaded image
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded Prescription", use_column_width=True)

            if st.button("Process Prescription"):
                # Save the image to a temporary file
                temp_image_path = "temp_image.jpg"
                image.save(temp_image_path)

                # Preprocess the image
                preprocessed_image = preprocess_image(temp_image_path)

                # Perform inference
                result_doch1 = CLIENT.infer(preprocessed_image, model_id="doctor-s-handwriting/1")

                # Extract labels and detections
                labels = [item["class"] for item in result_doch1["predictions"]]
                detections = sv.Detections.from_inference(result_doch1)

                # Sort detections and labels
                sorted_indices = sorted(range(len(detections)), key=lambda i: get_x1(detections[i]))
                sorted_detections = [detections[i] for i in sorted_indices]
                sorted_labels = [labels[i] for i in sorted_indices]

                # Convert list to string
                resulting_string = ''.join(sorted_labels)

                # Display results
                st.subheader("Processed Prescription")
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

                # Plot bounding boxes
                image_with_boxes = preprocessed_image.copy()
                for detection in sorted_detections:
                    x1, y1, x2, y2 = detection.xyxy[0]
                    cv2.rectangle(image_with_boxes, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)      
                ax1.imshow(cv2.cvtColor(image_with_boxes, cv2.COLOR_BGR2RGB))
                ax1.set_title("Bounding Boxes")
                ax1.axis('off')

                # Plot labels
                image_with_labels = preprocessed_image.copy()
                for i, detection in enumerate(sorted_detections):
                    x1, y1, x2, y2 = detection.xyxy[0]
                    label = sorted_labels[i]
                    cv2.putText(image_with_labels, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
                ax2.imshow(cv2.cvtColor(image_with_labels, cv2.COLOR_BGR2RGB))
                ax2.set_title("Labels")
                ax2.axis('off')

                st.pyplot(fig)

                st.write("Extracted Text from Prescription:", resulting_string)

                # Prepare prompt for LLM
                prompt = f"""Analyze the following prescription text:
                {resulting_string}

                Please provide:
                1. Identified drug name(s)
                2. Full name of each identified drug
                3. Primary uses of each drug
                4. Common side effects
                5. Recommended dosage (if identifiable from the text)
                6. Any warnings or precautions
                7. Potential interactions with other medications (if multiple drugs are identified)
                8. Any additional relevant information for the patient

                If any part of the prescription is unclear or seems incomplete, please mention that and provide information about possible interpretations or matches. Always emphasize the importance of consulting a healthcare professional for accurate interpretation and advice."""

                # Get LLM response
                llm_response = get_ai71_response(prompt)

                st.subheader("AI Analysis of the Prescription")
                st.write(llm_response)

                # Remove the temporary image file
                os.remove(temp_image_path)

        else:
            st.info("Please upload a prescription image to proceed.")

    elif page == "Disease Detection":
        st.write("Upload a chest X-ray image for disease detection.")
        uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

        if uploaded_image is not None and st.session_state.disease_model is not None:
            # Display the image
            img_opened = Image.open(uploaded_image).convert('RGB')
            image_pred = np.array(img_opened)
            image_pred = cv2.resize(image_pred, (150, 150))

            # Convert the image to a numpy array
            image_pred = np.array(image_pred)

            # Rescale the image (if the model was trained with rescaling)
            image_pred = image_pred / 255.0

            # Add an extra dimension to match the input shape (1, 150, 150, 3)
            image_pred = np.expand_dims(image_pred, axis=0)

            # Predict using the model
            prediction = st.session_state.disease_model.predict(image_pred)

            # Get the predicted class
            predicted_ = np.argmax(prediction)

            # Decode the prediction
            if predicted_ == 0:
                predicted_class = "Covid"
            elif predicted_ == 1:
                predicted_class = "Normal Chest X-ray"
            else:
                predicted_class = "Pneumonia"

            st.image(image_pred, caption='Input image by user', use_column_width=True)
            st.write("Prediction Classes for different types:")
            st.write("COVID: 0")
            st.write("Normal Chest X-ray: 1")
            st.write("Pneumonia: 2")
            st.write("\n")
            st.write("DETECTED DISEASE DISPLAY")
            st.write(f"Predicted Class : {predicted_}")
            st.write(predicted_class)

            # Analyze X-ray results with LLM
            analyze_xray_with_llm(predicted_class)
        else:
            st.write("Please upload an image file or ensure the disease model is loaded.")

    elif page == "Outbreak Alert":
        st.markdown("## **Disease Outbreak News (from WHO)**")

        # Fetch WHO news page
        url = "https://www.who.int/news-room/events"
        response = requests.get(url)
        response.raise_for_status()  # Raise an exception for bad status codes

        soup = BeautifulSoup(response.content, 'html.parser')

        # Find news articles (adjust selectors if WHO website changes)
        articles = soup.find_all('div', class_='list-view--item')

        for article in articles[:5]:  # Display the top 5 news articles
            title_element = article.find('a', class_='link-container')
            if title_element:
                title = title_element.text.strip()
                link = title_element['href']
                date_element = article.find('span', class_='date')
                date = date_element.text.strip() if date_element else "Date not found"
                
                # Format date
                date_parts = date.split()
                if len(date_parts) >= 3:
                    try:
                        formatted_date = datetime.strptime(date, "%d %B %Y").strftime("%Y-%m-%d")
                    except ValueError:
                        formatted_date = date  # Keep the original date if formatting fails
                else:
                    formatted_date = date
                
                # Display news item in a card-like container
                with st.container():
                    st.markdown(f"**{formatted_date}**")
                    st.markdown(f"[{title}]({link})")
                    st.markdown("---")
            else:
                st.write("Could not find article details.")

# Auto-scroll to the bottom of the chat container
st.markdown(
    """
    <script>
    const chatContainer = document.querySelector('.st-chat-container');
    if (chatContainer) {
        chatContainer.scrollTop = chatContainer.scrollHeight;
    }
    </script>
    """,
    unsafe_allow_html=True,
)