|
import os |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
from typing import List, Dict |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class MedicalAssistant: |
|
def __init__(self): |
|
"""Initialize the medical assistant with model and tokenizer""" |
|
try: |
|
logger.info("Starting model initialization...") |
|
|
|
|
|
self.model_name = "mradermacher/Llama3-Med42-8B-GGUF" |
|
self.max_length = 1048 |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
logger.info("Loading tokenizer...") |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_name, |
|
padding_side="left", |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
logger.info("Loading model...") |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
load_in_8bit=True, |
|
trust_remote_code=True |
|
) |
|
|
|
logger.info("Model initialization completed successfully!") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during initialization: {str(e)}") |
|
raise |
|
|
|
def generate_response(self, message: str, chat_history: List[Dict] = None) -> str: |
|
"""Generate a response to the user's message""" |
|
try: |
|
|
|
system_prompt = """You are a medical AI assistant. Respond to medical queries |
|
professionally and accurately. If you're unsure, always recommend consulting |
|
with a healthcare provider.""" |
|
|
|
|
|
full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:" |
|
|
|
|
|
inputs = self.tokenizer( |
|
full_prompt, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length |
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
response = self.tokenizer.decode( |
|
outputs[0], |
|
skip_special_tokens=True |
|
) |
|
|
|
|
|
response = response.split("Assistant:")[-1].strip() |
|
|
|
return response |
|
|
|
except Exception as e: |
|
logger.error(f"Error during response generation: {str(e)}") |
|
return f"I apologize, but I encountered an error. Please try again." |
|
|
|
|
|
assistant = None |
|
|
|
def initialize_assistant(): |
|
"""Initialize the assistant and handle any errors""" |
|
global assistant |
|
try: |
|
assistant = MedicalAssistant() |
|
return True |
|
except Exception as e: |
|
logger.error(f"Failed to initialize assistant: {str(e)}") |
|
return False |
|
|
|
def chat_response(message: str, history: List[Dict]): |
|
"""Handle chat messages and return responses""" |
|
global assistant |
|
|
|
|
|
if assistant is None: |
|
if not initialize_assistant(): |
|
return "I apologize, but I'm currently unavailable. Please try again later." |
|
|
|
try: |
|
return assistant.generate_response(message, history) |
|
except Exception as e: |
|
logger.error(f"Error in chat response: {str(e)}") |
|
return "I encountered an error. Please try again." |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=chat_response, |
|
title="Medical Assistant (Test Version)", |
|
description="""This is a test version of the medical assistant. |
|
Please use it to verify basic functionality.""", |
|
examples=[ |
|
"What are the symptoms of malaria?", |
|
"How can I prevent type 2 diabetes?", |
|
"What should I do for a mild headache?" |
|
], |
|
|
|
|
|
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |