File size: 21,156 Bytes
c77acf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c4851c
c77acf1
6b68020
c750967
 
 
 
 
2defa07
 
 
 
 
6b68020
c77acf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cb8e2d
a4a1998
5cb8e2d
16d2fe1
 
c3d0bfd
 
 
16d2fe1
 
0508f11
16d2fe1
978b34b
 
 
16d2fe1
398c97f
 
 
 
 
 
 
 
 
 
 
 
 
0c4851c
398c97f
 
c3d0bfd
0c4851c
398c97f
 
c3d0bfd
398c97f
 
 
 
 
 
 
f5b6e9f
398c97f
c3d0bfd
 
398c97f
0508f11
398c97f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deda045
f0a39c8
7120942
398c97f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17fd5c3
 
 
398c97f
 
17fd5c3
004ba85
 
17fd5c3
 
004ba85
17fd5c3
398c97f
 
 
17fd5c3
 
398c97f
 
17fd5c3
 
 
 
004ba85
17fd5c3
398c97f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5b6e9f
398c97f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5b6e9f
398c97f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
import streamlit as st
import requests
from utils.ai71_utils import get_ai71_response
from datetime import datetime
import cv2
import numpy as np
from PIL import Image
import supervision as sv
import matplotlib.pyplot as plt
import io
import os
from inference_sdk import InferenceHTTPClient
from bs4 import BeautifulSoup
import tensorflow as tf
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
import nltk
import re
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import joblib  # Import joblib for loading the logistic regression model

# --- Download NLTK 'punkt' resource if not already present ---
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

try:
    nltk.data.find('corpora/stopwords')
except LookupError:
    nltk.download('stopwords')

# --- Preprocess text function (moved outside session state) ---
def preprocess_text(text):
    # Convert to lowercase
    text = text.lower()

    cleaned_text = re.sub(r'[^a-zA-Z0-9\s\,]', ' ', text)
    # Tokenize text
    tokens = word_tokenize(cleaned_text)

    # Remove stop words
    stop_words = set(stopwords.words('english'))
    tokens = [word for word in tokens if word not in stop_words]

    # Rejoin tokens into a single string
    cleaned_text = ' '.join(tokens)

    return cleaned_text

st.title("Medi Scape Dashboard")

# --- Session State Initialization ---
if 'disease_model' not in st.session_state:
    try:
        # Assuming all models are in the root directory of your Hugging Face Space
        model_path = 'FINAL_MODEL.h5'  
        st.session_state.disease_model = tf.keras.models.load_model(model_path)
        print("Disease model loaded successfully!")
    except FileNotFoundError:
        st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.h5' is in the correct directory.")
        st.session_state.disease_model = None
    except PermissionError:
        st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.")
        st.session_state.disease_model = None

# Load the vectorizer 
if 'vectorizer' not in st.session_state:
    try:
        vectorizer_path = "vectorizer.pkl"
        st.session_state.vectorizer = pd.read_pickle(vectorizer_path)
        print("Vectorizer loaded successfully!")
    except FileNotFoundError:
        st.error("Vectorizer file not found. Please ensure 'vectorizer.pkl' is in the same directory as this app.")
        st.session_state.vectorizer = None
    except Exception as e:
        st.error(f"An error occurred while loading the vectorizer: {e}")
        st.session_state.vectorizer = None

# Load the logistic regression model using joblib
if 'model_llm' not in st.session_state:
    try:
        llm_model_path = "logistic_regression_model.pkl" 
        st.session_state.model_llm = joblib.load(llm_model_path)  
        print("LLM model loaded successfully!")
    except FileNotFoundError:
        st.error("LLM model file not found. Please ensure 'logistic_regression_model.pkl' is in the correct directory.")
        st.session_state.model_llm = None
    except Exception as e:
        st.error(f"An error occurred while loading the LLM model: {e}")
        st.session_state.model_llm = None

# --- End of Session State Initialization ---

# Load the disease classification model (outside session state for this example)
try:
    model_path = 'FINAL_MODEL.h5'  
    disease_model = tf.keras.models.load_model(model_path)
except FileNotFoundError:
    st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.h5' is in the correct directory.")
    disease_model = None
except PermissionError:
    st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.")
    disease_model = None

# Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Home", "AI Chatbot Diagnosis", "Drug Identification", "Disease Detection", "Outbreak Alert"])

# Access secrets using st.secrets
if "INFERENCE_API_URL" not in st.secrets or "INFERENCE_API_KEY" not in st.secrets:
    st.error("Please make sure to set your secrets in the Streamlit secrets settings.")
else:
    # Initialize the Inference Client
    CLIENT = InferenceHTTPClient(
        api_url=st.secrets["INFERENCE_API_URL"],
        api_key=st.secrets["INFERENCE_API_KEY"]
    )

    # Function to preprocess the image
    def preprocess_image(image_path):
        # Load the image
        image = cv2.imread(image_path)

        # Convert to grayscale
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Remove noise
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)

        # Thresholding/Binarization
        _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

        # Dilation and Erosion
        kernel = np.ones((1, 1), np.uint8)
        dilated = cv2.dilate(binary, kernel, iterations=1)
        eroded = cv2.erode(dilated, kernel, iterations=1)

        # Edge detection
        edges = cv2.Canny(eroded, 100, 200)

        # Deskewing
        coords = np.column_stack(np.where(edges > 0))
        angle = cv2.minAreaRect(coords)[-1]
        if angle < -45:
            angle = -(90 + angle)
        else:
            angle = -angle

        (h, w) = edges.shape[:2]
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, angle, 1.0)
        deskewed = cv2.warpAffine(edges, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)

        # Find contours
        contours, _ = cv2.findContours(deskewed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Draw contours on the original image
        contour_image = image.copy()
        cv2.drawContours(contour_image, contours, -1, (0, 255, 0), 2)

        return contour_image

    def get_x1(detection):
        return detection.xyxy[0][0]

    # --- Prediction function (using session state) ---
    def predict_disease(symptoms):
        if st.session_state.vectorizer is not None and st.session_state.model_llm is not None:
            preprocessed_symptoms = preprocess_text(symptoms)
            symptoms_vectorized = st.session_state.vectorizer.transform([preprocessed_symptoms])
            # Use the loaded model for prediction
            prediction = st.session_state.model_llm.predict(symptoms_vectorized)  # Use the model, not the array
            return prediction[0]  # Extract the string value from the array
        else:
            st.error("Unable to make prediction. Vectorizer or LLM model is not loaded.")
            return None

    # --- New function to analyze X-ray with LLM ---
    def analyze_xray_with_llm(predicted_class):
        prompt = f"""
        Based on a chest X-ray analysis, the predicted condition is {predicted_class}. 
        Please provide a concise summary of this condition, including:
        - A brief description of the condition.
        - Common symptoms associated with it.
        - Potential causes.
        - General treatment approaches.
        - Any other relevant information for a patient.
        """
        llm_response = get_ai71_response(prompt)
        st.write("## LLM Analysis of X-ray Results:")
        st.write(llm_response)

    # --- Functions for Symptom Detection ---
    def precaution(label):
        dataset_precau = pd.read_csv("disease_precaution.csv", encoding='latin1')  # Make sure this file is in the same directory
        label = str(label).lower()
    
        dataset_precau["Disease"] = dataset_precau["Disease"].str.lower()
        filtered_precautions = dataset_precau[dataset_precau["Disease"] == label]
    
        if not filtered_precautions.empty:
            precautions = filtered_precautions[["Precaution_1", "Precaution_2", "Precaution_3", "Precaution_4"]]
            precautions_list = precautions.values.flatten().tolist()  # Flatten the DataFrame to a list of strings
            return "\n".join(f"- {precaution}" for precaution in precautions_list)  # Join the list into a single string with bullet points
        else:
            return "No precautions found."

    def occurance(label):
        dataset_occur = pd.read_csv("disease_riskFactors.csv", encoding='latin1')
        label = str(label).lower()
    
        dataset_occur["DNAME"] = dataset_occur["DNAME"].str.lower()
        filtered_occurrence = dataset_occur[dataset_occur["DNAME"] == label]
    
        occurrences = filtered_occurrence["OCCUR"].tolist()  # Convert Series to list
        if occurrences:
            return "\n".join(occurrences)  # Join the list into a single string with newlines
        else:
            return "No occurrences found."

    if page == "Home":
        st.markdown("## Welcome to Medi Scape")
        st.write("Medi Scape is an AI-powered healthcare application designed to streamline the process of understanding and managing medical information.  It leverages advanced AI models to provide features such as prescription analysis, disease detection from chest X-rays, and symptom-based diagnosis assistance.")

        st.markdown("## Features")
        st.write("Medi Scape provides various AI-powered tools for remote healthcare, including:")
        features = [
            "**AI Chatbot Diagnosis:** Interact with an AI chatbot for preliminary diagnosis and medical information.",
            "**Drug Identification:** Upload a prescription image to identify medications and access relevant details.",
            "**Doctor's Handwriting Identification:** Our system can accurately recognize and process doctor's handwriting.",
            "**Disease Detection:** Upload a chest X-ray image to detect potential diseases.",
            "**Outbreak Alert:** Stay informed about potential disease outbreaks in your area."
        ]
        for feature in features:
            st.markdown(f"- {feature}")

        st.markdown("## How it Works")
        steps = [
            "**Upload:** You can upload a prescription image for drug identification or a chest X-ray image for disease detection.",
            "**Process:** Our AI models will analyze the image and extract relevant information.",
            "**Results:** You will receive identified drug names, uses, side effects, and more, or a potential disease diagnosis."
        ]
        for i, step in enumerate(steps, 1):
            st.markdown(f"{i}. {step}")

        st.markdown("## Key Features")
        key_features = [
            "**AI-Powered:** Leverages advanced AI models for accurate analysis and diagnosis.",
            "**User-Friendly:** Simple and intuitive interface for easy navigation and interaction.",
            "**Secure:** Your data is protected and handled with confidentiality."
        ]
        for feature in key_features:
            st.markdown(f"- {feature}")

        st.markdown("Please use the sidebar to navigate to different features.")

    elif page == "AI Chatbot Diagnosis":
        st.write("Enter your symptoms separated by commas:")
        symptoms_input = st.text_area("Symptoms:")
        col1, col2 = st.columns(2)
        with col1:
            if st.button("Diagnose with Regression Model"):
                if symptoms_input:
                    # --- Pipeline 1 Implementation ---
                    # 1. Symptom Input (already done with st.text_area)
                    # 2. Regression Prediction
                    regression_prediction = predict_disease(symptoms_input)

                    if regression_prediction is not None:
                        st.write("## Logistic Regression Prediction:")
                        st.write(regression_prediction)

                        st.write("## Precautions:")
                        precautions_names = precaution(regression_prediction)
                        st.write(precautions_names)

                        st.write("## Occurrence:")
                        occurance_name = occurance(regression_prediction)
                        st.write(occurance_name)

                else:
                    st.write("Please enter your symptoms.")

        with col2:
            if st.button("Diagnose with LLM"):
                if symptoms_input:
                    # --- Pipeline 2 Implementation (LLM Only) ---
                    prompt = f"""The user is experiencing the following symptoms: {symptoms_input}. 
                    Based on these symptoms, provide a detailed explanation of possible conditions, including 
                    potential causes, common symptoms, and general treatment approaches. Also, suggest when 
                    a patient should consult a doctor."""

                    llm_response = get_ai71_response(prompt)

                    st.write("## LLM Diagnosis:")
                    st.write(llm_response)
                else:
                    st.write("Please enter your symptoms.")

    elif page == "Drug Identification":
        st.write("Upload a prescription image for drug identification.")
        uploaded_file = st.file_uploader("Upload prescription", type=["png", "jpg", "jpeg"])

        if uploaded_file is not None:
            # Display the uploaded image
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded Prescription", use_column_width=True)

            if st.button("Process Prescription"):
                # Save the image to a temporary file
                temp_image_path = "temp_image.jpg"
                image.save(temp_image_path)

                # Preprocess the image
                preprocessed_image = preprocess_image(temp_image_path)

                # Perform inference
                result_doch1 = CLIENT.infer(preprocessed_image, model_id="doctor-s-handwriting/1")

                # Extract labels and detections
                labels = [item["class"] for item in result_doch1["predictions"]]
                detections = sv.Detections.from_inference(result_doch1)

                # Sort detections and labels
                sorted_indices = sorted(range(len(detections)), key=lambda i: get_x1(detections[i]))
                sorted_detections = [detections[i] for i in sorted_indices]
                sorted_labels = [labels[i] for i in sorted_indices]

                # Convert list to string
                resulting_string = ''.join(sorted_labels)

                # Display results
                st.subheader("Processed Prescription")
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

                # Plot bounding boxes
                image_with_boxes = preprocessed_image.copy()
                for detection in sorted_detections:
                    x1, y1, x2, y2 = detection.xyxy[0]
                    cv2.rectangle(image_with_boxes, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)      
                ax1.imshow(cv2.cvtColor(image_with_boxes, cv2.COLOR_BGR2RGB))
                ax1.set_title("Bounding Boxes")
                ax1.axis('off')

                # Plot labels
                image_with_labels = preprocessed_image.copy()
                for i, detection in enumerate(sorted_detections):
                    x1, y1, x2, y2 = detection.xyxy[0]
                    label = sorted_labels[i]
                    cv2.putText(image_with_labels, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
                ax2.imshow(cv2.cvtColor(image_with_labels, cv2.COLOR_BGR2RGB))
                ax2.set_title("Labels")
                ax2.axis('off')

                st.pyplot(fig)

                st.write("Extracted Text from Prescription:", resulting_string)

                # Prepare prompt for LLM
                prompt = f"""Analyze the following prescription text:
                {resulting_string}

                Please provide:
                1. Identified drug name(s)
                2. Full name of each identified drug
                3. Primary uses of each drug
                4. Common side effects
                5. Recommended dosage (if identifiable from the text)
                6. Any warnings or precautions
                7. Potential interactions with other medications (if multiple drugs are identified)
                8. Any additional relevant information for the patient

                If any part of the prescription is unclear or seems incomplete, please mention that and provide information about possible interpretations or matches. Always emphasize the importance of consulting a healthcare professional for accurate interpretation and advice."""

                # Get LLM response
                llm_response = get_ai71_response(prompt)

                st.subheader("AI Analysis of the Prescription")
                st.write(llm_response)

                # Remove the temporary image file
                os.remove(temp_image_path)

        else:
            st.info("Please upload a prescription image to proceed.")

    elif page == "Disease Detection":
        st.write("Upload a chest X-ray image for disease detection.")
        uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

        if uploaded_image is not None and disease_model is not None:  # Use disease_model directly
            # Display the image
            img_opened = Image.open(uploaded_image).convert('RGB')
            image_pred = np.array(img_opened)
            image_pred = cv2.resize(image_pred, (150, 150))

            # Convert the image to a numpy array
            image_pred = np.array(image_pred)

            # Rescale the image (if the model was trained with rescaling)
            image_pred = image_pred / 255.0

            # Add an extra dimension to match the input shape (1, 150, 150, 3)
            image_pred = np.expand_dims(image_pred, axis=0)

            # Predict using the model
            prediction = disease_model.predict(image_pred)  # Use disease_model directly

            # Get the predicted class
            predicted_ = np.argmax(prediction)

            # Decode the prediction
            if predicted_ == 0:
                predicted_class = "Covid"
            elif predicted_ == 1:
                predicted_class = "Normal Chest X-ray"
            else:
                predicted_class = "Pneumonia"

            st.image(image_pred, caption='Input image by user', use_column_width=True)
            st.write("Prediction Classes for different types:")
            st.write("COVID: 0")
            st.write("Normal Chest X-ray: 1")
            st.write("Pneumonia: 2")
            st.write("\n")
            st.write("DETECTED DISEASE DISPLAY")
            st.write(f"Predicted Class : {predicted_}")
            st.write(predicted_class)

            # Analyze X-ray results with LLM
            analyze_xray_with_llm(predicted_class)
        else:
            st.write("Please upload an image file or ensure the disease model is loaded.")

    elif page == "Outbreak Alert":
        st.markdown("## **Disease Outbreak News (from WHO)**")

        # Fetch WHO news page
        url = "https://www.who.int/news-room/events"
        response = requests.get(url)
        response.raise_for_status()  # Raise an exception for bad status codes

        soup = BeautifulSoup(response.content, 'html.parser')

        # Find news articles (adjust selectors if WHO website changes)
        articles = soup.find_all('div', class_='list-view--item')

        for article in articles[:5]:  # Display the top 5 news articles
            title_element = article.find('a', class_='link-container')
            if title_element:
                title = title_element.text.strip()
                link = title_element['href']
                date_element = article.find('span', class_='date')
                date = date_element.text.strip() if date_element else "Date not found"
                
                # Format date
                date_parts = date.split()
                if len(date_parts) >= 3:
                    try:
                        formatted_date = datetime.strptime(date, "%d %B %Y").strftime("%Y-%m-%d")
                    except ValueError:
                        formatted_date = date  # Keep the original date if formatting fails
                else:
                    formatted_date = date
                
                # Display news item in a card-like container
                with st.container():
                    st.markdown(f"**{formatted_date}**")
                    st.markdown(f"[{title}]({link})")
                    st.markdown("---")
            else:
                st.write("Could not find article details.")

# Auto-scroll to the bottom of the chat container
st.markdown(
    """
    <script>
    const chatContainer = document.querySelector('.st-chat-container');
    if (chatContainer) {
        chatContainer.scrollTop = chatContainer.scrollHeight;
    }
    </script>
    """,
    unsafe_allow_html=True,
)