# import os | |
# import gradio as gr | |
# from transformers import AutoModelForCausalLM, AutoTokenizer | |
# import torch | |
# from typing import List, Dict | |
# import logging | |
# # Set up logging to help us debug model loading and inference | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# class MedicalAssistant: | |
# def __init__(self): | |
# """Initialize the medical assistant with model and tokenizer""" | |
# try: | |
# logger.info("Starting model initialization...") | |
# # Model configuration - adjust these based on your available compute | |
# self.model_name = "mradermacher/Llama3-Med42-8B-GGUF" | |
# self.max_length = 1048 | |
# self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# logger.info(f"Using device: {self.device}") | |
# # Load tokenizer first - this is typically faster and can catch issues early | |
# logger.info("Loading tokenizer...") | |
# self.tokenizer = AutoTokenizer.from_pretrained( | |
# self.model_name, | |
# padding_side="left", | |
# trust_remote_code=True | |
# ) | |
# # Set padding token if not set | |
# if self.tokenizer.pad_token is None: | |
# self.tokenizer.pad_token = self.tokenizer.eos_token | |
# # Load model with memory optimizations | |
# logger.info("Loading model...") | |
# self.model = AutoModelForCausalLM.from_pretrained( | |
# self.model_name, | |
# torch_dtype=torch.float16, | |
# device_map="auto", | |
# load_in_8bit=True, | |
# trust_remote_code=True | |
# ) | |
# logger.info("Model initialization completed successfully!") | |
# except Exception as e: | |
# logger.error(f"Error during initialization: {str(e)}") | |
# raise | |
# def generate_response(self, message: str, chat_history: List[Dict] = None) -> str: | |
# """Generate a response to the user's message""" | |
# try: | |
# # Prepare the prompt | |
# system_prompt = """You are a medical AI assistant. Respond to medical queries | |
# professionally and accurately. If you're unsure, always recommend consulting | |
# with a healthcare provider.""" | |
# # Combine system prompt, chat history, and current message | |
# full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:" | |
# # Tokenize input | |
# inputs = self.tokenizer( | |
# full_prompt, | |
# return_tensors="pt", | |
# padding=True, | |
# truncation=True, | |
# max_length=self.max_length | |
# ).to(self.device) | |
# # Generate response | |
# with torch.no_grad(): | |
# outputs = self.model.generate( | |
# **inputs, | |
# max_new_tokens=512, | |
# do_sample=True, | |
# temperature=0.7, | |
# top_p=0.95, | |
# pad_token_id=self.tokenizer.pad_token_id, | |
# repetition_penalty=1.1 | |
# ) | |
# # Decode and clean up response | |
# response = self.tokenizer.decode( | |
# outputs[0], | |
# skip_special_tokens=True | |
# ) | |
# # Extract just the assistant's response | |
# response = response.split("Assistant:")[-1].strip() | |
# return response | |
# except Exception as e: | |
# logger.error(f"Error during response generation: {str(e)}") | |
# return f"I apologize, but I encountered an error. Please try again." | |
# # Initialize the assistant | |
# assistant = None | |
# def initialize_assistant(): | |
# """Initialize the assistant and handle any errors""" | |
# global assistant | |
# try: | |
# assistant = MedicalAssistant() | |
# return True | |
# except Exception as e: | |
# logger.error(f"Failed to initialize assistant: {str(e)}") | |
# return False | |
# def chat_response(message: str, history: List[Dict]): | |
# """Handle chat messages and return responses""" | |
# global assistant | |
# # Check if assistant is initialized | |
# if assistant is None: | |
# if not initialize_assistant(): | |
# return "I apologize, but I'm currently unavailable. Please try again later." | |
# try: | |
# return assistant.generate_response(message, history) | |
# except Exception as e: | |
# logger.error(f"Error in chat response: {str(e)}") | |
# return "I encountered an error. Please try again." | |
# # Create Gradio interface | |
# demo = gr.ChatInterface( | |
# fn=chat_response, | |
# title="Medical Assistant (Test Version)", | |
# description="""This is a test version of the medical assistant. | |
# Please use it to verify basic functionality.""", | |
# examples=[ | |
# "What are the symptoms of malaria?", | |
# "How can I prevent type 2 diabetes?", | |
# "What should I do for a mild headache?" | |
# ], | |
# # retry_btn=None, | |
# # undo_btn=None, | |
# # clear_btn="Clear" | |
# ) | |
# # Launch the interface | |
# if __name__ == "__main__": | |
# demo.launch() | |
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
from typing import List, Dict | |
import logging | |
import traceback | |
# Set up logging to help us track what's happening | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class MedicalAssistant: | |
def __init__(self): | |
""" | |
Initialize the medical assistant with the Llama3-Med42 model. | |
This model is specifically trained on medical data and quantized to 4-bit precision | |
for better memory efficiency while maintaining good performance. | |
""" | |
try: | |
logger.info("Starting model initialization...") | |
# Updated model to use Llama3-Med42 | |
self.model_name = "emircanerol/Llama3-Med42-8B-4bit" | |
self.max_length = 2048 | |
# Initialize the pipeline for simplified text generation | |
# The pipeline handles tokenizer and model loading automatically | |
logger.info("Initializing pipeline...") | |
self.pipe = pipeline( | |
"text-generation", | |
model=self.model_name, | |
token=os.getenv('HUGGING_FACE_TOKEN'), | |
device_map="auto", | |
torch_dtype=torch.float16, # Use half precision for 4-bit model | |
load_in_4bit=True # Enable 4-bit quantization | |
) | |
# Load tokenizer separately for more control over text processing | |
logger.info("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
token=os.getenv('HUGGING_FACE_TOKEN'), | |
trust_remote_code=True | |
) | |
# Ensure proper padding token configuration | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
logger.info("Medical Assistant initialized successfully!") | |
except Exception as e: | |
logger.error(f"Initialization failed: {str(e)}") | |
logger.error(traceback.format_exc()) | |
raise | |
def generate_response(self, message: str, chat_history: List[Dict] = None) -> str: | |
""" | |
Generate a response using the Llama3-Med42 pipeline. | |
This method formats the conversation history and generates appropriate medical responses. | |
""" | |
try: | |
logger.info("Preparing message for generation") | |
# Create a medical context-aware prompt | |
system_prompt = """You are a medical AI assistant based on Llama3-Med42, | |
specifically trained on medical knowledge. Provide accurate, professional | |
medical guidance while acknowledging limitations. Always recommend | |
consulting healthcare providers for specific medical advice.""" | |
# Format the conversation for the model | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": message} | |
] | |
# Add chat history if available | |
if chat_history: | |
for chat in chat_history: | |
messages.append({ | |
"role": "user" if chat["role"] == "user" else "assistant", | |
"content": chat["content"] | |
}) | |
logger.info("Generating response") | |
# Generate response using the pipeline | |
response = self.pipe( | |
messages, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.1 | |
)[0]["generated_text"] | |
# Clean up the response by extracting the last assistant message | |
response = response.split("assistant:")[-1].strip() | |
logger.info("Response generated successfully") | |
return response | |
except Exception as e: | |
logger.error(f"Error during response generation: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return f"I apologize, but I encountered an error: {str(e)}" | |
# Initialize the assistant | |
assistant = None | |
def initialize_assistant(): | |
"""Initialize the assistant with proper error handling""" | |
global assistant | |
try: | |
logger.info("Attempting to initialize assistant") | |
assistant = MedicalAssistant() | |
logger.info("Assistant initialized successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to initialize assistant: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return False | |
def chat_response(message: str, history: List[Dict]): | |
"""Handle chat interactions with error recovery""" | |
global assistant | |
if assistant is None: | |
logger.info("Assistant not initialized, attempting initialization") | |
if not initialize_assistant(): | |
return "I apologize, but I'm currently unavailable. Please try again later." | |
try: | |
return assistant.generate_response(message, history) | |
except Exception as e: | |
logger.error(f"Error in chat response: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return f"I encountered an error: {str(e)}" | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=chat_response, | |
title="Medical Assistant (Llama3-Med42)", | |
description="""This medical assistant is powered by Llama3-Med42, | |
a model specifically trained on medical knowledge. It provides | |
guidance and information about health-related queries while | |
maintaining professional medical standards.""", | |
examples=[ | |
"What are the symptoms of malaria?", | |
"How can I prevent type 2 diabetes?", | |
"What should I do for a mild headache?" | |
] | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
logger.info("Starting the application") | |
demo.launch() |