Nurses / app.py
benardo0's picture
Update app.py
4b15044 verified
raw
history blame
11.7 kB
# import os
# import gradio as gr
# from transformers import AutoModelForCausalLM, AutoTokenizer
# import torch
# from typing import List, Dict
# import logging
# # Set up logging to help us debug model loading and inference
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# class MedicalAssistant:
# def __init__(self):
# """Initialize the medical assistant with model and tokenizer"""
# try:
# logger.info("Starting model initialization...")
# # Model configuration - adjust these based on your available compute
# self.model_name = "mradermacher/Llama3-Med42-8B-GGUF"
# self.max_length = 1048
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# logger.info(f"Using device: {self.device}")
# # Load tokenizer first - this is typically faster and can catch issues early
# logger.info("Loading tokenizer...")
# self.tokenizer = AutoTokenizer.from_pretrained(
# self.model_name,
# padding_side="left",
# trust_remote_code=True
# )
# # Set padding token if not set
# if self.tokenizer.pad_token is None:
# self.tokenizer.pad_token = self.tokenizer.eos_token
# # Load model with memory optimizations
# logger.info("Loading model...")
# self.model = AutoModelForCausalLM.from_pretrained(
# self.model_name,
# torch_dtype=torch.float16,
# device_map="auto",
# load_in_8bit=True,
# trust_remote_code=True
# )
# logger.info("Model initialization completed successfully!")
# except Exception as e:
# logger.error(f"Error during initialization: {str(e)}")
# raise
# def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
# """Generate a response to the user's message"""
# try:
# # Prepare the prompt
# system_prompt = """You are a medical AI assistant. Respond to medical queries
# professionally and accurately. If you're unsure, always recommend consulting
# with a healthcare provider."""
# # Combine system prompt, chat history, and current message
# full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:"
# # Tokenize input
# inputs = self.tokenizer(
# full_prompt,
# return_tensors="pt",
# padding=True,
# truncation=True,
# max_length=self.max_length
# ).to(self.device)
# # Generate response
# with torch.no_grad():
# outputs = self.model.generate(
# **inputs,
# max_new_tokens=512,
# do_sample=True,
# temperature=0.7,
# top_p=0.95,
# pad_token_id=self.tokenizer.pad_token_id,
# repetition_penalty=1.1
# )
# # Decode and clean up response
# response = self.tokenizer.decode(
# outputs[0],
# skip_special_tokens=True
# )
# # Extract just the assistant's response
# response = response.split("Assistant:")[-1].strip()
# return response
# except Exception as e:
# logger.error(f"Error during response generation: {str(e)}")
# return f"I apologize, but I encountered an error. Please try again."
# # Initialize the assistant
# assistant = None
# def initialize_assistant():
# """Initialize the assistant and handle any errors"""
# global assistant
# try:
# assistant = MedicalAssistant()
# return True
# except Exception as e:
# logger.error(f"Failed to initialize assistant: {str(e)}")
# return False
# def chat_response(message: str, history: List[Dict]):
# """Handle chat messages and return responses"""
# global assistant
# # Check if assistant is initialized
# if assistant is None:
# if not initialize_assistant():
# return "I apologize, but I'm currently unavailable. Please try again later."
# try:
# return assistant.generate_response(message, history)
# except Exception as e:
# logger.error(f"Error in chat response: {str(e)}")
# return "I encountered an error. Please try again."
# # Create Gradio interface
# demo = gr.ChatInterface(
# fn=chat_response,
# title="Medical Assistant (Test Version)",
# description="""This is a test version of the medical assistant.
# Please use it to verify basic functionality.""",
# examples=[
# "What are the symptoms of malaria?",
# "How can I prevent type 2 diabetes?",
# "What should I do for a mild headache?"
# ],
# # retry_btn=None,
# # undo_btn=None,
# # clear_btn="Clear"
# )
# # Launch the interface
# if __name__ == "__main__":
# demo.launch()
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from typing import List, Dict
import logging
import traceback
# Set up logging to help us track what's happening
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class MedicalAssistant:
def __init__(self):
"""
Initialize the medical assistant with the Llama3-Med42 model.
This model is specifically trained on medical data and quantized to 4-bit precision
for better memory efficiency while maintaining good performance.
"""
try:
logger.info("Starting model initialization...")
# Updated model to use Llama3-Med42
self.model_name = "emircanerol/Llama3-Med42-8B-4bit"
self.max_length = 2048
# Initialize the pipeline for simplified text generation
# The pipeline handles tokenizer and model loading automatically
logger.info("Initializing pipeline...")
self.pipe = pipeline(
"text-generation",
model=self.model_name,
token=os.getenv('HUGGING_FACE_TOKEN'),
device_map="auto",
torch_dtype=torch.float16, # Use half precision for 4-bit model
load_in_4bit=True # Enable 4-bit quantization
)
# Load tokenizer separately for more control over text processing
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
token=os.getenv('HUGGING_FACE_TOKEN'),
trust_remote_code=True
)
# Ensure proper padding token configuration
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("Medical Assistant initialized successfully!")
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
logger.error(traceback.format_exc())
raise
def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
"""
Generate a response using the Llama3-Med42 pipeline.
This method formats the conversation history and generates appropriate medical responses.
"""
try:
logger.info("Preparing message for generation")
# Create a medical context-aware prompt
system_prompt = """You are a medical AI assistant based on Llama3-Med42,
specifically trained on medical knowledge. Provide accurate, professional
medical guidance while acknowledging limitations. Always recommend
consulting healthcare providers for specific medical advice."""
# Format the conversation for the model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message}
]
# Add chat history if available
if chat_history:
for chat in chat_history:
messages.append({
"role": "user" if chat["role"] == "user" else "assistant",
"content": chat["content"]
})
logger.info("Generating response")
# Generate response using the pipeline
response = self.pipe(
messages,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1
)[0]["generated_text"]
# Clean up the response by extracting the last assistant message
response = response.split("assistant:")[-1].strip()
logger.info("Response generated successfully")
return response
except Exception as e:
logger.error(f"Error during response generation: {str(e)}")
logger.error(traceback.format_exc())
return f"I apologize, but I encountered an error: {str(e)}"
# Initialize the assistant
assistant = None
def initialize_assistant():
"""Initialize the assistant with proper error handling"""
global assistant
try:
logger.info("Attempting to initialize assistant")
assistant = MedicalAssistant()
logger.info("Assistant initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize assistant: {str(e)}")
logger.error(traceback.format_exc())
return False
def chat_response(message: str, history: List[Dict]):
"""Handle chat interactions with error recovery"""
global assistant
if assistant is None:
logger.info("Assistant not initialized, attempting initialization")
if not initialize_assistant():
return "I apologize, but I'm currently unavailable. Please try again later."
try:
return assistant.generate_response(message, history)
except Exception as e:
logger.error(f"Error in chat response: {str(e)}")
logger.error(traceback.format_exc())
return f"I encountered an error: {str(e)}"
# Create the Gradio interface
demo = gr.ChatInterface(
fn=chat_response,
title="Medical Assistant (Llama3-Med42)",
description="""This medical assistant is powered by Llama3-Med42,
a model specifically trained on medical knowledge. It provides
guidance and information about health-related queries while
maintaining professional medical standards.""",
examples=[
"What are the symptoms of malaria?",
"How can I prevent type 2 diabetes?",
"What should I do for a mild headache?"
]
)
# Launch the interface
if __name__ == "__main__":
logger.info("Starting the application")
demo.launch()