benardo0 commited on
Commit
803e48a
·
verified ·
1 Parent(s): 7625d6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -70
app.py CHANGED
@@ -147,13 +147,13 @@
147
 
148
  import os
149
  import gradio as gr
150
- from transformers import AutoModelForCausalLM, AutoTokenizer
151
  import torch
152
  from typing import List, Dict
153
  import logging
154
  import traceback
155
 
156
- # Configure detailed logging
157
  logging.basicConfig(
158
  level=logging.INFO,
159
  format='%(asctime)s - %(levelname)s - %(message)s'
@@ -162,54 +162,40 @@ logger = logging.getLogger(__name__)
162
 
163
  class MedicalAssistant:
164
  def __init__(self):
165
- """Initialize the medical assistant with model and tokenizer"""
 
 
 
166
  try:
167
  logger.info("Starting model initialization...")
168
 
169
- # Model configuration
170
- self.model_name = "mradermacher/Llama3-Med42-8B-GGUF"
171
  self.max_length = 2048
172
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
173
 
 
174
  logger.info(f"Using device: {self.device}")
175
  logger.info(f"Available CUDA devices: {torch.cuda.device_count() if torch.cuda.is_available() else 'None'}")
176
  if torch.cuda.is_available():
177
  logger.info(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
178
 
179
- # First, verify the model exists
180
- logger.info(f"Attempting to load tokenizer from {self.model_name}")
181
- try:
182
- self.tokenizer = AutoTokenizer.from_pretrained(
183
- self.model_name,
184
- trust_remote_code=True
185
- )
186
- logger.info("Tokenizer loaded successfully")
187
- except Exception as e:
188
- logger.error(f"Failed to load tokenizer: {str(e)}")
189
- logger.error(traceback.format_exc())
190
- raise
191
 
192
- # Set padding token if not set
 
 
193
  if self.tokenizer.pad_token is None:
194
  self.tokenizer.pad_token = self.tokenizer.eos_token
195
- logger.info("Set padding token to EOS token")
196
-
197
- # Load model with more conservative settings
198
- logger.info("Loading model - this may take a few minutes...")
199
- try:
200
- self.model = AutoModelForCausalLM.from_pretrained(
201
- self.model_name,
202
- torch_dtype=torch.float16,
203
- device_map="auto",
204
- load_in_4bit=True, # More conservative than 8-bit
205
- trust_remote_code=True,
206
- low_cpu_mem_usage=True
207
- )
208
- logger.info("Model loaded successfully!")
209
- except Exception as e:
210
- logger.error(f"Failed to load model: {str(e)}")
211
- logger.error(traceback.format_exc())
212
- raise
213
 
214
  except Exception as e:
215
  logger.error(f"Initialization failed: {str(e)}")
@@ -217,43 +203,42 @@ class MedicalAssistant:
217
  raise
218
 
219
  def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
220
- """Generate a response to the user's message"""
 
 
 
221
  try:
222
- logger.info("Generating response for message")
223
 
224
- # Prepare the prompt
225
  system_prompt = """You are a medical AI assistant. Respond to medical queries
226
  professionally and accurately. If you're unsure, always recommend consulting
227
  with a healthcare provider."""
228
 
229
- full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:"
230
- logger.info("Tokenizing input")
231
-
232
- inputs = self.tokenizer(
233
- full_prompt,
234
- return_tensors="pt",
235
- padding=True,
236
- truncation=True,
237
- max_length=self.max_length
238
- )
239
 
240
- # Move inputs to the correct device
241
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
242
 
243
  logger.info("Generating response")
244
- with torch.no_grad():
245
- outputs = self.model.generate(
246
- **inputs,
247
- max_new_tokens=512,
248
- do_sample=True,
249
- temperature=0.7,
250
- top_p=0.95,
251
- pad_token_id=self.tokenizer.pad_token_id,
252
- repetition_penalty=1.1
253
- )
254
 
255
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
256
- response = response.split("Assistant:")[-1].strip()
257
 
258
  logger.info("Response generated successfully")
259
  return response
@@ -263,11 +248,14 @@ class MedicalAssistant:
263
  logger.error(traceback.format_exc())
264
  return f"I apologize, but I encountered an error: {str(e)}"
265
 
266
- # Global variable for the assistant
267
  assistant = None
268
 
269
  def initialize_assistant():
270
- """Initialize the assistant and handle any errors"""
 
 
 
271
  global assistant
272
  try:
273
  logger.info("Attempting to initialize assistant")
@@ -280,7 +268,9 @@ def initialize_assistant():
280
  return False
281
 
282
  def chat_response(message: str, history: List[Dict]):
283
- """Handle chat messages and return responses"""
 
 
284
  global assistant
285
 
286
  if assistant is None:
@@ -295,11 +285,12 @@ def chat_response(message: str, history: List[Dict]):
295
  logger.error(traceback.format_exc())
296
  return f"I encountered an error: {str(e)}"
297
 
298
- # Create Gradio interface
299
  demo = gr.ChatInterface(
300
  fn=chat_response,
301
- title="Medical Assistant (Test Version)",
302
- description="This is a test version of the medical assistant. Please use it to verify basic functionality.",
 
303
  examples=[
304
  "What are the symptoms of malaria?",
305
  "How can I prevent type 2 diabetes?",
@@ -307,7 +298,7 @@ demo = gr.ChatInterface(
307
  ]
308
  )
309
 
310
- # Launch the interface
311
  if __name__ == "__main__":
312
  logger.info("Starting the application")
313
  demo.launch()
 
147
 
148
  import os
149
  import gradio as gr
150
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
151
  import torch
152
  from typing import List, Dict
153
  import logging
154
  import traceback
155
 
156
+ # Configure detailed logging to help us track the model's behavior
157
  logging.basicConfig(
158
  level=logging.INFO,
159
  format='%(asctime)s - %(levelname)s - %(message)s'
 
162
 
163
  class MedicalAssistant:
164
  def __init__(self):
165
+ """
166
+ Initialize the medical assistant using a pre-quantized 4-bit model.
167
+ This approach uses less memory while maintaining good performance.
168
+ """
169
  try:
170
  logger.info("Starting model initialization...")
171
 
172
+ # Define model configuration
173
+ self.model_name = "emircanerol/Llama3-Med42-8B-4bit"
174
  self.max_length = 2048
175
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
176
 
177
+ # Log system information for debugging
178
  logger.info(f"Using device: {self.device}")
179
  logger.info(f"Available CUDA devices: {torch.cuda.device_count() if torch.cuda.is_available() else 'None'}")
180
  if torch.cuda.is_available():
181
  logger.info(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
182
 
183
+ # Initialize the pipeline for text generation
184
+ logger.info("Initializing text generation pipeline...")
185
+ self.pipe = pipeline(
186
+ "text-generation",
187
+ model=self.model_name,
188
+ device_map="auto",
189
+ torch_dtype=torch.float16
190
+ )
191
+ logger.info("Pipeline initialized successfully!")
 
 
 
192
 
193
+ # Load tokenizer separately for more control over text processing
194
+ logger.info("Loading tokenizer...")
195
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
196
  if self.tokenizer.pad_token is None:
197
  self.tokenizer.pad_token = self.tokenizer.eos_token
198
+ logger.info("Tokenizer loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  except Exception as e:
201
  logger.error(f"Initialization failed: {str(e)}")
 
203
  raise
204
 
205
  def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
206
+ """
207
+ Generate a response using the text generation pipeline.
208
+ The pipeline handles most of the complexity of text generation for us.
209
+ """
210
  try:
211
+ logger.info("Preparing message for generation")
212
 
213
+ # Prepare the conversation format
214
  system_prompt = """You are a medical AI assistant. Respond to medical queries
215
  professionally and accurately. If you're unsure, always recommend consulting
216
  with a healthcare provider."""
217
 
218
+ # Format messages for the model
219
+ messages = [
220
+ {"role": "system", "content": system_prompt},
221
+ {"role": "user", "content": message}
222
+ ]
 
 
 
 
 
223
 
224
+ # Convert messages to a format the model expects
225
+ prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
226
+ prompt += "\nassistant:"
227
 
228
  logger.info("Generating response")
229
+ # Generate response using the pipeline
230
+ response = self.pipe(
231
+ prompt,
232
+ max_new_tokens=512,
233
+ do_sample=True,
234
+ temperature=0.7,
235
+ top_p=0.95,
236
+ repetition_penalty=1.1,
237
+ pad_token_id=self.tokenizer.pad_token_id
238
+ )[0]["generated_text"]
239
 
240
+ # Extract the assistant's response from the full generated text
241
+ response = response.split("assistant:")[-1].strip()
242
 
243
  logger.info("Response generated successfully")
244
  return response
 
248
  logger.error(traceback.format_exc())
249
  return f"I apologize, but I encountered an error: {str(e)}"
250
 
251
+ # Initialize our global assistant
252
  assistant = None
253
 
254
  def initialize_assistant():
255
+ """
256
+ Initialize the assistant with error handling and logging.
257
+ This helps us track any issues during startup.
258
+ """
259
  global assistant
260
  try:
261
  logger.info("Attempting to initialize assistant")
 
268
  return False
269
 
270
  def chat_response(message: str, history: List[Dict]):
271
+ """
272
+ Handle chat messages and maintain conversation context.
273
+ """
274
  global assistant
275
 
276
  if assistant is None:
 
285
  logger.error(traceback.format_exc())
286
  return f"I encountered an error: {str(e)}"
287
 
288
+ # Create the Gradio interface with a clean, professional design
289
  demo = gr.ChatInterface(
290
  fn=chat_response,
291
+ title="Medical Assistant (4-bit Quantized Version)",
292
+ description="""This medical assistant uses a 4-bit quantized model for efficient operation.
293
+ It provides medical guidance while ensuring comprehensive health information gathering.""",
294
  examples=[
295
  "What are the symptoms of malaria?",
296
  "How can I prevent type 2 diabetes?",
 
298
  ]
299
  )
300
 
301
+ # Launch the application
302
  if __name__ == "__main__":
303
  logger.info("Starting the application")
304
  demo.launch()