benardo0 commited on
Commit
b4ff37d
·
verified ·
1 Parent(s): e53bd9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -280
app.py CHANGED
@@ -1,305 +1,146 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
- from pydantic import BaseModel
3
- from typing import List, Optional, Dict
4
- import gradio as gr
5
- import json
6
- from enum import Enum
7
- import re
8
  import os
9
- import time
10
- import gc
11
- from contextlib import asynccontextmanager
12
- from huggingface_hub import hf_hub_download
13
- from llama_cpp import Llama
14
-
15
- # Configuration variables that can be set through environment variables
16
- # These allow for flexible deployment configuration without code changes
17
- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "mradermacher/Llama3-Med42-8B-GGUF")
18
- MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Llama3-Med42-8B.Q5_K_M.gguf")
19
- N_THREADS = int(os.getenv("N_THREADS", "4"))
20
-
21
- # Data models for API request/response handling
22
- class ConsultationState(Enum):
23
- INITIAL = "initial"
24
- GATHERING_INFO = "gathering_info"
25
- DIAGNOSIS = "diagnosis"
26
-
27
- class Message(BaseModel):
28
- role: str
29
- content: str
30
-
31
- class ChatRequest(BaseModel):
32
- messages: List[Message]
33
-
34
- class ChatResponse(BaseModel):
35
- response: str
36
- finished: bool
37
-
38
- # Standardized health assessment questions for consistent patient evaluation
39
- HEALTH_ASSESSMENT_QUESTIONS = [
40
- "What are your current symptoms and how long have you been experiencing them?",
41
- "Do you have any pre-existing medical conditions or chronic illnesses?",
42
- "Are you currently taking any medications? If yes, please list them.",
43
- "Is there any relevant family medical history I should know about?",
44
- "Have you had any similar symptoms in the past? If yes, what treatments worked?"
45
- ]
46
 
47
- # AI assistant's identity and role definition
48
- NURSE_OGE_IDENTITY = """
49
- You are Nurse Oge, a medical AI assistant focused on serving patients in Nigeria. Always be empathetic,
50
- professional, and thorough in your assessments. When asked about your identity, explain that you are
51
- Nurse Oge, a medical AI assistant serving Nigerian communities. Remember that you must gather complete
52
- health information before providing any medical advice.
53
- """
54
 
55
- class NurseOgeAssistant:
56
- """
57
- Main assistant class that handles conversation management and medical consultations
58
- """
59
  def __init__(self):
 
60
  try:
61
- # Initialize the Llama model using from_pretrained as per documentation
62
- self.llm = Llama.from_pretrained(
63
- repo_id=MODEL_REPO_ID,
64
- filename=MODEL_FILENAME,
65
- n_ctx=2048, # Context window size
66
- n_threads=N_THREADS, # CPU threads to use
67
- n_gpu_layers=0 # CPU-only inference
 
 
 
 
 
 
 
 
68
  )
69
 
70
- except Exception as e:
71
- raise RuntimeError(f"Failed to initialize the model: {str(e)}")
 
72
 
73
- # State management for multiple concurrent conversations
74
- self.consultation_states = {}
75
- self.gathered_info = {}
76
-
77
- def _is_identity_question(self, message: str) -> bool:
78
- """Detect if the user is asking about the assistant's identity"""
79
- identity_patterns = [
80
- r"who are you",
81
- r"what are you",
82
- r"your name",
83
- r"what should I call you",
84
- r"tell me about yourself"
85
- ]
86
- return any(re.search(pattern, message.lower()) for pattern in identity_patterns)
87
-
88
- def _is_location_question(self, message: str) -> bool:
89
- """Detect if the user is asking about the assistant's location"""
90
- location_patterns = [
91
- r"where are you",
92
- r"which country",
93
- r"your location",
94
- r"where do you work",
95
- r"where are you based"
96
- ]
97
- return any(re.search(pattern, message.lower()) for pattern in location_patterns)
98
-
99
- def _get_next_assessment_question(self, conversation_id: str) -> Optional[str]:
100
- """Get the next health assessment question based on conversation progress"""
101
- if conversation_id not in self.gathered_info:
102
- self.gathered_info[conversation_id] = []
103
-
104
- questions_asked = len(self.gathered_info[conversation_id])
105
- if questions_asked < len(HEALTH_ASSESSMENT_QUESTIONS):
106
- return HEALTH_ASSESSMENT_QUESTIONS[questions_asked]
107
- return None
108
 
