File size: 7,674 Bytes
a444494 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict
from llama_cpp import Llama
import gradio as gr
import json
from enum import Enum
import re
class ConsultationState(Enum):
INITIAL = "initial"
GATHERING_INFO = "gathering_info"
DIAGNOSIS = "diagnosis"
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
class ChatResponse(BaseModel):
response: str
finished: bool
# Standard health assessment questions that Nurse Oge always asks
HEALTH_ASSESSMENT_QUESTIONS = [
"What are your current symptoms and how long have you been experiencing them?",
"Do you have any pre-existing medical conditions or chronic illnesses?",
"Are you currently taking any medications? If yes, please list them.",
"Is there any relevant family medical history I should know about?",
"Have you had any similar symptoms in the past? If yes, what treatments worked?"
]
# Personality prompts for Nurse Oge
NURSE_OGE_IDENTITY = """
You are Nurse Oge, a medical AI assistant focused on serving patients in Nigeria. Always be empathetic,
professional, and thorough in your assessments. When asked about your identity, explain that you are
Nurse Oge, a medical AI assistant serving Nigerian communities. Remember that you must gather complete
health information before providing any medical advice.
"""
class NurseOgeAssistant:
def __init__(self):
self.llm = Llama.from_pretrained(
repo_id="mradermacher/Llama3-Med42-8B-GGUF",
filename="Llama3-Med42-8B.IQ3_M.gguf",
verbose=False
)
self.consultation_states = {} # Tracks state for each conversation
self.gathered_info = {} # Stores gathered health information
def _is_identity_question(self, message: str) -> bool:
identity_patterns = [
r"who are you",
r"what are you",
r"your name",
r"what should I call you",
r"tell me about yourself"
]
return any(re.search(pattern, message.lower()) for pattern in identity_patterns)
def _is_location_question(self, message: str) -> bool:
location_patterns = [
r"where are you",
r"which country",
r"your location",
r"where do you work",
r"where are you based"
]
return any(re.search(pattern, message.lower()) for pattern in location_patterns)
def _get_next_assessment_question(self, conversation_id: str) -> Optional[str]:
if conversation_id not in self.gathered_info:
self.gathered_info[conversation_id] = []
questions_asked = len(self.gathered_info[conversation_id])
if questions_asked < len(HEALTH_ASSESSMENT_QUESTIONS):
return HEALTH_ASSESSMENT_QUESTIONS[questions_asked]
return None
async def process_message(self, conversation_id: str, message: str, history: List[Dict]) -> ChatResponse:
# Initialize state if new conversation
if conversation_id not in self.consultation_states:
self.consultation_states[conversation_id] = ConsultationState.INITIAL
# Handle identity questions
if self._is_identity_question(message):
return ChatResponse(
response="I am Nurse Oge, a medical AI assistant dedicated to helping patients in Nigeria. "
"I'm here to provide medical guidance while ensuring I gather all necessary health information "
"for accurate assessments.",
finished=True
)
# Handle location questions
if self._is_location_question(message):
return ChatResponse(
response="I am based in Nigeria and specifically trained to serve Nigerian communities, "
"taking into account local healthcare contexts and needs.",
finished=True
)
# Start health assessment if it's a medical query
if self.consultation_states[conversation_id] == ConsultationState.INITIAL:
self.consultation_states[conversation_id] = ConsultationState.GATHERING_INFO
next_question = self._get_next_assessment_question(conversation_id)
return ChatResponse(
response=f"Before I can provide any medical advice, I need to gather some important health information. "
f"{next_question}",
finished=False
)
# Continue gathering information
if self.consultation_states[conversation_id] == ConsultationState.GATHERING_INFO:
self.gathered_info[conversation_id].append(message)
next_question = self._get_next_assessment_question(conversation_id)
if next_question:
return ChatResponse(
response=f"Thank you for that information. {next_question}",
finished=False
)
else:
self.consultation_states[conversation_id] = ConsultationState.DIAGNOSIS
# Prepare complete context for final response
context = "\n".join([
f"Q: {q}\nA: {a}" for q, a in
zip(HEALTH_ASSESSMENT_QUESTIONS, self.gathered_info[conversation_id])
])
# Generate final response using the model
messages = [
{"role": "system", "content": NURSE_OGE_IDENTITY},
{"role": "user", "content": f"Based on the following patient information, provide a thorough assessment, diagnosis and recommendations:\n\n{context}\n\nOriginal query: {message}"}
]
response = self.llm.create_chat_completion(
messages=messages,
max_tokens=1024,
temperature=0.7
)
# Reset state for next consultation
self.consultation_states[conversation_id] = ConsultationState.INITIAL
self.gathered_info[conversation_id] = []
return ChatResponse(
response=response['choices'][0]['message']['content'],
finished=True
)
# Initialize FastAPI and Nurse Oge
app = FastAPI()
nurse_oge = NurseOgeAssistant()
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
# Generate a conversation ID (in a real app, you'd want to manage these better)
conversation_id = "default"
# Extract the latest message
if not request.messages:
raise HTTPException(status_code=400, detail="No messages provided")
latest_message = request.messages[-1].content
# Process the message
response = await nurse_oge.process_message(
conversation_id=conversation_id,
message=latest_message,
history=request.messages[:-1]
)
return response
# Initialize Gradio interface (optional, for testing)
def gradio_chat(message, history):
response = nurse_oge.process_message("gradio_user", message, history)
return response.response
demo = gr.ChatInterface(
fn=gradio_chat,
title="Nurse Oge",
description="Finetuned llama 3.0 for medical diagnosis and all. This is just a demo",
theme="soft"
)
# Mount both FastAPI and Gradio
app = gr.mount_gradio_app(app, demo, path="/gradio")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |