benardo0 commited on
Commit
030bf70
·
verified ·
1 Parent(s): 3c8150d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -56
app.py CHANGED
@@ -147,13 +147,13 @@
147
 
148
  import os
149
  import gradio as gr
150
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
151
  import torch
152
  from typing import List, Dict
153
  import logging
154
  import traceback
155
 
156
- # Configure detailed logging to help us track the model's behavior
157
  logging.basicConfig(
158
  level=logging.INFO,
159
  format='%(asctime)s - %(levelname)s - %(message)s'
@@ -163,39 +163,48 @@ logger = logging.getLogger(__name__)
163
  class MedicalAssistant:
164
  def __init__(self):
165
  """
166
- Initialize the medical assistant using a pre-quantized 4-bit model.
167
- This approach uses less memory while maintaining good performance.
168
  """
169
  try:
170
  logger.info("Starting model initialization...")
171
 
172
- # Define model configuration
173
  self.model_name = "emircanerol/Llama3-Med42-8B-4bit"
174
  self.max_length = 2048
175
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
176
 
177
- # Log system information for debugging
178
- logger.info(f"Using device: {self.device}")
179
- logger.info(f"Available CUDA devices: {torch.cuda.device_count() if torch.cuda.is_available() else 'None'}")
180
- if torch.cuda.is_available():
181
- logger.info(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
 
 
 
 
 
 
182
 
183
- # Initialize the pipeline for text generation
184
- logger.info("Initializing text generation pipeline...")
 
 
 
 
 
 
 
 
 
185
  self.pipe = pipeline(
186
  "text-generation",
187
- model=self.model_name,
188
- device_map="auto",
189
- torch_dtype=torch.float16
 
190
  )
191
- logger.info("Pipeline initialized successfully!")
192
 
193
- # Load tokenizer separately for more control over text processing
194
- logger.info("Loading tokenizer...")
195
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
196
- if self.tokenizer.pad_token is None:
197
- self.tokenizer.pad_token = self.tokenizer.eos_token
198
- logger.info("Tokenizer loaded successfully!")
199
 
200
  except Exception as e:
201
  logger.error(f"Initialization failed: {str(e)}")
@@ -205,40 +214,33 @@ class MedicalAssistant:
205
  def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
206
  """
207
  Generate a response using the text generation pipeline.
208
- The pipeline handles most of the complexity of text generation for us.
209
  """
210
  try:
211
  logger.info("Preparing message for generation")
212
 
213
- # Prepare the conversation format
214
- system_prompt = """You are a medical AI assistant. Respond to medical queries
215
- professionally and accurately. If you're unsure, always recommend consulting
216
- with a healthcare provider."""
217
-
218
- # Format messages for the model
219
- messages = [
220
- {"role": "system", "content": system_prompt},
221
- {"role": "user", "content": message}
222
- ]
223
 
224
- # Convert messages to a format the model expects
225
- prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
226
- prompt += "\nassistant:"
227
 
228
  logger.info("Generating response")
229
- # Generate response using the pipeline
230
  response = self.pipe(
231
  prompt,
232
- max_new_tokens=512,
233
  do_sample=True,
234
  temperature=0.7,
235
  top_p=0.95,
236
- repetition_penalty=1.1,
237
  pad_token_id=self.tokenizer.pad_token_id
238
  )[0]["generated_text"]
239
 
240
- # Extract the assistant's response from the full generated text
241
- response = response.split("assistant:")[-1].strip()
242
 
243
  logger.info("Response generated successfully")
244
  return response
@@ -248,14 +250,11 @@ class MedicalAssistant:
248
  logger.error(traceback.format_exc())
249
  return f"I apologize, but I encountered an error: {str(e)}"
250
 
251
- # Initialize our global assistant
252
  assistant = None
253
 
254
  def initialize_assistant():
255
- """
256
- Initialize the assistant with error handling and logging.
257
- This helps us track any issues during startup.
258
- """
259
  global assistant
260
  try:
261
  logger.info("Attempting to initialize assistant")
@@ -268,15 +267,13 @@ def initialize_assistant():
268
  return False
269
 
270
  def chat_response(message: str, history: List[Dict]):
271
- """
272
- Handle chat messages and maintain conversation context.
273
- """
274
  global assistant
275
 
276
  if assistant is None:
277
  logger.info("Assistant not initialized, attempting initialization")
278
  if not initialize_assistant():
279
- return "I apologize, but I'm currently unavailable. The error has been logged for investigation."
280
 
281
  try:
282
  return assistant.generate_response(message, history)
@@ -285,12 +282,13 @@ def chat_response(message: str, history: List[Dict]):
285
  logger.error(traceback.format_exc())
286
  return f"I encountered an error: {str(e)}"
287
 
288
- # Create the Gradio interface with a clean, professional design
289
  demo = gr.ChatInterface(
290
  fn=chat_response,
291
- title="Medical Assistant (4-bit Quantized Version)",
292
- description="""This medical assistant uses a 4-bit quantized model for efficient operation.
293
- It provides medical guidance while ensuring comprehensive health information gathering.""",
 
294
  examples=[
295
  "What are the symptoms of malaria?",
296
  "How can I prevent type 2 diabetes?",
@@ -298,7 +296,7 @@ demo = gr.ChatInterface(
298
  ]
299
  )
300
 
301
- # Launch the application
302
  if __name__ == "__main__":
303
  logger.info("Starting the application")
304
  demo.launch()
 
147
 
148
  import os
149
  import gradio as gr
150
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
151
  import torch
152
  from typing import List, Dict
153
  import logging
154
  import traceback
155
 
156
+ # Set up logging to help us track what's happening
157
  logging.basicConfig(
158
  level=logging.INFO,
159
  format='%(asctime)s - %(levelname)s - %(message)s'
 
163
  class MedicalAssistant:
164
  def __init__(self):
165
  """
166
+ Initialize the medical assistant with CPU-friendly settings.
167
+ We'll use careful memory management and avoid GPU-specific features.
168
  """
169
  try:
170
  logger.info("Starting model initialization...")
171
 
172
+ # Model configuration
173
  self.model_name = "emircanerol/Llama3-Med42-8B-4bit"
174
  self.max_length = 2048
 
175
 
176
+ # First load the tokenizer as it's lighter on memory
177
+ logger.info("Loading tokenizer...")
178
+ self.tokenizer = AutoTokenizer.from_pretrained(
179
+ self.model_name,
180
+ trust_remote_code=True
181
+ )
182
+
183
+ # Handle padding token
184
+ if self.tokenizer.pad_token is None:
185
+ self.tokenizer.pad_token = self.tokenizer.eos_token
186
+ logger.info("Tokenizer loaded successfully")
187
 
188
+ # Load model with CPU-friendly settings
189
+ logger.info("Loading model - this may take a few minutes...")
190
+ self.model = AutoModelForCausalLM.from_pretrained(
191
+ self.model_name,
192
+ torch_dtype=torch.float32, # Use float32 for CPU
193
+ low_cpu_mem_usage=True,
194
+ trust_remote_code=True
195
+ )
196
+
197
+ # Create the pipeline with our loaded components
198
+ logger.info("Creating pipeline...")
199
  self.pipe = pipeline(
200
  "text-generation",
201
+ model=self.model,
202
+ tokenizer=self.tokenizer,
203
+ device=-1, # Force CPU usage
204
+ torch_dtype=torch.float32
205
  )
 
206
 
207
+ logger.info("Initialization completed successfully!")
 
 
 
 
 
208
 
209
  except Exception as e:
210
  logger.error(f"Initialization failed: {str(e)}")
 
214
  def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
215
  """
216
  Generate a response using the text generation pipeline.
217
+ Includes careful error handling and response processing.
218
  """
219
  try:
220
  logger.info("Preparing message for generation")
221
 
222
+ # Create a medical context-aware prompt
223
+ system_prompt = """You are a medical AI assistant. Provide accurate,
224
+ professional medical guidance. Always recommend consulting healthcare
225
+ providers for specific medical advice."""
 
 
 
 
 
 
226
 
227
+ # Format the conversation
228
+ prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:"
 
229
 
230
  logger.info("Generating response")
231
+ # Generate with conservative settings for CPU
232
  response = self.pipe(
233
  prompt,
234
+ max_new_tokens=256, # Reduced for CPU efficiency
235
  do_sample=True,
236
  temperature=0.7,
237
  top_p=0.95,
238
+ num_return_sequences=1,
239
  pad_token_id=self.tokenizer.pad_token_id
240
  )[0]["generated_text"]
241
 
242
+ # Clean up the response
243
+ response = response.split("Assistant:")[-1].strip()
244
 
245
  logger.info("Response generated successfully")
246
  return response
 
250
  logger.error(traceback.format_exc())
251
  return f"I apologize, but I encountered an error: {str(e)}"
252
 
253
+ # Global assistant instance
254
  assistant = None
255
 
256
  def initialize_assistant():
257
+ """Initialize the assistant with proper error handling"""
 
 
 
258
  global assistant
259
  try:
260
  logger.info("Attempting to initialize assistant")
 
267
  return False
268
 
269
  def chat_response(message: str, history: List[Dict]):
270
+ """Handle chat interactions with error recovery"""
 
 
271
  global assistant
272
 
273
  if assistant is None:
274
  logger.info("Assistant not initialized, attempting initialization")
275
  if not initialize_assistant():
276
+ return "I apologize, but I'm currently unavailable. Please try again later."
277
 
278
  try:
279
  return assistant.generate_response(message, history)
 
282
  logger.error(traceback.format_exc())
283
  return f"I encountered an error: {str(e)}"
284
 
285
+ # Create the Gradio interface
286
  demo = gr.ChatInterface(
287
  fn=chat_response,
288
+ title="Medical Assistant (CPU Version)",
289
+ description="""This medical assistant provides guidance and information
290
+ about health-related queries. Note that this is running
291
+ in CPU mode for broader compatibility.""",
292
  examples=[
293
  "What are the symptoms of malaria?",
294
  "How can I prevent type 2 diabetes?",
 
296
  ]
297
  )
298
 
299
+ # Launch the interface
300
  if __name__ == "__main__":
301
  logger.info("Starting the application")
302
  demo.launch()