109
- async def process_message(self, conversation_id: str, message: str, history: List[Dict]) -> ChatResponse:
110
- """
111
- Process incoming messages and manage the conversation flow
112
- """
113
  try:
114
- # Initialize state for new conversations
115
- if conversation_id not in self.consultation_states:
116
- self.consultation_states[conversation_id] = ConsultationState.INITIAL
117
-
118
- # Handle identity questions
119
- if self._is_identity_question(message):
120
- return ChatResponse(
121
- response="I am Nurse Oge, a medical AI assistant dedicated to helping patients in Nigeria. "
122
- "I'm here to provide medical guidance while ensuring I gather all necessary health information "
123
- "for accurate assessments.",
124
- finished=True
125
- )
126
-
127
- # Handle location questions
128
- if self._is_location_question(message):
129
- return ChatResponse(
130
- response="I am based in Nigeria and specifically trained to serve Nigerian communities, "
131
- "taking into account local healthcare contexts and needs.",
132
- finished=True
133
- )
134
-
135
- # Start health assessment for medical queries
136
- if self.consultation_states[conversation_id] == ConsultationState.INITIAL:
137
- self.consultation_states[conversation_id] = ConsultationState.GATHERING_INFO
138
- next_question = self._get_next_assessment_question(conversation_id)
139
- return ChatResponse(
140
- response=f"Before I can provide any medical advice, I need to gather some important health information. "
141
- f"{next_question}",
142
- finished=False
143
  )
144
-
145
- # Continue gathering information
146
- if self.consultation_states[conversation_id] == ConsultationState.GATHERING_INFO:
147
- self.gathered_info[conversation_id].append(message)
148
- next_question = self._get_next_assessment_question(conversation_id)
149
-
150
- if next_question:
151
- return ChatResponse(
152
- response=f"Thank you for that information. {next_question}",
153
- finished=False
154
- )
155
- else:
156
- self.consultation_states[conversation_id] = ConsultationState.DIAGNOSIS
157
- # Prepare context from gathered information
158
- context = "\n".join([
159
- f"Q: {q}\nA: {a}" for q, a in
160
- zip(HEALTH_ASSESSMENT_QUESTIONS, self.gathered_info[conversation_id])
161
- ])
162
-
163
- # Prepare messages for the model
164
- messages = [
165
- {"role": "system", "content": NURSE_OGE_IDENTITY},
166
- {"role": "user", "content": f"Based on the following patient information, provide thorough assessment, diagnosis and recommendations:\n\n{context}\n\nOriginal query: {message}"}
167
- ]
168
-
169
- # Implement retry logic for model inference
170
- max_retries = 3
171
- retry_delay = 2
172
-
173
- for attempt in range(max_retries):
174
- try:
175
- response = self.llm.create_chat_completion(
176
- messages=messages,
177
- max_tokens=512,
178
- temperature=0.7,
179
- top_p=0.95,
180
- stop=["</s>"]
181
- )
182
- break
183
- except Exception as e:
184
- if attempt < max_retries - 1:
185
- time.sleep(retry_delay)
186
- continue
187
- return ChatResponse(
188
- response="I'm sorry, I'm experiencing some technical difficulties. Please try again in a moment.",
189
- finished=True
190
- )
191
-
192
- # Reset conversation state
193
- self.consultation_states[conversation_id] = ConsultationState.INITIAL
194
- self.gathered_info[conversation_id] = []
195
-
196
- return ChatResponse(
197
- response=response['choices'][0]['message']['content'],
198
- finished=True
199
- )
200
-
201
- except Exception as e:
202
- return ChatResponse(
203
- response=f"An error occurred while processing your request. Please try again.",
204
- finished=True
205
  )
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- # Define FastAPI lifespan for startup/shutdown events
208
- @asynccontextmanager
209
- async def lifespan(app: FastAPI):
210
- # Initialize on startup
211
- global nurse_oge
212
  try:
