benardo0 commited on
Commit
e53bd9c
·
verified ·
1 Parent(s): 94dc8bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -35
app.py CHANGED
@@ -8,15 +8,17 @@ import re
8
  import os
9
  import time
10
  import gc
 
11
  from huggingface_hub import hf_hub_download
12
  from llama_cpp import Llama
13
 
14
  # Configuration variables that can be set through environment variables
 
15
  MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "mradermacher/Llama3-Med42-8B-GGUF")
16
  MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Llama3-Med42-8B.Q5_K_M.gguf")
17
  N_THREADS = int(os.getenv("N_THREADS", "4"))
18
 
19
- # Define our data models for API requests and responses
20
  class ConsultationState(Enum):
21
  INITIAL = "initial"
22
  GATHERING_INFO = "gathering_info"
@@ -33,7 +35,7 @@ class ChatResponse(BaseModel):
33
  response: str
34
  finished: bool
35
 
36
- # Define our standard health assessment questions
37
  HEALTH_ASSESSMENT_QUESTIONS = [
38
  "What are your current symptoms and how long have you been experiencing them?",
39
  "Do you have any pre-existing medical conditions or chronic illnesses?",
@@ -42,7 +44,7 @@ HEALTH_ASSESSMENT_QUESTIONS = [
42
  "Have you had any similar symptoms in the past? If yes, what treatments worked?"
43
  ]
44
 
45
- # Define the AI assistant's identity and role
46
  NURSE_OGE_IDENTITY = """
47
  You are Nurse Oge, a medical AI assistant focused on serving patients in Nigeria. Always be empathetic,
48
  professional, and thorough in your assessments. When asked about your identity, explain that you are
@@ -51,32 +53,29 @@ health information before providing any medical advice.
51
  """
52
 
53
  class NurseOgeAssistant:
 
 
 
54
  def __init__(self):
55
  try:
56
- # Download the model file from Hugging Face
57
- model_path = hf_hub_download(
58
  repo_id=MODEL_REPO_ID,
59
  filename=MODEL_FILENAME,
60
- resume_download=True
61
- )
62
-
63
- # Initialize the Llama model with appropriate parameters
64
- self.llm = Llama(
65
- model_path=model_path,
66
  n_ctx=2048, # Context window size
67
  n_threads=N_THREADS, # CPU threads to use
68
- n_gpu_layers=0, # CPU-only inference
69
- verbose=False # Set to True for debugging
70
  )
71
 
72
  except Exception as e:
73
  raise RuntimeError(f"Failed to initialize the model: {str(e)}")
74
 
75
- # Initialize conversation state management
76
  self.consultation_states = {}
77
  self.gathered_info = {}
78
 
79
  def _is_identity_question(self, message: str) -> bool:
 
80
  identity_patterns = [
81
  r"who are you",
82
  r"what are you",
@@ -87,6 +86,7 @@ class NurseOgeAssistant:
87
  return any(re.search(pattern, message.lower()) for pattern in identity_patterns)
88
 
89
  def _is_location_question(self, message: str) -> bool:
 
90
  location_patterns = [
91
  r"where are you",
92
  r"which country",
@@ -97,6 +97,7 @@ class NurseOgeAssistant:
97
  return any(re.search(pattern, message.lower()) for pattern in location_patterns)
98
 
99
  def _get_next_assessment_question(self, conversation_id: str) -> Optional[str]:
 
100
  if conversation_id not in self.gathered_info:
101
  self.gathered_info[conversation_id] = []
102
 
@@ -106,6 +107,9 @@ class NurseOgeAssistant:
106
  return None
107
 
108
  async def process_message(self, conversation_id: str, message: str, history: List[Dict]) -> ChatResponse:
 
 
 
109
  try:
110
  # Initialize state for new conversations
111
  if conversation_id not in self.consultation_states:
@@ -159,7 +163,7 @@ class NurseOgeAssistant:
159
  # Prepare messages for the model
160
  messages = [
161
  {"role": "system", "content": NURSE_OGE_IDENTITY},
162
- {"role": "user", "content": f"Based on the following patient information, provide a thorough assessment and recommendations:\n\n{context}\n\nOriginal query: {message}"}
163
  ]
164
 
165
  # Implement retry logic for model inference
@@ -200,37 +204,41 @@ class NurseOgeAssistant:
200
  finished=True
201
  )
202
 
203
- # Initialize FastAPI
204
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
205
 
206
- # Create a global variable for our assistant
207
- nurse_oge = None
208
 
209
  # Add memory management middleware
210
  @app.middleware("http")
211
  async def add_memory_management(request: Request, call_next):
 
212
  gc.collect()
213
  response = await call_next(request)
214
  gc.collect()
215
  return response
216
 
217
- # Initialize the assistant during startup
218
- @app.on_event("startup")
219
- async def startup_event():
220
- global nurse_oge
221
- try:
222
- nurse_oge = NurseOgeAssistant()
223
- except Exception as e:
224
- print(f"Failed to initialize NurseOgeAssistant: {e}")
225
-
226
  # Health check endpoint
227
  @app.get("/health")
228
  async def health_check():
 
229
  return {"status": "healthy", "model_loaded": nurse_oge is not None}
230
 
231
  # Chat endpoint
232
  @app.post("/chat")
233
  async def chat_endpoint(request: ChatRequest):
 
234
  if nurse_oge is None:
235
  raise HTTPException(
236
  status_code=503,
@@ -251,14 +259,15 @@ async def chat_endpoint(request: ChatRequest):
251
  return response
252
 
253
  # Gradio chat interface function
254
- def gradio_chat(message, history):
 
255
  if nurse_oge is None:
256
  return "The medical assistant is not available at the moment. Please try again later."
257
 
258
- response = nurse_oge.process_message("gradio_user", message, history)
259
  return response.response
260
 
261
- # Create and configure Gradio interface with enhanced styling
262
  demo = gr.ChatInterface(
263
  fn=gradio_chat,
264
  title="Nurse Oge - Medical Assistant",
@@ -272,10 +281,7 @@ demo = gr.ChatInterface(
272
  theme=gr.themes.Soft(
273
  primary_hue="blue",
274
  secondary_hue="purple",
275
- ),
276
- retry_btn="Try Again",
277
- undo_btn="Undo Last",
278
- clear_btn="Clear Chat"
279
  )
280
 
281
  # Add custom CSS for better appearance
 
8
  import os
9
  import time
10
  import gc
11
+ from contextlib import asynccontextmanager
12
  from huggingface_hub import hf_hub_download
13
  from llama_cpp import Llama
14
 
15
  # Configuration variables that can be set through environment variables
16
+ # These allow for flexible deployment configuration without code changes
17
  MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "mradermacher/Llama3-Med42-8B-GGUF")
18
  MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Llama3-Med42-8B.Q5_K_M.gguf")
19
  N_THREADS = int(os.getenv("N_THREADS", "4"))
20
 
21
+ # Data models for API request/response handling
22
  class ConsultationState(Enum):
23
  INITIAL = "initial"
24
  GATHERING_INFO = "gathering_info"
 
35
  response: str
36
  finished: bool
37
 
38
+ # Standardized health assessment questions for consistent patient evaluation
39
  HEALTH_ASSESSMENT_QUESTIONS = [
40
  "What are your current symptoms and how long have you been experiencing them?",
41
  "Do you have any pre-existing medical conditions or chronic illnesses?",
 
44
  "Have you had any similar symptoms in the past? If yes, what treatments worked?"
45
  ]
46
 
47
+ # AI assistant's identity and role definition
48
  NURSE_OGE_IDENTITY = """
49
  You are Nurse Oge, a medical AI assistant focused on serving patients in Nigeria. Always be empathetic,
50
  professional, and thorough in your assessments. When asked about your identity, explain that you are
 
53
  """
54
 
55
  class NurseOgeAssistant:
56
+ """
57
+ Main assistant class that handles conversation management and medical consultations
58
+ """
59
  def __init__(self):
60
  try:
61
+ # Initialize the Llama model using from_pretrained as per documentation
62
+ self.llm = Llama.from_pretrained(
63
  repo_id=MODEL_REPO_ID,
64
  filename=MODEL_FILENAME,
 
 
 
 
 
 
65
  n_ctx=2048, # Context window size
66
  n_threads=N_THREADS, # CPU threads to use
67
+ n_gpu_layers=0 # CPU-only inference
 
68
  )
69
 
70
  except Exception as e:
71
  raise RuntimeError(f"Failed to initialize the model: {str(e)}")
72
 
73
+ # State management for multiple concurrent conversations
74
  self.consultation_states = {}
75
  self.gathered_info = {}
76
 
77
  def _is_identity_question(self, message: str) -> bool:
78
+ """Detect if the user is asking about the assistant's identity"""
79
  identity_patterns = [
80
  r"who are you",
81
  r"what are you",
 
86
  return any(re.search(pattern, message.lower()) for pattern in identity_patterns)
87
 
88
  def _is_location_question(self, message: str) -> bool:
89
+ """Detect if the user is asking about the assistant's location"""
90
  location_patterns = [
91
  r"where are you",
92
  r"which country",
 
97
  return any(re.search(pattern, message.lower()) for pattern in location_patterns)
98
 
99
  def _get_next_assessment_question(self, conversation_id: str) -> Optional[str]:
100
+ """Get the next health assessment question based on conversation progress"""
101
  if conversation_id not in self.gathered_info:
102
  self.gathered_info[conversation_id] = []
103
 
 
107
  return None
108
 
109
  async def process_message(self, conversation_id: str, message: str, history: List[Dict]) -> ChatResponse:
110
+ """
111
+ Process incoming messages and manage the conversation flow
112
+ """
113
  try:
114
  # Initialize state for new conversations
115
  if conversation_id not in self.consultation_states:
 
163
  # Prepare messages for the model
164
  messages = [
165
  {"role": "system", "content": NURSE_OGE_IDENTITY},
166
+ {"role": "user", "content": f"Based on the following patient information, provide thorough assessment, diagnosis and recommendations:\n\n{context}\n\nOriginal query: {message}"}
167
  ]
168
 
169
  # Implement retry logic for model inference
 
204
  finished=True
205
  )
206
 
207
+ # Define FastAPI lifespan for startup/shutdown events
208
+ @asynccontextmanager
209
+ async def lifespan(app: FastAPI):
210
+ # Initialize on startup
211
+ global nurse_oge
212
+ try:
213
+ nurse_oge = NurseOgeAssistant()
214
+ except Exception as e:
215
+ print(f"Failed to initialize NurseOgeAssistant: {e}")
216
+ yield
217
+ # Clean up on shutdown if needed
218
+ # Add cleanup code here
219
 
220
+ # Initialize FastAPI with lifespan
221
+ app = FastAPI(lifespan=lifespan)
222
 
223
  # Add memory management middleware
224
  @app.middleware("http")
225
  async def add_memory_management(request: Request, call_next):
226
+ """Middleware to help manage memory usage"""
227
  gc.collect()
228
  response = await call_next(request)
229
  gc.collect()
230
  return response
231
 
 
 
 
 
 
 
 
 
 
232
  # Health check endpoint
233
  @app.get("/health")
234
  async def health_check():
235
+ """Endpoint to verify service health"""
236
  return {"status": "healthy", "model_loaded": nurse_oge is not None}
237
 
238
  # Chat endpoint
239
  @app.post("/chat")
240
  async def chat_endpoint(request: ChatRequest):
241
+ """Main chat endpoint for API interactions"""
242
  if nurse_oge is None:
243
  raise HTTPException(
244
  status_code=503,
 
259
  return response
260
 
261
  # Gradio chat interface function
262
+ async def gradio_chat(message, history):
263
+ """Handler for Gradio chat interface"""
264
  if nurse_oge is None:
265
  return "The medical assistant is not available at the moment. Please try again later."
266
 
267
+ response = await nurse_oge.process_message("gradio_user", message, history)
268
  return response.response
269
 
270
+ # Create and configure Gradio interface
271
  demo = gr.ChatInterface(
272
  fn=gradio_chat,
273
  title="Nurse Oge - Medical Assistant",
 
281
  theme=gr.themes.Soft(
282
  primary_hue="blue",
283
  secondary_hue="purple",
284
+ )
 
 
 
285
  )
286
 
287
  # Add custom CSS for better appearance