Spaces:
Running
Running
File size: 6,243 Bytes
8dcd1f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import pandas as pd
import json
import numpy as np
import faiss
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import pipeline
# -------------------------------
# Load disease data and preprocess
# -------------------------------
def load_disease_data(csv_path):
df = pd.read_csv(csv_path)
df.columns = df.columns.str.strip().str.lower()
df = df.fillna("")
disease_symptoms = {}
disease_precautions = {}
for _, row in df.iterrows():
disease = row["disease"].strip()
symptoms = [s.strip().lower() for s in row["symptoms"].split(",") if s.strip()]
precautions = [p.strip() for p in row["precautions"].split(",") if p.strip()]
disease_symptoms[disease] = symptoms
disease_precautions[disease] = precautions
return disease_symptoms, disease_precautions
# Load CSV data (ensure this CSV file is in the repository root)
disease_symptoms, disease_precautions = load_disease_data("disease_sympts_prec_full.csv")
known_symptoms = set()
for syms in disease_symptoms.values():
known_symptoms.update(syms)
# -------------------------------
# Build symptom vectorizer and FAISS index
# -------------------------------
vectorizer = TfidfVectorizer()
symptom_texts = [" ".join(symptoms) for symptoms in disease_symptoms.values()]
tfidf_matrix = vectorizer.fit_transform(symptom_texts).toarray()
index = faiss.IndexFlatL2(tfidf_matrix.shape[1])
index.add(np.array(tfidf_matrix, dtype=np.float32))
disease_list = list(disease_symptoms.keys())
def find_closest_disease(user_symptoms):
if not user_symptoms:
return None
user_vector = vectorizer.transform([" ".join(user_symptoms)]).toarray().astype("float32")
distances, indices = index.search(user_vector, k=1)
return disease_list[indices[0][0]]
# -------------------------------
# Load Medical NER model for symptom extraction
# -------------------------------
medical_ner = pipeline(
"ner",
model="blaze999/Medical-NER",
tokenizer="blaze999/Medical-NER",
aggregation_strategy="simple"
)
def extract_symptoms_ner(text):
results = medical_ner(text)
extracted = []
for r in results:
if "SIGN_SYMPTOM" in r["entity_group"]:
extracted.append(r["word"].lower())
return list(set(extracted))
def is_affirmative(answer):
answer_lower = answer.lower()
return any(word in answer_lower for word in ["yes", "yeah", "yep", "certainly", "sometimes", "a little"])
# -------------------------------
# Chatbot session class
# -------------------------------
class ChatbotSession:
def __init__(self):
self.conversation_history = []
self.reported_symptoms = set()
self.asked_missing = set()
self.awaiting_followup = None
self.state = "symptom_collection" # states: symptom_collection, pain, medications
# Initial greeting
greeting = "Doctor: Hello, I am your virtual doctor. What brought you in today?"
self.conversation_history.append(greeting)
self.finished = False
def process_message(self, message: str) -> str:
# State: collecting symptoms
if self.state == "symptom_collection":
if message.lower() in ["exit", "quit", "no"]:
self.state = "pain"
prompt = "Doctor: Do you experience any pain or aches? Please rate the pain on a scale of 1 to 10 (or type 'no' if none):"
self.conversation_history.append(prompt)
return prompt
# If we are waiting on a follow-up about a specific symptom
if self.awaiting_followup:
if is_affirmative(message):
self.reported_symptoms.add(self.awaiting_followup)
self.asked_missing.add(self.awaiting_followup)
self.awaiting_followup = None
else:
# Extract symptoms from message text
ner_results = extract_symptoms_ner(message)
for sym in ner_results:
if sym not in self.reported_symptoms:
self.reported_symptoms.add(sym)
# Update predicted disease
predicted_disease = find_closest_disease(list(self.reported_symptoms)) if self.reported_symptoms else None
# Check for missing symptoms if a disease is predicted
if predicted_disease:
expected = set(disease_symptoms.get(predicted_disease, []))
missing = expected - self.reported_symptoms
not_asked = missing - self.asked_missing
if not_asked:
symptom_to_ask = list(not_asked)[0]
followup = f"Are you also experiencing {symptom_to_ask}?"
self.conversation_history.append("Doctor: " + followup)
self.awaiting_followup = symptom_to_ask
return followup
prompt = "Doctor: Do you have any other symptoms you'd like to mention?"
self.conversation_history.append(prompt)
return prompt
# State: asking about pain
elif self.state == "pain":
try:
self.pain_level = int(message)
except ValueError:
self.pain_level = message
self.state = "medications"
prompt = "Doctor: Have you taken any medications recently? If yes, please specify (or type 'no' if none):"
self.conversation_history.append(prompt)
return prompt
# State: asking about medications
elif self.state == "medications":
self.medications = message if message.lower() not in ["no", "none"] else "None"
closing = "Doctor: Thank you for providing all the information."
self.conversation_history.append(closing)
self.finished = True
return closing
return "Doctor: I'm sorry, I didn't understand that."
def get_data(self):
return {
"conversation": self.conversation_history,
"symptoms": list(self.reported_symptoms),
"pain_level": getattr(self, "pain_level", None),
"medications": getattr(self, "medications", None)
}
|