213
- nurse_oge = NurseOgeAssistant()
 
214
  except Exception as e:
215
- print(f"Failed to initialize NurseOgeAssistant: {e}")
216
- yield
217
- # Clean up on shutdown if needed
218
- # Add cleanup code here
219
 
220
- # Initialize FastAPI with lifespan
221
- app = FastAPI(lifespan=lifespan)
222
-
223
- # Add memory management middleware
224
- @app.middleware("http")
225
- async def add_memory_management(request: Request, call_next):
226
- """Middleware to help manage memory usage"""
227
- gc.collect()
228
- response = await call_next(request)
229
- gc.collect()
230
- return response
231
-
232
- # Health check endpoint
233
- @app.get("/health")
234
- async def health_check():
235
- """Endpoint to verify service health"""
236
- return {"status": "healthy", "model_loaded": nurse_oge is not None}
237
-
238
- # Chat endpoint
239
- @app.post("/chat")
240
- async def chat_endpoint(request: ChatRequest):
241
- """Main chat endpoint for API interactions"""
242
- if nurse_oge is None:
243
- raise HTTPException(
244
- status_code=503,
245
- detail="The medical assistant is not available at the moment. Please try again later."
246
- )
247
 
248
- if not request.messages:
249
- raise HTTPException(status_code=400, detail="No messages provided")
 
 
250
 
251
- latest_message = request.messages[-1].content
252
-
253
- response = await nurse_oge.process_message(
254
- conversation_id="default",
255
- message=latest_message,
256
- history=request.messages[:-1]
257
- )
258
-
259
- return response
260
-
261
- # Gradio chat interface function
262
- async def gradio_chat(message, history):
263
- """Handler for Gradio chat interface"""
264
- if nurse_oge is None:
265
- return "The medical assistant is not available at the moment. Please try again later."
266
-
267
- response = await nurse_oge.process_message("gradio_user", message, history)
268
- return response.response
269
 
270
- # Create and configure Gradio interface
271
  demo = gr.ChatInterface(
272
- fn=gradio_chat,
273
- title="Nurse Oge - Medical Assistant",
274
- description="""Welcome to Nurse Oge, your AI medical assistant specialized in serving Nigerian communities.
275
- This system provides medical guidance while ensuring comprehensive health information gathering.""",
276
  examples=[
277
- ["What are the common symptoms of malaria?"],
278
- ["I've been having headaches for the past week"],
279
- ["How can I prevent typhoid fever?"],
280
  ],
281
- theme=gr.themes.Soft(
282
- primary_hue="blue",
283
- secondary_hue="purple",
284
- )
285
  )
286
 
287
- # Add custom CSS for better appearance
288
- demo.css = """
289
- .gradio-container {
290
- font-family: 'Arial', sans-serif;
291
- }
292
- .chat-message {
293
- padding: 1rem;
294
- border-radius: 0.5rem;
295
- margin-bottom: 0.5rem;
296
- }
297
- """
298
-
299
- # Mount both FastAPI and Gradio
300
- app = gr.mount_gradio_app(app, demo, path="/gradio")
301
-
302
- # Run the application
303
  if __name__ == "__main__":
304
- import uvicorn
305
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
1
  import os
2
+ import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+ from typing import List, Dict
6
+ import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Set up logging to help us debug model loading and inference
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
 
 
 
 
11
 
12
+ class MedicalAssistant:
 
 
 
13
  def __init__(self):
14
+ """Initialize the medical assistant with model and tokenizer"""
15
  try:
16
+ logger.info("Starting model initialization...")
17
+
18
+ # Model configuration - adjust these based on your available compute
19
+ self.model_name = "mradermacher/Llama3-Med42-8B-GGUF"
20
+ self.max_length = 1048
21
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ logger.info(f"Using device: {self.device}")
24
+
25
+ # Load tokenizer first - this is typically faster and can catch issues early
26
+ logger.info("Loading tokenizer...")
27
+ self.tokenizer = AutoTokenizer.from_pretrained(
28
+ self.model_name,
29
+ padding_side="left",
30
+ trust_remote_code=True
31
  )
