import os import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch from typing import List, Dict import logging # Set up logging to help us debug model loading and inference logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class MedicalAssistant: def __init__(self): """Initialize the medical assistant with model and tokenizer""" try: logger.info("Starting model initialization...") # Model configuration - adjust these based on your available compute self.model_name = "mradermacher/Llama3-Med42-8B-GGUF" self.max_length = 1048 self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Load tokenizer first - this is typically faster and can catch issues early logger.info("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, padding_side="left", trust_remote_code=True ) # Set padding token if not set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model with memory optimizations logger.info("Loading model...") self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16, device_map="auto", load_in_8bit=True, trust_remote_code=True ) logger.info("Model initialization completed successfully!") except Exception as e: logger.error(f"Error during initialization: {str(e)}") raise def generate_response(self, message: str, chat_history: List[Dict] = None) -> str: """Generate a response to the user's message""" try: # Prepare the prompt system_prompt = """You are a medical AI assistant. Respond to medical queries professionally and accurately. If you're unsure, always recommend consulting with a healthcare provider.""" # Combine system prompt, chat history, and current message full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:" # Tokenize input inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length ).to(self.device) # Generate response with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.95, pad_token_id=self.tokenizer.pad_token_id, repetition_penalty=1.1 ) # Decode and clean up response response = self.tokenizer.decode( outputs[0], skip_special_tokens=True ) # Extract just the assistant's response response = response.split("Assistant:")[-1].strip() return response except Exception as e: logger.error(f"Error during response generation: {str(e)}") return f"I apologize, but I encountered an error. Please try again." # Initialize the assistant assistant = None def initialize_assistant(): """Initialize the assistant and handle any errors""" global assistant try: assistant = MedicalAssistant() return True except Exception as e: logger.error(f"Failed to initialize assistant: {str(e)}") return False def chat_response(message: str, history: List[Dict]): """Handle chat messages and return responses""" global assistant # Check if assistant is initialized if assistant is None: if not initialize_assistant(): return "I apologize, but I'm currently unavailable. Please try again later." try: return assistant.generate_response(message, history) except Exception as e: logger.error(f"Error in chat response: {str(e)}") return "I encountered an error. Please try again." # Create Gradio interface demo = gr.ChatInterface( fn=chat_response, title="Medical Assistant (Test Version)", description="""This is a test version of the medical assistant. Please use it to verify basic functionality.""", examples=[ "What are the symptoms of malaria?", "How can I prevent type 2 diabetes?", "What should I do for a mild headache?" ], # retry_btn=None, # undo_btn=None, # clear_btn="Clear" ) # Launch the interface if __name__ == "__main__": demo.launch()