Uspark / app /chatbot.py
Last commit not found
raw
history blame
6.24 kB
import pandas as pd
import json
import numpy as np
import faiss
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import pipeline
# -------------------------------
# Load disease data and preprocess
# -------------------------------
def load_disease_data(csv_path):
df = pd.read_csv(csv_path)
df.columns = df.columns.str.strip().str.lower()
df = df.fillna("")
disease_symptoms = {}
disease_precautions = {}
for _, row in df.iterrows():
disease = row["disease"].strip()
symptoms = [s.strip().lower() for s in row["symptoms"].split(",") if s.strip()]
precautions = [p.strip() for p in row["precautions"].split(",") if p.strip()]
disease_symptoms[disease] = symptoms
disease_precautions[disease] = precautions
return disease_symptoms, disease_precautions
# Load CSV data (ensure this CSV file is in the repository root)
disease_symptoms, disease_precautions = load_disease_data("disease_sympts_prec_full.csv")
known_symptoms = set()
for syms in disease_symptoms.values():
known_symptoms.update(syms)
# -------------------------------
# Build symptom vectorizer and FAISS index
# -------------------------------
vectorizer = TfidfVectorizer()
symptom_texts = [" ".join(symptoms) for symptoms in disease_symptoms.values()]
tfidf_matrix = vectorizer.fit_transform(symptom_texts).toarray()
index = faiss.IndexFlatL2(tfidf_matrix.shape[1])
index.add(np.array(tfidf_matrix, dtype=np.float32))
disease_list = list(disease_symptoms.keys())
def find_closest_disease(user_symptoms):
if not user_symptoms:
return None
user_vector = vectorizer.transform([" ".join(user_symptoms)]).toarray().astype("float32")
distances, indices = index.search(user_vector, k=1)
return disease_list[indices[0][0]]
# -------------------------------
# Load Medical NER model for symptom extraction
# -------------------------------
medical_ner = pipeline(
"ner",
model="blaze999/Medical-NER",
tokenizer="blaze999/Medical-NER",
aggregation_strategy="simple"
)
def extract_symptoms_ner(text):
results = medical_ner(text)
extracted = []
for r in results:
if "SIGN_SYMPTOM" in r["entity_group"]:
extracted.append(r["word"].lower())
return list(set(extracted))
def is_affirmative(answer):
answer_lower = answer.lower()
return any(word in answer_lower for word in ["yes", "yeah", "yep", "certainly", "sometimes", "a little"])
# -------------------------------
# Chatbot session class
# -------------------------------
class ChatbotSession:
def __init__(self):
self.conversation_history = []
self.reported_symptoms = set()
self.asked_missing = set()
self.awaiting_followup = None
self.state = "symptom_collection" # states: symptom_collection, pain, medications
# Initial greeting
greeting = "Doctor: Hello, I am your virtual doctor. What brought you in today?"
self.conversation_history.append(greeting)
self.finished = False
def process_message(self, message: str) -> str:
# State: collecting symptoms
if self.state == "symptom_collection":
if message.lower() in ["exit", "quit", "no"]:
self.state = "pain"
prompt = "Doctor: Do you experience any pain or aches? Please rate the pain on a scale of 1 to 10 (or type 'no' if none):"
self.conversation_history.append(prompt)
return prompt
# If we are waiting on a follow-up about a specific symptom
if self.awaiting_followup:
if is_affirmative(message):
self.reported_symptoms.add(self.awaiting_followup)
self.asked_missing.add(self.awaiting_followup)
self.awaiting_followup = None
else:
# Extract symptoms from message text
ner_results = extract_symptoms_ner(message)
for sym in ner_results:
if sym not in self.reported_symptoms:
self.reported_symptoms.add(sym)
# Update predicted disease
predicted_disease = find_closest_disease(list(self.reported_symptoms)) if self.reported_symptoms else None
# Check for missing symptoms if a disease is predicted
if predicted_disease:
expected = set(disease_symptoms.get(predicted_disease, []))
missing = expected - self.reported_symptoms
not_asked = missing - self.asked_missing
if not_asked:
symptom_to_ask = list(not_asked)[0]
followup = f"Are you also experiencing {symptom_to_ask}?"
self.conversation_history.append("Doctor: " + followup)
self.awaiting_followup = symptom_to_ask
return followup
prompt = "Doctor: Do you have any other symptoms you'd like to mention?"
self.conversation_history.append(prompt)
return prompt
# State: asking about pain
elif self.state == "pain":
try:
self.pain_level = int(message)
except ValueError:
self.pain_level = message
self.state = "medications"
prompt = "Doctor: Have you taken any medications recently? If yes, please specify (or type 'no' if none):"
self.conversation_history.append(prompt)
return prompt
# State: asking about medications
elif self.state == "medications":
self.medications = message if message.lower() not in ["no", "none"] else "None"
closing = "Doctor: Thank you for providing all the information."
self.conversation_history.append(closing)
self.finished = True
return closing
return "Doctor: I'm sorry, I didn't understand that."
def get_data(self):
return {
"conversation": self.conversation_history,
"symptoms": list(self.reported_symptoms),
"pain_level": getattr(self, "pain_level", None),
"medications": getattr(self, "medications", None)
}