DjPapzin's picture
Renamed FINAL_MODEL.zip to FINAL_MODEL.keras
313e620
import streamlit as st
import requests
from utils.ai71_utils import get_ai71_response
from datetime import datetime
import cv2
import numpy as np
from PIL import Image
import supervision as sv
import matplotlib.pyplot as plt
import io
import os
from inference_sdk import InferenceHTTPClient
from bs4 import BeautifulSoup
import tensorflow as tf
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
import nltk
import re
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
# --- Preprocess text function (moved outside session state) ---
def preprocess_text(text):
# Convert to lowercase
text = text.lower()
cleaned_text = re.sub(r'[^a-zA-Z0-9\s\,]', ' ', text)
# Tokenize text
tokens = word_tokenize(cleaned_text)
# Remove stop words
stop_words = set(stopwords.words('english'))
tokens = [word for word in tokens if word not in stop_words]
# Rejoin tokens into a single string
cleaned_text = ' '.join(tokens)
return cleaned_text
st.title("Medi Scape Dashboard")
# --- Session State Initialization ---
if 'disease_model' not in st.session_state:
try:
model_path = 'FINAL_MODEL.keras'
print(f"Attempting to load disease model from: {model_path}")
print(f"Model file exists: {os.path.exists(model_path)}")
st.session_state.disease_model = tf.keras.models.load_model(model_path)
print("Disease model loaded successfully!")
except FileNotFoundError:
st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.keras' is in the same directory as this app.")
st.session_state.disease_model = None
# Load the vectorizer
if 'vectorizer' not in st.session_state:
try:
vectorizer_path = "vectorizer.pkl"
print(f"Attempting to load vectorizer from: {vectorizer_path}")
print(f"Vectorizer file exists: {os.path.exists(vectorizer_path)}")
st.session_state.vectorizer = pd.read_pickle(vectorizer_path)
print("Vectorizer loaded successfully!")
except FileNotFoundError:
st.error("Vectorizer file not found. Please ensure 'vectorizer.pkl' is in the same directory as this app.")
st.session_state.vectorizer = None
if 'model_llm' not in st.session_state:
# --- Load pre-trained model and vectorizer ---
st.session_state.model_llm = LogisticRegression()
try:
llm_model_path = "logistic_regression_model.pkl"
print(f"Attempting to load LLM model from: {llm_model_path}")
print(f"LLM Model file exists: {os.path.exists(llm_model_path)}")
st.session_state.model_llm = pd.read_pickle(llm_model_path)
print("LLM model loaded successfully!")
except FileNotFoundError:
st.error("LLM model file not found. Please ensure 'logistic_regression_model.pkl' is in the same directory.")
st.session_state.model_llm = None
# Load datasets (only for reference, not used for training)
dataset_1 = pd.read_csv("Symptoms_Detection/training_data.csv")
dataset_2 = pd.read_csv("Symptoms_Detection/Symptom2Disease.csv")
# Create symptoms_text column (only for reference, not used for training)
dataset_1['symptoms_text'] = dataset_1.apply(lambda row: ','.join([col for col in dataset_1.columns if row[col] == 1]), axis=1)
final_dataset = pd.DataFrame(dataset_1[["prognosis", "symptoms_text"]])
final_dataset.columns = ['label', 'text']
# Combine datasets (only for reference, not used for training)
df_combined = pd.concat([final_dataset, dataset_2[['label', 'text']]], axis=0, ignore_index=True)
# Store in session state (only for reference, not used for training)
st.session_state.df_combined = df_combined
# --- End of Session State Initialization ---
# Load the disease classification model
try:
disease_model = tf.keras.models.load_model('FINAL_MODEL.keras')
except FileNotFoundError:
st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.keras' is in the same directory as this app.")
disease_model = None
# Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Home", "AI Chatbot Diagnosis", "Drug Identification", "Disease Detection", "Outbreak Alert"])
# Access secrets using st.secrets
if "INFERENCE_API_URL" not in st.secrets or "INFERENCE_API_KEY" not in st.secrets:
st.error("Please make sure to set your secrets in the Streamlit secrets settings.")
else:
# Initialize the Inference Client
CLIENT = InferenceHTTPClient(
api_url=st.secrets["INFERENCE_API_URL"],
api_key=st.secrets["INFERENCE_API_KEY"]
)
# Function to preprocess the image
def preprocess_image(image_path):
# Load the image
image = cv2.imread(image_path)
# Convert to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Remove noise
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Thresholding/Binarization
_, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Dilation and Erosion
kernel = np.ones((1, 1), np.uint8)
dilated = cv2.dilate(binary, kernel, iterations=1)
eroded = cv2.erode(dilated, kernel, iterations=1)
# Edge detection
edges = cv2.Canny(eroded, 100, 200)
# Deskewing
coords = np.column_stack(np.where(edges > 0))
angle = cv2.minAreaRect(coords)[-1]
if angle < -45:
angle = -(90 + angle)
else:
angle = -angle
(h, w) = edges.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
deskewed = cv2.warpAffine(edges, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
# Find contours
contours, _ = cv2.findContours(deskewed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Draw contours on the original image
contour_image = image.copy()
cv2.drawContours(contour_image, contours, -1, (0, 255, 0), 2)
return contour_image
def get_x1(detection):
return detection.xyxy[0][0]
# Access secrets using st.secrets
if "INFERENCE_API_URL" not in st.secrets or "INFERENCE_API_KEY" not in st.secrets:
st.error("Please make sure to set your secrets in the Streamlit secrets settings.")
else:
# Initialize the Inference Client
CLIENT = InferenceHTTPClient(
api_url=st.secrets["INFERENCE_API_URL"],
api_key=st.secrets["INFERENCE_API_KEY"]
)
# Function to preprocess the image
def preprocess_image(image_path):
# Load the image
image = cv2.imread(image_path)
# Convert to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Remove noise
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Thresholding/Binarization
_, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Dilation and Erosion
kernel = np.ones((1, 1), np.uint8)
dilated = cv2.dilate(binary, kernel, iterations=1)
eroded = cv2.erode(dilated, kernel, iterations=1)
# Edge detection
edges = cv2.Canny(eroded, 100, 200)
# Deskewing
coords = np.column_stack(np.where(edges > 0))
angle = cv2.minAreaRect(coords)[-1]
if angle < -45:
angle = -(90 + angle)
else:
angle = -angle
(h, w) = edges.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
deskewed = cv2.warpAffine(edges, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
# Find contours
contours, _ = cv2.findContours(deskewed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Draw contours on the original image
contour_image = image.copy()
cv2.drawContours(contour_image, contours, -1, (0, 255, 0), 2)
return contour_image
def get_x1(detection):
return detection.xyxy[0][0]
# --- Prediction function (using session state) ---
def predict_disease(symptoms):
if st.session_state.vectorizer is not None and st.session_state.model_llm is not None:
preprocessed_symptoms = preprocess_text(symptoms)
symptoms_vectorized = st.session_state.vectorizer.transform([preprocessed_symptoms])
prediction = st.session_state.model_llm.predict(symptoms_vectorized)
return prediction[0]
else:
st.error("Unable to make prediction. Vectorizer or LLM model is not loaded.")
return None
# --- New function to analyze X-ray with LLM ---
def analyze_xray_with_llm(predicted_class):
prompt = f"""
Based on a chest X-ray analysis, the predicted condition is {predicted_class}.
Please provide a concise summary of this condition, including:
- A brief description of the condition.
- Common symptoms associated with it.
- Potential causes.
- General treatment approaches.
- Any other relevant information for a patient.
"""
llm_response = get_ai71_response(prompt)
st.write("## LLM Analysis of X-ray Results:")
st.write(llm_response)
if page == "Home":
st.markdown("## Welcome to Medi Scape")
st.write("Medi Scape is an AI-powered healthcare application designed to streamline the process of understanding and managing medical information. It leverages advanced AI models to provide features such as prescription analysis, disease detection from chest X-rays, and symptom-based diagnosis assistance.")
st.markdown("## Features")
st.write("Medi Scape provides various AI-powered tools for remote healthcare, including:")
features = [
"**AI Chatbot Diagnosis:** Interact with an AI chatbot for preliminary diagnosis and medical information.",
"**Drug Identification:** Upload a prescription image to identify medications and access relevant details.",
"**Doctor's Handwriting Identification:** Our system can accurately recognize and process doctor's handwriting.",
"**Disease Detection:** Upload a chest X-ray image to detect potential diseases.",
"**Outbreak Alert:** Stay informed about potential disease outbreaks in your area."
]
for feature in features:
st.markdown(f"- {feature}")
st.markdown("## How it Works")
steps = [
"**Upload:** You can upload a prescription image for drug identification or a chest X-ray image for disease detection.",
"**Process:** Our AI models will analyze the image and extract relevant information.",
"**Results:** You will receive identified drug names, uses, side effects, and more, or a potential disease diagnosis."
]
for i, step in enumerate(steps, 1):
st.markdown(f"{i}. {step}")
st.markdown("## Key Features")
key_features = [
"**AI-Powered:** Leverages advanced AI models for accurate analysis and diagnosis.",
"**User-Friendly:** Simple and intuitive interface for easy navigation and interaction.",
"**Secure:** Your data is protected and handled with confidentiality."
]
for feature in key_features:
st.markdown(f"- {feature}")
st.markdown("Please use the sidebar to navigate to different features.")
elif page == "AI Chatbot Diagnosis":
st.write("Enter your symptoms separated by commas:")
symptoms_input = st.text_area("Symptoms:")
if st.button("Diagnose"):
if symptoms_input:
# --- Pipeline 1 Implementation ---
# 1. Symptom Input (already done with st.text_area)
# 2. Regression Prediction
regression_prediction = predict_disease(symptoms_input)
if regression_prediction is not None:
# 3. LLM Prompt Enhancement
prompt = f"""The predicted condition based on a symptom analysis is {regression_prediction}.
Provide a detailed explanation of this condition, including possible causes, common symptoms,
and general treatment approaches. Also, suggest when a patient should consult a doctor."""
# 4. LLM Output
llm_response = get_ai71_response(prompt)
# 5. Combined Output
st.write("## Logistic Regression Prediction:")
st.write(regression_prediction)
st.write("## LLM Explanation:")
st.write(llm_response)
# --- End of Pipeline 1 Implementation ---
else:
st.write("Please enter your symptoms.")
elif page == "Drug Identification":
st.write("Upload a prescription image for drug identification.")
uploaded_file = st.file_uploader("Upload prescription", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
# Display the uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Prescription", use_column_width=True)
if st.button("Process Prescription"):
# Save the image to a temporary file
temp_image_path = "temp_image.jpg"
image.save(temp_image_path)
# Preprocess the image
preprocessed_image = preprocess_image(temp_image_path)
# Perform inference
result_doch1 = CLIENT.infer(preprocessed_image, model_id="doctor-s-handwriting/1")
# Extract labels and detections
labels = [item["class"] for item in result_doch1["predictions"]]
detections = sv.Detections.from_inference(result_doch1)
# Sort detections and labels
sorted_indices = sorted(range(len(detections)), key=lambda i: get_x1(detections[i]))
sorted_detections = [detections[i] for i in sorted_indices]
sorted_labels = [labels[i] for i in sorted_indices]
# Convert list to string
resulting_string = ''.join(sorted_labels)
# Display results
st.subheader("Processed Prescription")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
# Plot bounding boxes
image_with_boxes = preprocessed_image.copy()
for detection in sorted_detections:
x1, y1, x2, y2 = detection.xyxy[0]
cv2.rectangle(image_with_boxes, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)
ax1.imshow(cv2.cvtColor(image_with_boxes, cv2.COLOR_BGR2RGB))
ax1.set_title("Bounding Boxes")
ax1.axis('off')
# Plot labels
image_with_labels = preprocessed_image.copy()
for i, detection in enumerate(sorted_detections):
x1, y1, x2, y2 = detection.xyxy[0]
label = sorted_labels[i]
cv2.putText(image_with_labels, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
ax2.imshow(cv2.cvtColor(image_with_labels, cv2.COLOR_BGR2RGB))
ax2.set_title("Labels")
ax2.axis('off')
st.pyplot(fig)
st.write("Extracted Text from Prescription:", resulting_string)
# Prepare prompt for LLM
prompt = f"""Analyze the following prescription text:
{resulting_string}
Please provide:
1. Identified drug name(s)
2. Full name of each identified drug
3. Primary uses of each drug
4. Common side effects
5. Recommended dosage (if identifiable from the text)
6. Any warnings or precautions
7. Potential interactions with other medications (if multiple drugs are identified)
8. Any additional relevant information for the patient
If any part of the prescription is unclear or seems incomplete, please mention that and provide information about possible interpretations or matches. Always emphasize the importance of consulting a healthcare professional for accurate interpretation and advice."""
# Get LLM response
llm_response = get_ai71_response(prompt)
st.subheader("AI Analysis of the Prescription")
st.write(llm_response)
# Remove the temporary image file
os.remove(temp_image_path)
else:
st.info("Please upload a prescription image to proceed.")
elif page == "Disease Detection":
st.write("Upload a chest X-ray image for disease detection.")
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image is not None and st.session_state.disease_model is not None:
# Display the image
img_opened = Image.open(uploaded_image).convert('RGB')
image_pred = np.array(img_opened)
image_pred = cv2.resize(image_pred, (150, 150))
# Convert the image to a numpy array
image_pred = np.array(image_pred)
# Rescale the image (if the model was trained with rescaling)
image_pred = image_pred / 255.0
# Add an extra dimension to match the input shape (1, 150, 150, 3)
image_pred = np.expand_dims(image_pred, axis=0)
# Predict using the model
prediction = st.session_state.disease_model.predict(image_pred)
# Get the predicted class
predicted_ = np.argmax(prediction)
# Decode the prediction
if predicted_ == 0:
predicted_class = "Covid"
elif predicted_ == 1:
predicted_class = "Normal Chest X-ray"
else:
predicted_class = "Pneumonia"
st.image(image_pred, caption='Input image by user', use_column_width=True)
st.write("Prediction Classes for different types:")
st.write("COVID: 0")
st.write("Normal Chest X-ray: 1")
st.write("Pneumonia: 2")
st.write("\n")
st.write("DETECTED DISEASE DISPLAY")
st.write(f"Predicted Class : {predicted_}")
st.write(predicted_class)
# Analyze X-ray results with LLM
analyze_xray_with_llm(predicted_class)
else:
st.write("Please upload an image file or ensure the disease model is loaded.")
elif page == "Outbreak Alert":
st.markdown("## **Disease Outbreak News (from WHO)**")
# Fetch WHO news page
url = "https://www.who.int/news-room/events"
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
soup = BeautifulSoup(response.content, 'html.parser')
# Find news articles (adjust selectors if WHO website changes)
articles = soup.find_all('div', class_='list-view--item')
for article in articles[:5]: # Display the top 5 news articles
title_element = article.find('a', class_='link-container')
if title_element:
title = title_element.text.strip()
link = title_element['href']
date_element = article.find('span', class_='date')
date = date_element.text.strip() if date_element else "Date not found"
# Format date
date_parts = date.split()
if len(date_parts) >= 3:
try:
formatted_date = datetime.strptime(date, "%d %B %Y").strftime("%Y-%m-%d")
except ValueError:
formatted_date = date # Keep the original date if formatting fails
else:
formatted_date = date
# Display news item in a card-like container
with st.container():
st.markdown(f"**{formatted_date}**")
st.markdown(f"[{title}]({link})")
st.markdown("---")
else:
st.write("Could not find article details.")
# Auto-scroll to the bottom of the chat container
st.markdown(
"""
<script>
const chatContainer = document.querySelector('.st-chat-container');
if (chatContainer) {
chatContainer.scrollTop = chatContainer.scrollHeight;
}
</script>
""",
unsafe_allow_html=True,
)