# import os # import gradio as gr # from transformers import AutoModelForCausalLM, AutoTokenizer # import torch # from typing import List, Dict # import logging # # Set up logging to help us debug model loading and inference # logging.basicConfig(level=logging.INFO) # logger = logging.getLogger(__name__) # class MedicalAssistant: # def __init__(self): # """Initialize the medical assistant with model and tokenizer""" # try: # logger.info("Starting model initialization...") # # Model configuration - adjust these based on your available compute # self.model_name = "mradermacher/Llama3-Med42-8B-GGUF" # self.max_length = 1048 # self.device = "cuda" if torch.cuda.is_available() else "cpu" # logger.info(f"Using device: {self.device}") # # Load tokenizer first - this is typically faster and can catch issues early # logger.info("Loading tokenizer...") # self.tokenizer = AutoTokenizer.from_pretrained( # self.model_name, # padding_side="left", # trust_remote_code=True # ) # # Set padding token if not set # if self.tokenizer.pad_token is None: # self.tokenizer.pad_token = self.tokenizer.eos_token # # Load model with memory optimizations # logger.info("Loading model...") # self.model = AutoModelForCausalLM.from_pretrained( # self.model_name, # torch_dtype=torch.float16, # device_map="auto", # load_in_8bit=True, # trust_remote_code=True # ) # logger.info("Model initialization completed successfully!") # except Exception as e: # logger.error(f"Error during initialization: {str(e)}") # raise # def generate_response(self, message: str, chat_history: List[Dict] = None) -> str: # """Generate a response to the user's message""" # try: # # Prepare the prompt # system_prompt = """You are a medical AI assistant. Respond to medical queries # professionally and accurately. If you're unsure, always recommend consulting # with a healthcare provider.""" # # Combine system prompt, chat history, and current message # full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:" # # Tokenize input # inputs = self.tokenizer( # full_prompt, # return_tensors="pt", # padding=True, # truncation=True, # max_length=self.max_length # ).to(self.device) # # Generate response # with torch.no_grad(): # outputs = self.model.generate( # **inputs, # max_new_tokens=512, # do_sample=True, # temperature=0.7, # top_p=0.95, # pad_token_id=self.tokenizer.pad_token_id, # repetition_penalty=1.1 # ) # # Decode and clean up response # response = self.tokenizer.decode( # outputs[0], # skip_special_tokens=True # ) # # Extract just the assistant's response # response = response.split("Assistant:")[-1].strip() # return response # except Exception as e: # logger.error(f"Error during response generation: {str(e)}") # return f"I apologize, but I encountered an error. Please try again." # # Initialize the assistant # assistant = None # def initialize_assistant(): # """Initialize the assistant and handle any errors""" # global assistant # try: # assistant = MedicalAssistant() # return True # except Exception as e: # logger.error(f"Failed to initialize assistant: {str(e)}") # return False # def chat_response(message: str, history: List[Dict]): # """Handle chat messages and return responses""" # global assistant # # Check if assistant is initialized # if assistant is None: # if not initialize_assistant(): # return "I apologize, but I'm currently unavailable. Please try again later." # try: # return assistant.generate_response(message, history) # except Exception as e: # logger.error(f"Error in chat response: {str(e)}") # return "I encountered an error. Please try again." # # Create Gradio interface # demo = gr.ChatInterface( # fn=chat_response, # title="Medical Assistant (Test Version)", # description="""This is a test version of the medical assistant. # Please use it to verify basic functionality.""", # examples=[ # "What are the symptoms of malaria?", # "How can I prevent type 2 diabetes?", # "What should I do for a mild headache?" # ], # # retry_btn=None, # # undo_btn=None, # # clear_btn="Clear" # ) # # Launch the interface # if __name__ == "__main__": # demo.launch() import os import gradio as gr from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM import torch from typing import List, Dict import logging import traceback # Configure detailed logging to help us track the model's behavior logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class MedicalAssistant: def __init__(self): """ Initialize the medical assistant using a pre-quantized 4-bit model. This approach uses less memory while maintaining good performance. """ try: logger.info("Starting model initialization...") # Define model configuration self.model_name = "emircanerol/Llama3-Med42-8B-4bit" self.max_length = 2048 self.device = "cuda" if torch.cuda.is_available() else "cpu" # Log system information for debugging logger.info(f"Using device: {self.device}") logger.info(f"Available CUDA devices: {torch.cuda.device_count() if torch.cuda.is_available() else 'None'}") if torch.cuda.is_available(): logger.info(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") # Initialize the pipeline for text generation logger.info("Initializing text generation pipeline...") self.pipe = pipeline( "text-generation", model=self.model_name, device_map="auto", torch_dtype=torch.float16 ) logger.info("Pipeline initialized successfully!") # Load tokenizer separately for more control over text processing logger.info("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token logger.info("Tokenizer loaded successfully!") except Exception as e: logger.error(f"Initialization failed: {str(e)}") logger.error(traceback.format_exc()) raise def generate_response(self, message: str, chat_history: List[Dict] = None) -> str: """ Generate a response using the text generation pipeline. The pipeline handles most of the complexity of text generation for us. """ try: logger.info("Preparing message for generation") # Prepare the conversation format system_prompt = """You are a medical AI assistant. Respond to medical queries professionally and accurately. If you're unsure, always recommend consulting with a healthcare provider.""" # Format messages for the model messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": message} ] # Convert messages to a format the model expects prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) prompt += "\nassistant:" logger.info("Generating response") # Generate response using the pipeline response = self.pipe( prompt, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.95, repetition_penalty=1.1, pad_token_id=self.tokenizer.pad_token_id )[0]["generated_text"] # Extract the assistant's response from the full generated text response = response.split("assistant:")[-1].strip() logger.info("Response generated successfully") return response except Exception as e: logger.error(f"Error during response generation: {str(e)}") logger.error(traceback.format_exc()) return f"I apologize, but I encountered an error: {str(e)}" # Initialize our global assistant assistant = None def initialize_assistant(): """ Initialize the assistant with error handling and logging. This helps us track any issues during startup. """ global assistant try: logger.info("Attempting to initialize assistant") assistant = MedicalAssistant() logger.info("Assistant initialized successfully") return True except Exception as e: logger.error(f"Failed to initialize assistant: {str(e)}") logger.error(traceback.format_exc()) return False def chat_response(message: str, history: List[Dict]): """ Handle chat messages and maintain conversation context. """ global assistant if assistant is None: logger.info("Assistant not initialized, attempting initialization") if not initialize_assistant(): return "I apologize, but I'm currently unavailable. The error has been logged for investigation." try: return assistant.generate_response(message, history) except Exception as e: logger.error(f"Error in chat response: {str(e)}") logger.error(traceback.format_exc()) return f"I encountered an error: {str(e)}" # Create the Gradio interface with a clean, professional design demo = gr.ChatInterface( fn=chat_response, title="Medical Assistant (4-bit Quantized Version)", description="""This medical assistant uses a 4-bit quantized model for efficient operation. It provides medical guidance while ensuring comprehensive health information gathering.""", examples=[ "What are the symptoms of malaria?", "How can I prevent type 2 diabetes?", "What should I do for a mild headache?" ] ) # Launch the application if __name__ == "__main__": logger.info("Starting the application") demo.launch()