Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -57,8 +57,20 @@ class MedicalTriageBot:
|
|
57 |
"999", "111", "emergency", "GP", "NHS",
|
58 |
"consult a doctor", "seek medical attention"
|
59 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def setup_model(self, model_name="google/gemma-7b-it"):
|
|
|
62 |
if self.model is not None:
|
63 |
return
|
64 |
|
@@ -117,7 +129,16 @@ class MedicalTriageBot:
|
|
117 |
"triage_guidelines.txt": """Triage Guidelines:
|
118 |
- Persistent fever >48h: Contact 111
|
119 |
- Minor injuries: Visit urgent care
|
120 |
-
- Medication questions: Consult GP"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
}
|
122 |
|
123 |
os.makedirs("medical_knowledge", exist_ok=True)
|
@@ -152,17 +173,22 @@ class MedicalTriageBot:
|
|
152 |
@torch.inference_mode()
|
153 |
def generate_safe_response(self, message, history):
|
154 |
try:
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
158 |
|
159 |
-
#
|
160 |
-
|
161 |
-
|
162 |
-
)
|
163 |
|
164 |
-
#
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
# Create safety-focused prompt
|
168 |
prompt = f"""<start_of_turn>system
|
@@ -213,6 +239,99 @@ Guidelines:
|
|
213 |
logger.error(f"Generation error: {e}")
|
214 |
return "Please contact NHS 111 directly for urgent medical advice."
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
# ===========================
|
217 |
# π¬ SAFE GRADIO INTERFACE
|
218 |
# ===========================
|
@@ -221,10 +340,25 @@ def create_medical_interface():
|
|
221 |
bot.setup_model()
|
222 |
bot.setup_rag_system()
|
223 |
|
|
|
|
|
|
|
|
|
|
|
224 |
def handle_conversation(message, history):
|
225 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
response = bot.generate_safe_response(message, history)
|
227 |
return history + [(message, response)]
|
|
|
228 |
except Exception as e:
|
229 |
logger.error(f"Conversation error: {e}")
|
230 |
return history + [(message, "System error - please refresh the page")]
|
@@ -235,7 +369,7 @@ def create_medical_interface():
|
|
235 |
|
236 |
with gr.Row():
|
237 |
chatbot = gr.Chatbot(
|
238 |
-
value=[("", "Hello! I'm your
|
239 |
height=500,
|
240 |
label="Medical Triage Chat"
|
241 |
)
|
@@ -264,7 +398,7 @@ def create_medical_interface():
|
|
264 |
).then(lambda: "", None, [message_input])
|
265 |
|
266 |
clear_btn.click(
|
267 |
-
lambda: [("", "Hello! I'm your
|
268 |
None,
|
269 |
[chatbot]
|
270 |
)
|
|
|
57 |
"999", "111", "emergency", "GP", "NHS",
|
58 |
"consult a doctor", "seek medical attention"
|
59 |
]
|
60 |
+
self.current_case = None
|
61 |
+
self.location_services = {
|
62 |
+
"London": {"A&E": ["St Thomas' Hospital", "Royal London Hospital"],
|
63 |
+
"Urgent Care": ["UCLH Urgent Care Centre"]},
|
64 |
+
"Manchester": {"A&E": ["Manchester Royal Infirmary"]}
|
65 |
+
}
|
66 |
+
self.base_questions = [
|
67 |
+
("duration", "How long have you experienced this?"),
|
68 |
+
("severity", "On a scale of 1-10, how severe is it?"),
|
69 |
+
("emergency_signs", "Any difficulty breathing, chest pain, or confusion?")
|
70 |
+
]
|
71 |
|
72 |
def setup_model(self, model_name="google/gemma-7b-it"):
|
73 |
+
"""Initialize the medical AI model with 4-bit quantization"""
|
74 |
if self.model is not None:
|
75 |
return
|
76 |
|
|
|
129 |
"triage_guidelines.txt": """Triage Guidelines:
|
130 |
- Persistent fever >48h: Contact 111
|
131 |
- Minor injuries: Visit urgent care
|
132 |
+
- Medication questions: Consult GP""",
|
133 |
+
"gp_care.txt": """GP Management:
|
134 |
+
- Persistent cough >3 weeks: Book GP
|
135 |
+
- Medication reviews: Schedule appointment
|
136 |
+
- Chronic condition management""",
|
137 |
+
|
138 |
+
"urgent_care.txt": """Urgent Care:
|
139 |
+
- Minor burns: Visit urgent care
|
140 |
+
- Sprains: Same-day urgent care
|
141 |
+
- Ear infections: Walk-in centre"""
|
142 |
}
|
143 |
|
144 |
os.makedirs("medical_knowledge", exist_ok=True)
|
|
|
173 |
@torch.inference_mode()
|
174 |
def generate_safe_response(self, message, history):
|
175 |
try:
|
176 |
+
if self.current_case is None:
|
177 |
+
context = self.get_medical_context(message)
|
178 |
+
self._initialize_case(context)
|
179 |
+
return self._next_question()
|
180 |
|
181 |
+
# Existing state handling code
|
182 |
+
if self.current_case["stage"] == "investigation":
|
183 |
+
return self._handle_investigation(message, history)
|
|
|
184 |
|
185 |
+
# Add model inference HERE
|
186 |
+
inputs = self.tokenizer(
|
187 |
+
self._build_prompt(message, history),
|
188 |
+
return_tensors="pt",
|
189 |
+
truncation=True,
|
190 |
+
max_length=1024
|
191 |
+
).to(self.model.device)
|
192 |
|
193 |
# Create safety-focused prompt
|
194 |
prompt = f"""<start_of_turn>system
|
|
|
239 |
logger.error(f"Generation error: {e}")
|
240 |
return "Please contact NHS 111 directly for urgent medical advice."
|
241 |
|
242 |
+
def _initialize_case(self, message):
|
243 |
+
context = self.get_medical_context(message)
|
244 |
+
self.current_case = {
|
245 |
+
"symptoms": self._detect_symptoms(context),
|
246 |
+
"responses": {},
|
247 |
+
"stage": "investigation",
|
248 |
+
"step": 0,
|
249 |
+
"location": None,
|
250 |
+
"guidelines": context
|
251 |
+
}
|
252 |
+
|
253 |
+
def _detect_symptoms(self, context):
|
254 |
+
return list(set(
|
255 |
+
line.split(":")[0].strip()
|
256 |
+
for line in context.split("\n")
|
257 |
+
if ":" in line
|
258 |
+
))
|
259 |
+
|
260 |
+
def _next_question(self):
|
261 |
+
questions = self.base_questions + self._generate_custom_questions()
|
262 |
+
if self.current_case["step"] < len(questions):
|
263 |
+
return questions[self.current_case["step"]][1]
|
264 |
+
self.current_case["stage"] = "location"
|
265 |
+
return "Could you share your UK postcode for local service recommendations?"
|
266 |
+
|
267 |
+
def _generate_custom_questions(self):
|
268 |
+
return [(f"custom_{i}", q.split(":")[1].strip())
|
269 |
+
for q in self.current_case["guidelines"].split("\n")
|
270 |
+
if "?" in q]
|
271 |
+
|
272 |
+
def _handle_investigation(self, message, history):
|
273 |
+
self.current_case["responses"][self.base_questions[self.current_case["step"]][0]] = message
|
274 |
+
self.current_case["step"] += 1
|
275 |
+
return self._next_question()
|
276 |
+
|
277 |
+
def _handle_location(self, message):
|
278 |
+
self.current_case["location"] = self._get_location(message)
|
279 |
+
self.current_case["stage"] = "recommendation"
|
280 |
+
return self._final_recommendation()
|
281 |
+
|
282 |
+
def _final_recommendation(self):
|
283 |
+
action = self._determine_action()
|
284 |
+
location_info = self._get_location_services()
|
285 |
+
self.current_case = None
|
286 |
+
return f"{action}\n\n{location_info}"
|
287 |
+
|
288 |
+
def _determine_action(self):
|
289 |
+
if self._is_emergency():
|
290 |
+
return "π Call 999 immediately. I can stay on the line with you."
|
291 |
+
if self._needs_gp():
|
292 |
+
return "π
Please book a GP appointment. Would you like me to help with that?"
|
293 |
+
return "π₯ Visit your nearest urgent care centre:"
|
294 |
+
|
295 |
+
def _get_location_services(self):
|
296 |
+
if not self.current_case["location"]:
|
297 |
+
return "Find local services: https://www.nhs.uk/service-search"
|
298 |
+
return "\n".join([
|
299 |
+
f"{service_type}: {', '.join(services)}"
|
300 |
+
for service_type, services in
|
301 |
+
self.location_services.get(self.current_case["location"], {}).items()
|
302 |
+
])
|
303 |
+
|
304 |
+
def _is_emergency(self):
|
305 |
+
return any(keyword in self.current_case["guidelines"]
|
306 |
+
for keyword in ["999", "emergency", "stroke"])
|
307 |
+
|
308 |
+
def _needs_gp(self):
|
309 |
+
return any(keyword in self.current_case["guidelines"]
|
310 |
+
for keyword in ["GP", "appointment", "persistent"])
|
311 |
+
|
312 |
+
def _get_location(self, postcode):
|
313 |
+
return "London" if postcode.startswith("L") else "Manchester"
|
314 |
+
|
315 |
+
def _build_prompt(self, message, history):
|
316 |
+
conversation = "\n".join([f"User: {user}\nAssistant: {bot}" for user, bot in history[-3:]])
|
317 |
+
context = self.get_medical_context(message)
|
318 |
+
|
319 |
+
return f"""<start_of_turn>system
|
320 |
+
Context:
|
321 |
+
{context}
|
322 |
+
Conversation:
|
323 |
+
{conversation}
|
324 |
+
Guidelines:
|
325 |
+
1. Follow investigation flow
|
326 |
+
2. Consider location: {self.current_case.get('location', 'unknown')}
|
327 |
+
3. Maintain safety protocols
|
328 |
+
<end_of_turn>
|
329 |
+
<start_of_turn>user
|
330 |
+
{message}
|
331 |
+
<end_of_turn>
|
332 |
+
<start_of_turn>assistant"""
|
333 |
+
|
334 |
+
|
335 |
# ===========================
|
336 |
# π¬ SAFE GRADIO INTERFACE
|
337 |
# ===========================
|
|
|
340 |
bot.setup_model()
|
341 |
bot.setup_rag_system()
|
342 |
|
343 |
+
def create_medical_interface():
|
344 |
+
bot = MedicalTriageBot()
|
345 |
+
bot.setup_model()
|
346 |
+
bot.setup_rag_system()
|
347 |
+
|
348 |
def handle_conversation(message, history):
|
349 |
try:
|
350 |
+
# Handle GP booking requests
|
351 |
+
if "book gp" in message.lower():
|
352 |
+
return history + [(message, "Redirecting to GP booking system...")]
|
353 |
+
|
354 |
+
# Handle location input
|
355 |
+
if any(word in message.lower() for word in ["postcode", "zip code", "location"]):
|
356 |
+
return history + [(message, "Please enter your UK postcode:")]
|
357 |
+
|
358 |
+
# Normal symptom processing
|
359 |
response = bot.generate_safe_response(message, history)
|
360 |
return history + [(message, response)]
|
361 |
+
|
362 |
except Exception as e:
|
363 |
logger.error(f"Conversation error: {e}")
|
364 |
return history + [(message, "System error - please refresh the page")]
|
|
|
369 |
|
370 |
with gr.Row():
|
371 |
chatbot = gr.Chatbot(
|
372 |
+
value=[("", "Hello! I'm Pearly, your digital assistant. How can I help you today?")],
|
373 |
height=500,
|
374 |
label="Medical Triage Chat"
|
375 |
)
|
|
|
398 |
).then(lambda: "", None, [message_input])
|
399 |
|
400 |
clear_btn.click(
|
401 |
+
lambda: [("", "Hello! I'm Pearly, your digital assistant. How can I help you today?")],
|
402 |
None,
|
403 |
[chatbot]
|
404 |
)
|