benardo0 commited on
Commit
94dc8bb
·
verified ·
1 Parent(s): b1de9b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -28
app.py CHANGED
@@ -9,20 +9,14 @@ import os
9
  import time
10
  import gc
11
  from huggingface_hub import hf_hub_download
 
12
 
13
- # Environment variables for configuration
14
  MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "mradermacher/Llama3-Med42-8B-GGUF")
15
- MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Llama3-Med42-8B.Q4_K_M.gguf")
16
  N_THREADS = int(os.getenv("N_THREADS", "4"))
17
 
18
- # Import llama_cpp with error handling for better debugging
19
- try:
20
- from llama_cpp import Llama
21
- LLAMA_IMPORT_ERROR = None
22
- except Exception as e:
23
- LLAMA_IMPORT_ERROR = str(e)
24
- print(f"Warning: Failed to import llama_cpp: {e}")
25
-
26
  class ConsultationState(Enum):
27
  INITIAL = "initial"
28
  GATHERING_INFO = "gathering_info"
@@ -39,7 +33,7 @@ class ChatResponse(BaseModel):
39
  response: str
40
  finished: bool
41
 
42
- # Standard health assessment questions for thorough patient evaluation
43
  HEALTH_ASSESSMENT_QUESTIONS = [
44
  "What are your current symptoms and how long have you been experiencing them?",
45
  "Do you have any pre-existing medical conditions or chronic illnesses?",
@@ -58,22 +52,27 @@ health information before providing any medical advice.
58
 
59
  class NurseOgeAssistant:
60
  def __init__(self):
61
- if LLAMA_IMPORT_ERROR:
62
- raise ImportError(f"Cannot initialize NurseOgeAssistant due to llama_cpp import error: {LLAMA_IMPORT_ERROR}")
63
-
64
  try:
65
- # Initialize the model using from_pretrained for better compatibility with free tier
66
- self.llm = Llama.from_pretrained(
67
  repo_id=MODEL_REPO_ID,
68
  filename=MODEL_FILENAME,
 
 
 
 
 
 
69
  n_ctx=2048, # Context window size
70
- n_threads=N_THREADS, # Adjust based on available CPU resources
71
- n_gpu_layers=0 # CPU-only inference for free tier
 
72
  )
73
 
74
  except Exception as e:
75
  raise RuntimeError(f"Failed to initialize the model: {str(e)}")
76
 
 
77
  self.consultation_states = {}
78
  self.gathered_info = {}
79
 
@@ -151,17 +150,19 @@ class NurseOgeAssistant:
151
  )
152
  else:
153
  self.consultation_states[conversation_id] = ConsultationState.DIAGNOSIS
 
154
  context = "\n".join([
155
  f"Q: {q}\nA: {a}" for q, a in
156
  zip(HEALTH_ASSESSMENT_QUESTIONS, self.gathered_info[conversation_id])
157
  ])
158
 
 
159
  messages = [
160
  {"role": "system", "content": NURSE_OGE_IDENTITY},
161
  {"role": "user", "content": f"Based on the following patient information, provide a thorough assessment and recommendations:\n\n{context}\n\nOriginal query: {message}"}
162
  ]
163
 
164
- # Implement retry logic for API calls
165
  max_retries = 3
166
  retry_delay = 2
167
 
@@ -169,8 +170,10 @@ class NurseOgeAssistant:
169
  try:
170
  response = self.llm.create_chat_completion(
171
  messages=messages,
172
- max_tokens=512, # Reduced for free tier
173
- temperature=0.7
 
 
174
  )
175
  break
176
  except Exception as e:
@@ -182,6 +185,7 @@ class NurseOgeAssistant:
182
  finished=True
183
  )
184
 
 
185
  self.consultation_states[conversation_id] = ConsultationState.INITIAL
186
  self.gathered_info[conversation_id] = []
187
 
@@ -205,11 +209,12 @@ nurse_oge = None
205
  # Add memory management middleware
206
  @app.middleware("http")
207
  async def add_memory_management(request: Request, call_next):
208
- gc.collect() # Force garbage collection before processing request
209
  response = await call_next(request)
210
- gc.collect() # Clean up after request
211
  return response
212
 
 
213
  @app.on_event("startup")
214
  async def startup_event():
215
  global nurse_oge
@@ -218,10 +223,12 @@ async def startup_event():
218
  except Exception as e:
219
  print(f"Failed to initialize NurseOgeAssistant: {e}")
220
 
 
221
  @app.get("/health")
222
  async def health_check():
223
  return {"status": "healthy", "model_loaded": nurse_oge is not None}
224
 
 
225
  @app.post("/chat")
226
  async def chat_endpoint(request: ChatRequest):
227
  if nurse_oge is None:
@@ -243,7 +250,7 @@ async def chat_endpoint(request: ChatRequest):
243
 
244
  return response
245
 
246
- # Gradio interface
247
  def gradio_chat(message, history):
248
  if nurse_oge is None:
249
  return "The medical assistant is not available at the moment. Please try again later."
@@ -251,17 +258,42 @@ def gradio_chat(message, history):
251
  response = nurse_oge.process_message("gradio_user", message, history)
252
  return response.response
253
 
254
- # Create and configure Gradio interface
255
  demo = gr.ChatInterface(
256
  fn=gradio_chat,
257
- title="Nurse Oge",
258
- description="Finetuned llama 3.0 for medical diagnosis and all. This is just a demo",
259
- theme="soft"
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
261
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  # Mount both FastAPI and Gradio
263
  app = gr.mount_gradio_app(app, demo, path="/gradio")
264
 
 
265
  if __name__ == "__main__":
266
  import uvicorn
267
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
9
  import time
10
  import gc
11
  from huggingface_hub import hf_hub_download
12
+ from llama_cpp import Llama
13
 
14
+ # Configuration variables that can be set through environment variables
15
  MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "mradermacher/Llama3-Med42-8B-GGUF")
