benardo0 commited on
Commit
a444494
·
verified ·
1 Parent(s): 0c375b5

Create app.py

Browse files

app.py for project

Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional, Dict
4
+ from llama_cpp import Llama
5
+ import gradio as gr
6
+ import json
7
+ from enum import Enum
8
+ import re
9
+
10
+ class ConsultationState(Enum):
11
+ INITIAL = "initial"
12
+ GATHERING_INFO = "gathering_info"
13
+ DIAGNOSIS = "diagnosis"
14
+
15
+ class Message(BaseModel):
16
+ role: str
17
+ content: str
18
+
19
+ class ChatRequest(BaseModel):
20
+ messages: List[Message]
21
+
22
+ class ChatResponse(BaseModel):
23
+ response: str
24
+ finished: bool
25
+
26
+ # Standard health assessment questions that Nurse Oge always asks
27
+ HEALTH_ASSESSMENT_QUESTIONS = [
28
+ "What are your current symptoms and how long have you been experiencing them?",
29
+ "Do you have any pre-existing medical conditions or chronic illnesses?",
30
+ "Are you currently taking any medications? If yes, please list them.",
31
+ "Is there any relevant family medical history I should know about?",
32
+ "Have you had any similar symptoms in the past? If yes, what treatments worked?"
33
+ ]
34
+
35
+ # Personality prompts for Nurse Oge
36
+ NURSE_OGE_IDENTITY = """
37
+ You are Nurse Oge, a medical AI assistant focused on serving patients in Nigeria. Always be empathetic,
38
+ professional, and thorough in your assessments. When asked about your identity, explain that you are
39
+ Nurse Oge, a medical AI assistant serving Nigerian communities. Remember that you must gather complete
40
+ health information before providing any medical advice.
41
+ """
42
+
43
+ class NurseOgeAssistant:
44
+ def __init__(self):
45
+ self.llm = Llama.from_pretrained(
46
+ repo_id="mradermacher/Llama3-Med42-8B-GGUF",
47
+ filename="Llama3-Med42-8B.IQ3_M.gguf",
48
+ verbose=False
49
+ )
50
+ self.consultation_states = {} # Tracks state for each conversation
51
+ self.gathered_info = {} # Stores gathered health information
52
+
53
+ def _is_identity_question(self, message: str) -> bool:
54
+ identity_patterns = [
55
+ r"who are you",
56
+ r"what are you",
57
+ r"your name",
58
+ r"what should I call you",
59
+ r"tell me about yourself"
60
+ ]
61
+ return any(re.search(pattern, message.lower()) for pattern in identity_patterns)
62
+
63
+ def _is_location_question(self, message: str) -> bool:
64
+ location_patterns = [
65
+ r"where are you",
66
+ r"which country",
67
+ r"your location",
68
+ r"where do you work",
69
+ r"where are you based"
70
+ ]
71
+ return any(re.search(pattern, message.lower()) for pattern in location_patterns)
72
+
73
+ def _get_next_assessment_question(self, conversation_id: str) -> Optional[str]:
74
+ if conversation_id not in self.gathered_info:
75
+ self.gathered_info[conversation_id] = []
76
+
77
+ questions_asked = len(self.gathered_info[conversation_id])
78
+ if questions_asked < len(HEALTH_ASSESSMENT_QUESTIONS):
79
+ return HEALTH_ASSESSMENT_QUESTIONS[questions_asked]
80
+ return None
81
+
82
+ async def process_message(self, conversation_id: str, message: str, history: List[Dict]) -> ChatResponse:
83
+ # Initialize state if new conversation
84
+ if conversation_id not in self.consultation_states:
85
+ self.consultation_states[conversation_id] = ConsultationState.INITIAL
86
+
87
+ # Handle identity questions
88
+ if self._is_identity_question(message):
89
+ return ChatResponse(
90
+ response="I am Nurse Oge, a medical AI assistant dedicated to helping patients in Nigeria. "
91
+ "I'm here to provide medical guidance while ensuring I gather all necessary health information "
92
+ "for accurate assessments.",
93
+ finished=True
94
+ )
95
+
96
+ # Handle location questions
97
+ if self._is_location_question(message):
98
+ return ChatResponse(
99
+ response="I am based in Nigeria and specifically trained to serve Nigerian communities, "
100
+ "taking into account local healthcare contexts and needs.",
101
+ finished=True
102
+ )
103
+
104
+ # Start health assessment if it's a medical query
105
+ if self.consultation_states[conversation_id] == ConsultationState.INITIAL:
106
+ self.consultation_states[conversation_id] = ConsultationState.GATHERING_INFO
107
+ next_question = self._get_next_assessment_question(conversation_id)
108
+ return ChatResponse(
109
+ response=f"Before I can provide any medical advice, I need to gather some important health information. "
110
+ f"{next_question}",
111
+ finished=False
112
+ )
113
+
114
+ # Continue gathering information
115
+ if self.consultation_states[conversation_id] == ConsultationState.GATHERING_INFO:
116
+ self.gathered_info[conversation_id].append(message)
117
+ next_question = self._get_next_assessment_question(conversation_id)
118
+
119
+ if next_question:
120
+ return ChatResponse(
121
+ response=f"Thank you for that information. {next_question}",
122
+ finished=False
123
+ )
124
+ else:
125
+ self.consultation_states[conversation_id] = ConsultationState.DIAGNOSIS
126
+ # Prepare complete context for final response
127
+ context = "\n".join([
128
+ f"Q: {q}\nA: {a}" for q, a in
129
+ zip(HEALTH_ASSESSMENT_QUESTIONS, self.gathered_info[conversation_id])
130
+ ])
131
+
132
+ # Generate final response using the model
133
+ messages = [
134
+ {"role": "system", "content": NURSE_OGE_IDENTITY},
135
+ {"role": "user", "content": f"Based on the following patient information, provide a thorough assessment, diagnosis and recommendations:\n\n{context}\n\nOriginal query: {message}"}
136
+ ]
137
+
138
+ response = self.llm.create_chat_completion(
139
+ messages=messages,
140
+ max_tokens=1024,
141
+ temperature=0.7
142
+ )
143
+
144
+ # Reset state for next consultation
145
+ self.consultation_states[conversation_id] = ConsultationState.INITIAL
146
+ self.gathered_info[conversation_id] = []
147
+
148
+ return ChatResponse(
149
+ response=response['choices'][0]['message']['content'],
150
+ finished=True
151
+ )
152
+
153
+ # Initialize FastAPI and Nurse Oge
154
+ app = FastAPI()
155
+ nurse_oge = NurseOgeAssistant()
156
+
157
+ @app.post("/chat")
158
+ async def chat_endpoint(request: ChatRequest):
159
+ # Generate a conversation ID (in a real app, you'd want to manage these better)
160
+ conversation_id = "default"
161
+
162
+ # Extract the latest message
163
+ if not request.messages:
164
+ raise HTTPException(status_code=400, detail="No messages provided")
165
+
166
+ latest_message = request.messages[-1].content
167
+
168
+ # Process the message
169
+ response = await nurse_oge.process_message(
170
+ conversation_id=conversation_id,
171
+ message=latest_message,
172
+ history=request.messages[:-1]
173
+ )
174
+
175
+ return response
176
+
177
+ # Initialize Gradio interface (optional, for testing)
178
+ def gradio_chat(message, history):
179
+ response = nurse_oge.process_message("gradio_user", message, history)
180
+ return response.response
181
+
182
+ demo = gr.ChatInterface(
183
+ fn=gradio_chat,
184
+ title="Nurse Oge",
185
+ description="Finetuned llama 3.0 for medical diagnosis and all. This is just a demo",
186
+ theme="soft"
187
+ )
188
+
189
+ # Mount both FastAPI and Gradio
190
+ app = gr.mount_gradio_app(app, demo, path="/gradio")
191
+
192
+ if __name__ == "__main__":
193
+ import uvicorn
194
+ uvicorn.run(app, host="0.0.0.0", port=8000)