benardo0 commited on
Commit
4b15044
·
verified ·
1 Parent(s): 17025e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -61
app.py CHANGED
@@ -163,47 +163,42 @@ logger = logging.getLogger(__name__)
163
  class MedicalAssistant:
164
  def __init__(self):
165
  """
166
- Initialize the medical assistant with CPU-friendly settings.
167
- We use a base model instead of a quantized version to ensure CPU compatibility.
 
168
  """
169
  try:
170
  logger.info("Starting model initialization...")
171
 
172
- # Using a standard model instead of a 4-bit quantized version
173
- # This model is larger but more compatible with CPU-only environments
174
- self.model_name = "meta-llama/Llama-2-7b-chat-hf"
175
  self.max_length = 2048
176
 
177
- # First load the tokenizer as it's lighter on memory
 
 
 
 
 
 
 
 
 
 
 
 
178
  logger.info("Loading tokenizer...")
179
  self.tokenizer = AutoTokenizer.from_pretrained(
180
  self.model_name,
181
- token=os.getenv('HUGGING_FACE_TOKEN'), # Add your token in Space settings
182
  trust_remote_code=True
183
  )
184
 
185
- # Handle padding token
186
  if self.tokenizer.pad_token is None:
187
  self.tokenizer.pad_token = self.tokenizer.eos_token
188
- logger.info("Tokenizer loaded successfully")
189
-
190
- # Load model with CPU-friendly settings
191
- logger.info("Loading model - this may take a few minutes...")
192
- self.model = AutoModelForCausalLM.from_pretrained(
193
- self.model_name,
194
- token=os.getenv('HUGGING_FACE_TOKEN'),
195
- device_map="auto", # This will default to CPU if no GPU is available
196
- torch_dtype=torch.float32, # Standard precision for CPU
197
- low_cpu_mem_usage=True, # Optimize memory usage
198
- offload_folder="offload" # Enable disk offloading for memory management
199
- )
200
-
201
- # Move model explicitly to CPU and clear any GPU memory
202
- self.model = self.model.to('cpu')
203
- if torch.cuda.is_available():
204
- torch.cuda.empty_cache()
205
-
206
- logger.info("Model loaded successfully!")
207
 
208
  except Exception as e:
209
  logger.error(f"Initialization failed: {str(e)}")
@@ -212,45 +207,45 @@ class MedicalAssistant:
212
 
213
  def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
214
  """
215
- Generate a response directly using the model instead of a pipeline.
216
- This gives us more control over the generation process.
217
  """
218
  try:
219
  logger.info("Preparing message for generation")
220
 
221
  # Create a medical context-aware prompt
222
- system_prompt = """You are a medical AI assistant. Provide accurate,
223
- professional medical guidance. Always recommend consulting healthcare
224
- providers for specific medical advice."""
 
225
 
226
- # Format the conversation
227
- prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:"
 
 
 
228
 
229
- # Tokenize the input
230
- inputs = self.tokenizer(
231
- prompt,
232
- return_tensors="pt",
233
- padding=True,
234
- truncation=True,
235
- max_length=self.max_length
236
- ).to('cpu') # Ensure inputs are on CPU
237
 
238
  logger.info("Generating response")
239
- # Generate with conservative settings for CPU
240
- with torch.no_grad(): # Disable gradient computation to save memory
241
- outputs = self.model.generate(
242
- **inputs,
243
- max_new_tokens=256, # Reduced for CPU efficiency
244
- do_sample=True,
245
- temperature=0.7,
246
- top_p=0.95,
247
- pad_token_id=self.tokenizer.pad_token_id,
248
- repetition_penalty=1.1
249
- )
250
 
251
- # Decode and clean up the response
252
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
253
- response = response.split("Assistant:")[-1].strip()
254
 
255
  logger.info("Response generated successfully")
256
  return response
@@ -260,7 +255,7 @@ class MedicalAssistant:
260
  logger.error(traceback.format_exc())
261
  return f"I apologize, but I encountered an error: {str(e)}"
262
 
263
- # The rest of your code remains the same
264
  assistant = None
265
 
266
  def initialize_assistant():
@@ -295,10 +290,11 @@ def chat_response(message: str, history: List[Dict]):
295
  # Create the Gradio interface