16
+ MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Llama3-Med42-8B.Q5_K_M.gguf")
17
  N_THREADS = int(os.getenv("N_THREADS", "4"))
18
 
19
+ # Define our data models for API requests and responses
 
 
 
 
 
 
 
20
  class ConsultationState(Enum):
21
  INITIAL = "initial"
22
  GATHERING_INFO = "gathering_info"
 
33
  response: str
34
  finished: bool
35
 
36
+ # Define our standard health assessment questions
37
  HEALTH_ASSESSMENT_QUESTIONS = [
38
  "What are your current symptoms and how long have you been experiencing them?",
39
  "Do you have any pre-existing medical conditions or chronic illnesses?",
 
52
 
53
  class NurseOgeAssistant:
54
  def __init__(self):
 
 
 
55
  try:
56
+ # Download the model file from Hugging Face
57
+ model_path = hf_hub_download(
58
  repo_id=MODEL_REPO_ID,
59
  filename=MODEL_FILENAME,
60
+ resume_download=True
61
+ )
62
+
63
+ # Initialize the Llama model with appropriate parameters
64
+ self.llm = Llama(
65
+ model_path=model_path,
66
  n_ctx=2048, # Context window size
67
+ n_threads=N_THREADS, # CPU threads to use
68
+ n_gpu_layers=0, # CPU-only inference
69
+ verbose=False # Set to True for debugging
70
  )
71
 
72
  except Exception as e:
73
  raise RuntimeError(f"Failed to initialize the model: {str(e)}")
74
 
75
+ # Initialize conversation state management
76
  self.consultation_states = {}
77
  self.gathered_info = {}
78
 
 
150
  )
151
  else:
152
  self.consultation_states[conversation_id] = ConsultationState.DIAGNOSIS
153
+ # Prepare context from gathered information
154
  context = "\n".join([
155
  f"Q: {q}\nA: {a}" for q, a in
156
  zip(HEALTH_ASSESSMENT_QUESTIONS, self.gathered_info[conversation_id])
157
  ])
158
 
159
+ # Prepare messages for the model
160
  messages = [
161
  {"role": "system", "content": NURSE_OGE_IDENTITY},
162
  {"role": "user", "content": f"Based on the following patient information, provide a thorough assessment and recommendations:\n\n{context}\n\nOriginal query: {message}"}
163
  ]
164
 
165
+ # Implement retry logic for model inference
166
  max_retries = 3
167
  retry_delay = 2
168
 
 
170
  try:
171
  response = self.llm.create_chat_completion(
172
  messages=messages,
173
+ max_tokens=512,
174
+ temperature=0.7,
175
+ top_p=0.95,
176
+ stop=["</s>"]
177
  )
178
  break
179
  except Exception as e:
 
185
  finished=True
186
  )
187
 
188
+ # Reset conversation state
189
  self.consultation_states[conversation_id] = ConsultationState.INITIAL
190
  self.gathered_info[conversation_id] = []
191
 
 
209
  # Add memory management middleware
210
  @app.middleware("http")
211
  async def add_memory_management(request: Request, call_next):
212
+ gc.collect()
213
  response = await call_next(request)
214
+ gc.collect()
215
  return response
216
 
217
+ # Initialize the assistant during startup
218
  @app.on_event("startup")
219
  async def startup_event():
220
  global nurse_oge
 
223
  except Exception as e:
224
  print(f"Failed to initialize NurseOgeAssistant: {e}")
225
 
226
+ # Health check endpoint
227
  @app.get("/health")
228
  async def health_check():
229
  return {"status": "healthy", "model_loaded": nurse_oge is not None}
230
 
231
+ # Chat endpoint
232
  @app.post("/chat")
233
  async def chat_endpoint(request: ChatRequest):
234
  if nurse_oge is None:
 
250
 
251
  return response
252
 
253
+ # Gradio chat interface function
254
  def gradio_chat(message, history):
255
  if nurse_oge is None:
256
  return "The medical assistant is not available at the moment. Please try again later."
 
258
  response = nurse_oge.process_message("gradio_user", message, history)
259
  return response.response
260
 
261
+ # Create and configure Gradio interface with enhanced styling
262
  demo = gr.ChatInterface(
263
  fn=gradio_chat,
264
+ title="Nurse Oge - Medical Assistant",
265
+ description="""Welcome to Nurse Oge, your AI medical assistant specialized in serving Nigerian communities.
266
+ This system provides medical guidance while ensuring comprehensive health information gathering.""",
267
+ examples=[
268
+ ["What are the common symptoms of malaria?"],
269
+ ["I've been having headaches for the past week"],
270
+ ["How can I prevent typhoid fever?"],
271
+ ],
272
+ theme=gr.themes.Soft(
273
+ primary_hue="blue",
274
+ secondary_hue="purple",
275
+ ),
276
+ retry_btn="Try Again",
277
+ undo_btn="Undo Last",
278
+ clear_btn="Clear Chat"
279
  )
280
 
281
+ # Add custom CSS for better appearance
282
+ demo.css = """
283
+ .gradio-container {
284
+ font-family: 'Arial', sans-serif;
285
+ }
286
+ .chat-message {
287
+ padding: 1rem;
288
+ border-radius: 0.5rem;
289
+ margin-bottom: 0.5rem;
290
+ }
291
+ """
292
+
293
  # Mount both FastAPI and Gradio
294
  app = gr.mount_gradio_app(app, demo, path="/gradio")
295
 
296
+ # Run the application
297
  if __name__ == "__main__":
298
  import uvicorn
299
  uvicorn.run(app, host="0.0.0.0", port=8000)