Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,7 @@ from sklearn.ensemble import RandomForestClassifier
|
|
18 |
from sklearn.naive_bayes import GaussianNB
|
19 |
from sklearn.metrics import accuracy_score
|
20 |
|
|
|
21 |
# Suppress TensorFlow warnings
|
22 |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
23 |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
@@ -56,19 +57,18 @@ def load_data():
|
|
56 |
try:
|
57 |
df = pd.read_csv("Training.csv")
|
58 |
tr = pd.read_csv("Testing.csv")
|
59 |
-
except FileNotFoundError
|
60 |
raise RuntimeError("Data files not found. Please ensure `Training.csv` and `Testing.csv` are uploaded correctly.")
|
61 |
-
|
62 |
disease_dict = {
|
63 |
-
# Example disease encoding dictionary, update accordingly
|
64 |
'Fungal infection': 0, 'Allergy': 1, 'GERD': 2, 'Chronic cholestasis': 3, 'Drug Reaction': 4,
|
65 |
'Peptic ulcer diseae': 5, 'AIDS': 6, 'Diabetes': 7, 'Gastroenteritis': 8, 'Bronchial Asthma': 9,
|
66 |
'Hypertension': 10, 'Migraine': 11, 'Cervical spondylosis': 12, 'Paralysis': 13,
|
67 |
'Jaundice': 14, 'Malaria': 15, 'Chicken pox': 16, 'Dengue': 17, 'Typhoid': 18,
|
68 |
'Hepatitis A': 19, 'Hepatitis B': 20, 'Hepatitis C': 21, 'Hepatitis D': 22, 'Hepatitis E': 23,
|
69 |
'Alcoholic hepatitis': 24, 'Tuberculosis': 25, 'Common Cold': 26, 'Pneumonia': 27,
|
70 |
-
'Heart attack':
|
71 |
-
'Hypoglycemia':
|
72 |
}
|
73 |
|
74 |
df.replace({'prognosis': disease_dict}, inplace=True)
|
@@ -80,7 +80,7 @@ def load_data():
|
|
80 |
return df, tr, disease_dict
|
81 |
|
82 |
df, tr, disease_dict = load_data()
|
83 |
-
l1 = list(df.columns[:-1])
|
84 |
X = df[l1]
|
85 |
y = df['prognosis']
|
86 |
X_test = tr[l1]
|
@@ -107,19 +107,31 @@ def predict_disease(model, symptoms):
|
|
107 |
if symptom in l1:
|
108 |
input_test[l1.index(symptom)] = 1
|
109 |
prediction = model.predict([input_test])[0]
|
110 |
-
|
|
|
|
|
|
|
|
|
111 |
|
112 |
def disease_prediction_interface(symptoms):
|
113 |
symptoms_selected = [s for s in symptoms if s != "None"]
|
114 |
|
115 |
if len(symptoms_selected) < 3:
|
116 |
-
return "Please select at least 3 symptoms for accurate prediction."
|
117 |
|
118 |
results = []
|
119 |
for model_name, (model, acc) in trained_models.items():
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
return results
|
124 |
|
125 |
# Helper Functions (for chatbot)
|
@@ -138,11 +150,7 @@ def generate_chatbot_response(message, history):
|
|
138 |
try:
|
139 |
result = chatbot_model.predict([bag_of_words(message, words)])
|
140 |
tag = labels[np.argmax(result)]
|
141 |
-
response = "I'm sorry, I didn't understand that. 🤔"
|
142 |
-
for intent in intents_data["intents"]:
|
143 |
-
if intent["tag"] == tag:
|
144 |
-
response = random.choice(intent["responses"])
|
145 |
-
break
|
146 |
except Exception as e:
|
147 |
response = f"Error: {e}"
|
148 |
history.append((message, response))
|
|
|
18 |
from sklearn.naive_bayes import GaussianNB
|
19 |
from sklearn.metrics import accuracy_score
|
20 |
|
21 |
+
|
22 |
# Suppress TensorFlow warnings
|
23 |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
24 |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
57 |
try:
|
58 |
df = pd.read_csv("Training.csv")
|
59 |
tr = pd.read_csv("Testing.csv")
|
60 |
+
except FileNotFoundError:
|
61 |
raise RuntimeError("Data files not found. Please ensure `Training.csv` and `Testing.csv` are uploaded correctly.")
|
62 |
+
|
63 |
disease_dict = {
|
|
|
64 |
'Fungal infection': 0, 'Allergy': 1, 'GERD': 2, 'Chronic cholestasis': 3, 'Drug Reaction': 4,
|
65 |
'Peptic ulcer diseae': 5, 'AIDS': 6, 'Diabetes': 7, 'Gastroenteritis': 8, 'Bronchial Asthma': 9,
|
66 |
'Hypertension': 10, 'Migraine': 11, 'Cervical spondylosis': 12, 'Paralysis': 13,
|
67 |
'Jaundice': 14, 'Malaria': 15, 'Chicken pox': 16, 'Dengue': 17, 'Typhoid': 18,
|
68 |
'Hepatitis A': 19, 'Hepatitis B': 20, 'Hepatitis C': 21, 'Hepatitis D': 22, 'Hepatitis E': 23,
|
69 |
'Alcoholic hepatitis': 24, 'Tuberculosis': 25, 'Common Cold': 26, 'Pneumonia': 27,
|
70 |
+
'Heart attack': 28, 'Varicose veins': 29, 'Hypothyroidism': 30, 'Hyperthyroidism': 31,
|
71 |
+
'Hypoglycemia': 32, 'Osteoarthritis': 33, 'Arthritis': 34
|
72 |
}
|
73 |
|
74 |
df.replace({'prognosis': disease_dict}, inplace=True)
|
|
|
80 |
return df, tr, disease_dict
|
81 |
|
82 |
df, tr, disease_dict = load_data()
|
83 |
+
l1 = list(df.columns[:-1])
|
84 |
X = df[l1]
|
85 |
y = df['prognosis']
|
86 |
X_test = tr[l1]
|
|
|
107 |
if symptom in l1:
|
108 |
input_test[l1.index(symptom)] = 1
|
109 |
prediction = model.predict([input_test])[0]
|
110 |
+
confidence = model.predict_proba([input_test])[0][prediction] if hasattr(model, 'predict_proba') else None
|
111 |
+
return {
|
112 |
+
"disease": list(disease_dict.keys())[list(disease_dict.values()).index(prediction)],
|
113 |
+
"confidence": confidence
|
114 |
+
}
|
115 |
|
116 |
def disease_prediction_interface(symptoms):
|
117 |
symptoms_selected = [s for s in symptoms if s != "None"]
|
118 |
|
119 |
if len(symptoms_selected) < 3:
|
120 |
+
return ["Please select at least 3 symptoms for accurate prediction."]
|
121 |
|
122 |
results = []
|
123 |
for model_name, (model, acc) in trained_models.items():
|
124 |
+
prediction_info = predict_disease(model, symptoms_selected)
|
125 |
+
predicted_disease = prediction_info["disease"]
|
126 |
+
confidence_score = prediction_info["confidence"]
|
127 |
+
|
128 |
+
result = f"{model_name} Prediction: Predicted Disease: **{predicted_disease}**"
|
129 |
+
if confidence_score is not None:
|
130 |
+
result += f" (Confidence: {confidence_score:.2f})"
|
131 |
+
result += f" (Accuracy: {acc * 100:.2f}%)"
|
132 |
+
|
133 |
+
results.append(result)
|
134 |
+
|
135 |
return results
|
136 |
|
137 |
# Helper Functions (for chatbot)
|
|
|
150 |
try:
|
151 |
result = chatbot_model.predict([bag_of_words(message, words)])
|
152 |
tag = labels[np.argmax(result)]
|
153 |
+
response = next((random.choice(intent["responses"]) for intent in intents_data["intents"] if intent["tag"] == tag), "I'm sorry, I didn't understand that. 🤔")
|
|
|
|
|
|
|
|
|
154 |
except Exception as e:
|
155 |
response = f"Error: {e}"
|
156 |
history.append((message, response))
|