File size: 5,176 Bytes
c8d430c b4ff37d a444494 b4ff37d a444494 b4ff37d a444494 b4ff37d c8d430c b4ff37d c8d430c b4ff37d c8d430c b4ff37d a444494 b4ff37d b1de9b2 b4ff37d a444494 b4ff37d b1de9b2 b4ff37d a444494 b4ff37d e53bd9c b4ff37d e53bd9c b4ff37d c8d430c b4ff37d c8d430c b4ff37d a444494 b4ff37d a444494 b4ff37d a444494 b4ff37d 94dc8bb b4ff37d 94dc8bb b4ff37d a444494 b4ff37d a444494 b4ff37d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Dict
import logging
# Set up logging to help us debug model loading and inference
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MedicalAssistant:
def __init__(self):
"""Initialize the medical assistant with model and tokenizer"""
try:
logger.info("Starting model initialization...")
# Model configuration - adjust these based on your available compute
self.model_name = "mradermacher/Llama3-Med42-8B-GGUF"
self.max_length = 1048
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Load tokenizer first - this is typically faster and can catch issues early
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
padding_side="left",
trust_remote_code=True
)
# Set padding token if not set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with memory optimizations
logger.info("Loading model...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True,
trust_remote_code=True
)
logger.info("Model initialization completed successfully!")
except Exception as e:
logger.error(f"Error during initialization: {str(e)}")
raise
def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
"""Generate a response to the user's message"""
try:
# Prepare the prompt
system_prompt = """You are a medical AI assistant. Respond to medical queries
professionally and accurately. If you're unsure, always recommend consulting
with a healthcare provider."""
# Combine system prompt, chat history, and current message
full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:"
# Tokenize input
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
).to(self.device)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1
)
# Decode and clean up response
response = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
# Extract just the assistant's response
response = response.split("Assistant:")[-1].strip()
return response
except Exception as e:
logger.error(f"Error during response generation: {str(e)}")
return f"I apologize, but I encountered an error. Please try again."
# Initialize the assistant
assistant = None
def initialize_assistant():
"""Initialize the assistant and handle any errors"""
global assistant
try:
assistant = MedicalAssistant()
return True
except Exception as e:
logger.error(f"Failed to initialize assistant: {str(e)}")
return False
def chat_response(message: str, history: List[Dict]):
"""Handle chat messages and return responses"""
global assistant
# Check if assistant is initialized
if assistant is None:
if not initialize_assistant():
return "I apologize, but I'm currently unavailable. Please try again later."
try:
return assistant.generate_response(message, history)
except Exception as e:
logger.error(f"Error in chat response: {str(e)}")
return "I encountered an error. Please try again."
# Create Gradio interface
demo = gr.ChatInterface(
fn=chat_response,
title="Medical Assistant (Test Version)",
description="""This is a test version of the medical assistant.
Please use it to verify basic functionality.""",
examples=[
"What are the symptoms of malaria?",
"How can I prevent type 2 diabetes?",
"What should I do for a mild headache?"
],
# retry_btn=None,
# undo_btn=None,
# clear_btn="Clear"
)
# Launch the interface
if __name__ == "__main__":
demo.launch() |