32
 
33
+ # Set padding token if not set
34
+ if self.tokenizer.pad_token is None:
35
+ self.tokenizer.pad_token = self.tokenizer.eos_token
36
 
37
+ # Load model with memory optimizations
38
+ logger.info("Loading model...")
39
+ self.model = AutoModelForCausalLM.from_pretrained(
40
+ self.model_name,
41
+ torch_dtype=torch.float16,
42
+ device_map="auto",
43
+ load_in_8bit=True,
44
+ trust_remote_code=True
45
+ )
46
+
47
+ logger.info("Model initialization completed successfully!")
48
+
49
+ except Exception as e:
50
+ logger.error(f"Error during initialization: {str(e)}")
51
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
54
+ """Generate a response to the user's message"""
 
 
55
  try:
56
+ # Prepare the prompt
57
+ system_prompt = """You are a medical AI assistant. Respond to medical queries
58
+ professionally and accurately. If you're unsure, always recommend consulting
59
+ with a healthcare provider."""
60
+
61
+ # Combine system prompt, chat history, and current message
62
+ full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:"
63
+
64
+ # Tokenize input
65
+ inputs = self.tokenizer(
66
+ full_prompt,
67
+ return_tensors="pt",
68
+ padding=True,
69
+ truncation=True,
70
+ max_length=self.max_length
71
+ ).to(self.device)
72
+
73
+ # Generate response
74
+ with torch.no_grad():
75
+ outputs = self.model.generate(
76
+ **inputs,
77
+ max_new_tokens=512,
78
+ do_sample=True,
79
+ temperature=0.7,
80
+ top_p=0.95,
81
+ pad_token_id=self.tokenizer.pad_token_id,
82
+ repetition_penalty=1.1
 
 
83
  )
84
+
85
+ # Decode and clean up response
86
+ response = self.tokenizer.decode(
87
+ outputs[0],
88
+ skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
+
91
+ # Extract just the assistant's response
92
+ response = response.split("Assistant:")[-1].strip()
93
+
94
+ return response
95
+
96
+ except Exception as e:
97
+ logger.error(f"Error during response generation: {str(e)}")
98
+ return f"I apologize, but I encountered an error. Please try again."
99
+
100
+ # Initialize the assistant
101
+ assistant = None
102
 
103
+ def initialize_assistant():
104
+ """Initialize the assistant and handle any errors"""
105
+ global assistant
 
 
106
  try:
107
+ assistant = MedicalAssistant()
108
+ return True
109
  except Exception as e:
110
+ logger.error(f"Failed to initialize assistant: {str(e)}")
111
+ return False
 
 
112
 
113
+ def chat_response(message: str, history: List[Dict]):
114
+ """Handle chat messages and return responses"""
115
+ global assistant
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ # Check if assistant is initialized
118
+ if assistant is None:
119
+ if not initialize_assistant():
120
+ return "I apologize, but I'm currently unavailable. Please try again later."
121
 
122
+ try:
123
+ return assistant.generate_response(message, history)
124
+ except Exception as e:
125
+ logger.error(f"Error in chat response: {str(e)}")
126
+ return "I encountered an error. Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ # Create Gradio interface
129
  demo = gr.ChatInterface(
130
+ fn=chat_response,
131
+ title="Medical Assistant (Test Version)",
132
+ description="""This is a test version of the medical assistant.
133
+ Please use it to verify basic functionality.""",
134
  examples=[
135
+ "What are the symptoms of malaria?",
136
+ "How can I prevent type 2 diabetes?",
137
+ "What should I do for a mild headache?"
138
  ],
139
+ # retry_btn=None,
140
+ # undo_btn=None,
141
+ # clear_btn="Clear"
 
142
  )
143
 
144
+ # Launch the interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if __name__ == "__main__":
146
+ demo.launch()