Create app.py
Browse filesapp.py for project
app.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import List, Optional, Dict
|
4 |
+
from llama_cpp import Llama
|
5 |
+
import gradio as gr
|
6 |
+
import json
|
7 |
+
from enum import Enum
|
8 |
+
import re
|
9 |
+
|
10 |
+
class ConsultationState(Enum):
|
11 |
+
INITIAL = "initial"
|
12 |
+
GATHERING_INFO = "gathering_info"
|
13 |
+
DIAGNOSIS = "diagnosis"
|
14 |
+
|
15 |
+
class Message(BaseModel):
|
16 |
+
role: str
|
17 |
+
content: str
|
18 |
+
|
19 |
+
class ChatRequest(BaseModel):
|
20 |
+
messages: List[Message]
|
21 |
+
|
22 |
+
class ChatResponse(BaseModel):
|
23 |
+
response: str
|
24 |
+
finished: bool
|
25 |
+
|
26 |
+
# Standard health assessment questions that Nurse Oge always asks
|
27 |
+
HEALTH_ASSESSMENT_QUESTIONS = [
|
28 |
+
"What are your current symptoms and how long have you been experiencing them?",
|
29 |
+
"Do you have any pre-existing medical conditions or chronic illnesses?",
|
30 |
+
"Are you currently taking any medications? If yes, please list them.",
|
31 |
+
"Is there any relevant family medical history I should know about?",
|
32 |
+
"Have you had any similar symptoms in the past? If yes, what treatments worked?"
|
33 |
+
]
|
34 |
+
|
35 |
+
# Personality prompts for Nurse Oge
|
36 |
+
NURSE_OGE_IDENTITY = """
|
37 |
+
You are Nurse Oge, a medical AI assistant focused on serving patients in Nigeria. Always be empathetic,
|
38 |
+
professional, and thorough in your assessments. When asked about your identity, explain that you are
|
39 |
+
Nurse Oge, a medical AI assistant serving Nigerian communities. Remember that you must gather complete
|
40 |
+
health information before providing any medical advice.
|
41 |
+
"""
|
42 |
+
|
43 |
+
class NurseOgeAssistant:
|
44 |
+
def __init__(self):
|
45 |
+
self.llm = Llama.from_pretrained(
|
46 |
+
repo_id="mradermacher/Llama3-Med42-8B-GGUF",
|
47 |
+
filename="Llama3-Med42-8B.IQ3_M.gguf",
|
48 |
+
verbose=False
|
49 |
+
)
|
50 |
+
self.consultation_states = {} # Tracks state for each conversation
|
51 |
+
self.gathered_info = {} # Stores gathered health information
|
52 |
+
|
53 |
+
def _is_identity_question(self, message: str) -> bool:
|
54 |
+
identity_patterns = [
|
55 |
+
r"who are you",
|
56 |
+
r"what are you",
|
57 |
+
r"your name",
|
58 |
+
r"what should I call you",
|
59 |
+
r"tell me about yourself"
|
60 |
+
]
|
61 |
+
return any(re.search(pattern, message.lower()) for pattern in identity_patterns)
|
62 |
+
|
63 |
+
def _is_location_question(self, message: str) -> bool:
|
64 |
+
location_patterns = [
|
65 |
+
r"where are you",
|
66 |
+
r"which country",
|
67 |
+
r"your location",
|
68 |
+
r"where do you work",
|
69 |
+
r"where are you based"
|
70 |
+
]
|
71 |
+
return any(re.search(pattern, message.lower()) for pattern in location_patterns)
|
72 |
+
|
73 |
+
def _get_next_assessment_question(self, conversation_id: str) -> Optional[str]:
|
74 |
+
if conversation_id not in self.gathered_info:
|
75 |
+
self.gathered_info[conversation_id] = []
|
76 |
+
|
77 |
+
questions_asked = len(self.gathered_info[conversation_id])
|
78 |
+
if questions_asked < len(HEALTH_ASSESSMENT_QUESTIONS):
|
79 |
+
return HEALTH_ASSESSMENT_QUESTIONS[questions_asked]
|
80 |
+
return None
|
81 |
+
|
82 |
+
async def process_message(self, conversation_id: str, message: str, history: List[Dict]) -> ChatResponse:
|
83 |
+
# Initialize state if new conversation
|
84 |
+
if conversation_id not in self.consultation_states:
|
85 |
+
self.consultation_states[conversation_id] = ConsultationState.INITIAL
|
86 |
+
|
87 |
+
# Handle identity questions
|
88 |
+
if self._is_identity_question(message):
|
89 |
+
return ChatResponse(
|
90 |
+
response="I am Nurse Oge, a medical AI assistant dedicated to helping patients in Nigeria. "
|
91 |
+
"I'm here to provide medical guidance while ensuring I gather all necessary health information "
|
92 |
+
"for accurate assessments.",
|
93 |
+
finished=True
|
94 |
+
)
|
95 |
+
|
96 |
+
# Handle location questions
|
97 |
+
if self._is_location_question(message):
|
98 |
+
return ChatResponse(
|
99 |
+
response="I am based in Nigeria and specifically trained to serve Nigerian communities, "
|
100 |
+
"taking into account local healthcare contexts and needs.",
|
101 |
+
finished=True
|
102 |
+
)
|
103 |
+
|
104 |
+
# Start health assessment if it's a medical query
|
105 |
+
if self.consultation_states[conversation_id] == ConsultationState.INITIAL:
|
106 |
+
self.consultation_states[conversation_id] = ConsultationState.GATHERING_INFO
|
107 |
+
next_question = self._get_next_assessment_question(conversation_id)
|
108 |
+
return ChatResponse(
|
109 |
+
response=f"Before I can provide any medical advice, I need to gather some important health information. "
|
110 |
+
f"{next_question}",
|
111 |
+
finished=False
|
112 |
+
)
|
113 |
+
|
114 |
+
# Continue gathering information
|
115 |
+
if self.consultation_states[conversation_id] == ConsultationState.GATHERING_INFO:
|
116 |
+
self.gathered_info[conversation_id].append(message)
|
117 |
+
next_question = self._get_next_assessment_question(conversation_id)
|
118 |
+
|
119 |
+
if next_question:
|
120 |
+
return ChatResponse(
|
121 |
+
response=f"Thank you for that information. {next_question}",
|
122 |
+
finished=False
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
self.consultation_states[conversation_id] = ConsultationState.DIAGNOSIS
|
126 |
+
# Prepare complete context for final response
|
127 |
+
context = "\n".join([
|
128 |
+
f"Q: {q}\nA: {a}" for q, a in
|
129 |
+
zip(HEALTH_ASSESSMENT_QUESTIONS, self.gathered_info[conversation_id])
|
130 |
+
])
|
131 |
+
|
132 |
+
# Generate final response using the model
|
133 |
+
messages = [
|
134 |
+
{"role": "system", "content": NURSE_OGE_IDENTITY},
|
135 |
+
{"role": "user", "content": f"Based on the following patient information, provide a thorough assessment, diagnosis and recommendations:\n\n{context}\n\nOriginal query: {message}"}
|
136 |
+
]
|
137 |
+
|
138 |
+
response = self.llm.create_chat_completion(
|
139 |
+
messages=messages,
|
140 |
+
max_tokens=1024,
|
141 |
+
temperature=0.7
|
142 |
+
)
|
143 |
+
|
144 |
+
# Reset state for next consultation
|
145 |
+
self.consultation_states[conversation_id] = ConsultationState.INITIAL
|
146 |
+
self.gathered_info[conversation_id] = []
|
147 |
+
|
148 |
+
return ChatResponse(
|
149 |
+
response=response['choices'][0]['message']['content'],
|
150 |
+
finished=True
|
151 |
+
)
|
152 |
+
|
153 |
+
# Initialize FastAPI and Nurse Oge
|
154 |
+
app = FastAPI()
|
155 |
+
nurse_oge = NurseOgeAssistant()
|
156 |
+
|
157 |
+
@app.post("/chat")
|
158 |
+
async def chat_endpoint(request: ChatRequest):
|
159 |
+
# Generate a conversation ID (in a real app, you'd want to manage these better)
|
160 |
+
conversation_id = "default"
|
161 |
+
|
162 |
+
# Extract the latest message
|
163 |
+
if not request.messages:
|
164 |
+
raise HTTPException(status_code=400, detail="No messages provided")
|
165 |
+
|
166 |
+
latest_message = request.messages[-1].content
|
167 |
+
|
168 |
+
# Process the message
|
169 |
+
response = await nurse_oge.process_message(
|
170 |
+
conversation_id=conversation_id,
|
171 |
+
message=latest_message,
|
172 |
+
history=request.messages[:-1]
|
173 |
+
)
|
174 |
+
|
175 |
+
return response
|
176 |
+
|
177 |
+
# Initialize Gradio interface (optional, for testing)
|
178 |
+
def gradio_chat(message, history):
|
179 |
+
response = nurse_oge.process_message("gradio_user", message, history)
|
180 |
+
return response.response
|
181 |
+
|
182 |
+
demo = gr.ChatInterface(
|
183 |
+
fn=gradio_chat,
|
184 |
+
title="Nurse Oge",
|
185 |
+
description="Finetuned llama 3.0 for medical diagnosis and all. This is just a demo",
|
186 |
+
theme="soft"
|
187 |
+
)
|
188 |
+
|
189 |
+
# Mount both FastAPI and Gradio
|
190 |
+
app = gr.mount_gradio_app(app, demo, path="/gradio")
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
import uvicorn
|
194 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|