Nurses / app.py
benardo0's picture
Update app.py
b4ff37d verified
raw
history blame
5.18 kB
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Dict
import logging
# Set up logging to help us debug model loading and inference
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MedicalAssistant:
def __init__(self):
"""Initialize the medical assistant with model and tokenizer"""
try:
logger.info("Starting model initialization...")
# Model configuration - adjust these based on your available compute
self.model_name = "mradermacher/Llama3-Med42-8B-GGUF"
self.max_length = 1048
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Load tokenizer first - this is typically faster and can catch issues early
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
padding_side="left",
trust_remote_code=True
)
# Set padding token if not set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with memory optimizations
logger.info("Loading model...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True,
trust_remote_code=True
)
logger.info("Model initialization completed successfully!")
except Exception as e:
logger.error(f"Error during initialization: {str(e)}")
raise
def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
"""Generate a response to the user's message"""
try:
# Prepare the prompt
system_prompt = """You are a medical AI assistant. Respond to medical queries
professionally and accurately. If you're unsure, always recommend consulting
with a healthcare provider."""
# Combine system prompt, chat history, and current message
full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:"
# Tokenize input
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
).to(self.device)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1
)
# Decode and clean up response
response = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
# Extract just the assistant's response
response = response.split("Assistant:")[-1].strip()
return response
except Exception as e:
logger.error(f"Error during response generation: {str(e)}")
return f"I apologize, but I encountered an error. Please try again."
# Initialize the assistant
assistant = None
def initialize_assistant():
"""Initialize the assistant and handle any errors"""
global assistant
try:
assistant = MedicalAssistant()
return True
except Exception as e:
logger.error(f"Failed to initialize assistant: {str(e)}")
return False
def chat_response(message: str, history: List[Dict]):
"""Handle chat messages and return responses"""
global assistant
# Check if assistant is initialized
if assistant is None:
if not initialize_assistant():
return "I apologize, but I'm currently unavailable. Please try again later."
try:
return assistant.generate_response(message, history)
except Exception as e:
logger.error(f"Error in chat response: {str(e)}")
return "I encountered an error. Please try again."
# Create Gradio interface
demo = gr.ChatInterface(
fn=chat_response,
title="Medical Assistant (Test Version)",
description="""This is a test version of the medical assistant.
Please use it to verify basic functionality.""",
examples=[
"What are the symptoms of malaria?",
"How can I prevent type 2 diabetes?",
"What should I do for a mild headache?"
],
# retry_btn=None,
# undo_btn=None,
# clear_btn="Clear"
)
# Launch the interface
if __name__ == "__main__":
demo.launch()