Nurses / app.py
benardo0's picture
Create app.py
a444494 verified
raw
history blame
7.67 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict
from llama_cpp import Llama
import gradio as gr
import json
from enum import Enum
import re
class ConsultationState(Enum):
INITIAL = "initial"
GATHERING_INFO = "gathering_info"
DIAGNOSIS = "diagnosis"
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
class ChatResponse(BaseModel):
response: str
finished: bool
# Standard health assessment questions that Nurse Oge always asks
HEALTH_ASSESSMENT_QUESTIONS = [
"What are your current symptoms and how long have you been experiencing them?",
"Do you have any pre-existing medical conditions or chronic illnesses?",
"Are you currently taking any medications? If yes, please list them.",
"Is there any relevant family medical history I should know about?",
"Have you had any similar symptoms in the past? If yes, what treatments worked?"
]
# Personality prompts for Nurse Oge
NURSE_OGE_IDENTITY = """
You are Nurse Oge, a medical AI assistant focused on serving patients in Nigeria. Always be empathetic,
professional, and thorough in your assessments. When asked about your identity, explain that you are
Nurse Oge, a medical AI assistant serving Nigerian communities. Remember that you must gather complete
health information before providing any medical advice.
"""
class NurseOgeAssistant:
def __init__(self):
self.llm = Llama.from_pretrained(
repo_id="mradermacher/Llama3-Med42-8B-GGUF",
filename="Llama3-Med42-8B.IQ3_M.gguf",
verbose=False
)
self.consultation_states = {} # Tracks state for each conversation
self.gathered_info = {} # Stores gathered health information
def _is_identity_question(self, message: str) -> bool:
identity_patterns = [
r"who are you",
r"what are you",
r"your name",
r"what should I call you",
r"tell me about yourself"
]
return any(re.search(pattern, message.lower()) for pattern in identity_patterns)
def _is_location_question(self, message: str) -> bool:
location_patterns = [
r"where are you",
r"which country",
r"your location",
r"where do you work",
r"where are you based"
]
return any(re.search(pattern, message.lower()) for pattern in location_patterns)
def _get_next_assessment_question(self, conversation_id: str) -> Optional[str]:
if conversation_id not in self.gathered_info:
self.gathered_info[conversation_id] = []
questions_asked = len(self.gathered_info[conversation_id])
if questions_asked < len(HEALTH_ASSESSMENT_QUESTIONS):
return HEALTH_ASSESSMENT_QUESTIONS[questions_asked]
return None
async def process_message(self, conversation_id: str, message: str, history: List[Dict]) -> ChatResponse:
# Initialize state if new conversation
if conversation_id not in self.consultation_states:
self.consultation_states[conversation_id] = ConsultationState.INITIAL
# Handle identity questions
if self._is_identity_question(message):
return ChatResponse(
response="I am Nurse Oge, a medical AI assistant dedicated to helping patients in Nigeria. "
"I'm here to provide medical guidance while ensuring I gather all necessary health information "
"for accurate assessments.",
finished=True
)
# Handle location questions
if self._is_location_question(message):
return ChatResponse(
response="I am based in Nigeria and specifically trained to serve Nigerian communities, "
"taking into account local healthcare contexts and needs.",
finished=True
)
# Start health assessment if it's a medical query
if self.consultation_states[conversation_id] == ConsultationState.INITIAL:
self.consultation_states[conversation_id] = ConsultationState.GATHERING_INFO
next_question = self._get_next_assessment_question(conversation_id)
return ChatResponse(
response=f"Before I can provide any medical advice, I need to gather some important health information. "
f"{next_question}",
finished=False
)
# Continue gathering information
if self.consultation_states[conversation_id] == ConsultationState.GATHERING_INFO:
self.gathered_info[conversation_id].append(message)
next_question = self._get_next_assessment_question(conversation_id)
if next_question:
return ChatResponse(
response=f"Thank you for that information. {next_question}",
finished=False
)
else:
self.consultation_states[conversation_id] = ConsultationState.DIAGNOSIS
# Prepare complete context for final response
context = "\n".join([
f"Q: {q}\nA: {a}" for q, a in
zip(HEALTH_ASSESSMENT_QUESTIONS, self.gathered_info[conversation_id])
])
# Generate final response using the model
messages = [
{"role": "system", "content": NURSE_OGE_IDENTITY},
{"role": "user", "content": f"Based on the following patient information, provide a thorough assessment, diagnosis and recommendations:\n\n{context}\n\nOriginal query: {message}"}
]
response = self.llm.create_chat_completion(
messages=messages,
max_tokens=1024,
temperature=0.7
)
# Reset state for next consultation
self.consultation_states[conversation_id] = ConsultationState.INITIAL
self.gathered_info[conversation_id] = []
return ChatResponse(
response=response['choices'][0]['message']['content'],
finished=True
)
# Initialize FastAPI and Nurse Oge
app = FastAPI()
nurse_oge = NurseOgeAssistant()
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
# Generate a conversation ID (in a real app, you'd want to manage these better)
conversation_id = "default"
# Extract the latest message
if not request.messages:
raise HTTPException(status_code=400, detail="No messages provided")
latest_message = request.messages[-1].content
# Process the message
response = await nurse_oge.process_message(
conversation_id=conversation_id,
message=latest_message,
history=request.messages[:-1]
)
return response
# Initialize Gradio interface (optional, for testing)
def gradio_chat(message, history):
response = nurse_oge.process_message("gradio_user", message, history)
return response.response
demo = gr.ChatInterface(
fn=gradio_chat,
title="Nurse Oge",
description="Finetuned llama 3.0 for medical diagnosis and all. This is just a demo",
theme="soft"
)
# Mount both FastAPI and Gradio
app = gr.mount_gradio_app(app, demo, path="/gradio")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)