File size: 5,176 Bytes
c8d430c
b4ff37d
 
 
 
 
a444494
b4ff37d
 
 
a444494
b4ff37d
a444494
b4ff37d
c8d430c
b4ff37d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8d430c
 
b4ff37d
 
 
c8d430c
b4ff37d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a444494
b4ff37d
 
b1de9b2
b4ff37d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a444494
b4ff37d
 
 
 
 
b1de9b2
b4ff37d
 
 
 
 
 
 
 
 
 
 
 
a444494
b4ff37d
 
 
e53bd9c
b4ff37d
 
e53bd9c
b4ff37d
 
c8d430c
b4ff37d
 
 
c8d430c
b4ff37d
 
 
 
a444494
b4ff37d
 
 
 
 
a444494
b4ff37d
a444494
b4ff37d
 
 
 
94dc8bb
b4ff37d
 
 
94dc8bb
b4ff37d
 
 
a444494
 
b4ff37d
a444494
b4ff37d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Dict
import logging

# Set up logging to help us debug model loading and inference
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MedicalAssistant:
    def __init__(self):
        """Initialize the medical assistant with model and tokenizer"""
        try:
            logger.info("Starting model initialization...")
            
            # Model configuration - adjust these based on your available compute
            self.model_name = "mradermacher/Llama3-Med42-8B-GGUF"
            self.max_length = 1048
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            
            logger.info(f"Using device: {self.device}")
            
            # Load tokenizer first - this is typically faster and can catch issues early
            logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                padding_side="left",
                trust_remote_code=True
            )
            
            # Set padding token if not set
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # Load model with memory optimizations
            logger.info("Loading model...")
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                load_in_8bit=True,
                trust_remote_code=True
            )
            
            logger.info("Model initialization completed successfully!")
            
        except Exception as e:
            logger.error(f"Error during initialization: {str(e)}")
            raise

    def generate_response(self, message: str, chat_history: List[Dict] = None) -> str:
        """Generate a response to the user's message"""
        try:
            # Prepare the prompt
            system_prompt = """You are a medical AI assistant. Respond to medical queries 
            professionally and accurately. If you're unsure, always recommend consulting 
            with a healthcare provider."""
            
            # Combine system prompt, chat history, and current message
            full_prompt = f"{system_prompt}\n\nUser: {message}\nAssistant:"
            
            # Tokenize input
            inputs = self.tokenizer(
                full_prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_length
            ).to(self.device)
            
            # Generate response
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=512,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.95,
                    pad_token_id=self.tokenizer.pad_token_id,
                    repetition_penalty=1.1
                )
            
            # Decode and clean up response
            response = self.tokenizer.decode(
                outputs[0],
                skip_special_tokens=True
            )
            
            # Extract just the assistant's response
            response = response.split("Assistant:")[-1].strip()
            
            return response
            
        except Exception as e:
            logger.error(f"Error during response generation: {str(e)}")
            return f"I apologize, but I encountered an error. Please try again."

# Initialize the assistant
assistant = None

def initialize_assistant():
    """Initialize the assistant and handle any errors"""
    global assistant
    try:
        assistant = MedicalAssistant()
        return True
    except Exception as e:
        logger.error(f"Failed to initialize assistant: {str(e)}")
        return False

def chat_response(message: str, history: List[Dict]):
    """Handle chat messages and return responses"""
    global assistant
    
    # Check if assistant is initialized
    if assistant is None:
        if not initialize_assistant():
            return "I apologize, but I'm currently unavailable. Please try again later."
    
    try:
        return assistant.generate_response(message, history)
    except Exception as e:
        logger.error(f"Error in chat response: {str(e)}")
        return "I encountered an error. Please try again."

# Create Gradio interface
demo = gr.ChatInterface(
    fn=chat_response,
    title="Medical Assistant (Test Version)",
    description="""This is a test version of the medical assistant. 
                   Please use it to verify basic functionality.""",
    examples=[
        "What are the symptoms of malaria?",
        "How can I prevent type 2 diabetes?",
        "What should I do for a mild headache?"
    ],
    # retry_btn=None,
    # undo_btn=None,
    # clear_btn="Clear"
)

# Launch the interface
if __name__ == "__main__":
    demo.launch()