296
  demo = gr.ChatInterface(
297
  fn=chat_response,
298
- title="Medical Assistant (CPU Version)",
299
- description="""This medical assistant provides guidance and information
300
- about health-related queries. Please note that response
301
- generation may take longer as this is running in CPU mode.""",
 
302
  examples=[
303
  "What are the symptoms of malaria?",
304
  "How can I prevent type 2 diabetes?",
 
163
  class MedicalAssistant:
164
  def __init__(self):
165
  """
166
+ Initialize the medical assistant with the Llama3-Med42 model.
167
+ This model is specifically trained on medical data and quantized to 4-bit precision
168
+ for better memory efficiency while maintaining good performance.
169
  """
170
  try:
171
  logger.info("Starting model initialization...")
172
 
173
+ # Updated model to use Llama3-Med42
174
+ self.model_name = "emircanerol/Llama3-Med42-8B-4bit"
 
175
  self.max_length = 2048
176
 
177
+ # Initialize the pipeline for simplified text generation
178
+ # The pipeline handles tokenizer and model loading automatically
179
+ logger.info("Initializing pipeline...")
180
+ self.pipe = pipeline(
181
+ "text-generation",
182
+ model=self.model_name,
183
+ token=os.getenv('HUGGING_FACE_TOKEN'),
184
+ device_map="auto",
185
+ torch_dtype=torch.float16, # Use half precision for 4-bit model
186
+ load_in_4bit=True # Enable 4-bit quantization
187
+ )
188
+
189
+ # Load tokenizer separately for more control over text processing
190
  logger.info("Loading tokenizer...")
191
  self.tokenizer = AutoTokenizer.from_pretrained(
192
  self.model_name,
193
+ token=os.getenv('HUGGING_FACE_TOKEN'),
194
  trust_remote_code=True
195
  )
196
 
197
+ # Ensure proper padding token configuration
198
  if self.tokenizer.pad_token is None:
199
  self.tokenizer.pad_token = self.tokenizer.eos_token
200
+
201
+ logger.info("Medical Assistant initialized successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  except Exception as e:
204
  logger.error(f"Initialization failed: {str(e)}")
 
207
 
208
  def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
209
  """
210
+ Generate a response using the Llama3-Med42 pipeline.
211
+ This method formats the conversation history and generates appropriate medical responses.
212
  """
213
  try:
214
  logger.info("Preparing message for generation")
215
 
216
  # Create a medical context-aware prompt
217
+ system_prompt = """You are a medical AI assistant based on Llama3-Med42,
218
+ specifically trained on medical knowledge. Provide accurate, professional
219
+ medical guidance while acknowledging limitations. Always recommend
220
+ consulting healthcare providers for specific medical advice."""
221
 
222
+ # Format the conversation for the model
223
+ messages = [
224
+ {"role": "system", "content": system_prompt},
225
+ {"role": "user", "content": message}
226
+ ]
227
 
228
+ # Add chat history if available
229
+ if chat_history:
230
+ for chat in chat_history:
231
+ messages.append({
232
+ "role": "user" if chat["role"] == "user" else "assistant",
233
+ "content": chat["content"]
234
+ })
 
235
 
236
  logger.info("Generating response")
237
+ # Generate response using the pipeline
238
+ response = self.pipe(
239
+ messages,
240
+ max_new_tokens=256,
241
+ do_sample=True,
242
+ temperature=0.7,
243
+ top_p=0.95,
244
+ repetition_penalty=1.1
245
+ )[0]["generated_text"]
 
 
246
 
247
+ # Clean up the response by extracting the last assistant message
248
+ response = response.split("assistant:")[-1].strip()
 
249
 
250
  logger.info("Response generated successfully")
251
  return response
 
255
  logger.error(traceback.format_exc())
256
  return f"I apologize, but I encountered an error: {str(e)}"
257
 
258
+ # Initialize the assistant
259
  assistant = None
260
 
261
  def initialize_assistant():
 
290
  # Create the Gradio interface
291
  demo = gr.ChatInterface(
292
  fn=chat_response,
293
+ title="Medical Assistant (Llama3-Med42)",
294
+ description="""This medical assistant is powered by Llama3-Med42,
295
+ a model specifically trained on medical knowledge. It provides
296
+ guidance and information about health-related queries while
297
+ maintaining professional medical standards.""",
298
  examples=[
299
  "What are the symptoms of malaria?",
300
  "How can I prevent type 2 diabetes?",