File size: 8,227 Bytes
65b8407
6ea4317
65b8407
 
48823b1
65b8407
 
48823b1
6ea4317
65b8407
48823b1
65b8407
 
 
 
 
48823b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65b8407
 
48823b1
65b8407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1215503
65b8407
 
8b70f49
 
1215503
 
65b8407
 
1215503
65b8407
8b70f49
65b8407
1215503
 
 
65b8407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1215503
 
 
65b8407
 
 
 
 
 
 
48823b1
 
 
 
 
 
 
 
 
 
65b8407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48823b1
1215503
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
from groq import Groq
import streamlit as st
import re
from datetime import datetime
import os
from typing import Generator, List, Tuple, Optional
import logging
from dotenv import load_dotenv

# --- Load Environment Variables ---
load_dotenv()

# --- Logging Configuration ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- API Key Management ---
def get_api_key() -> Optional[str]:
    """Get API key from environment or user input."""
    api_key = os.getenv("GROQ_API_KEY")
    
    if not api_key:
        st.sidebar.markdown("## πŸ”‘ API Configuration")
        api_key = st.sidebar.text_input(
            "Enter your Groq API Key:",
            type="password",
            help="Get your API key from https://console.groq.com",
            key="groq_api_key"
        )
        if api_key:
            st.sidebar.success("API Key set successfully!")
        else:
            st.sidebar.warning("Please enter your Groq API Key to continue")
    
    return api_key

# --- Constants ---
MODEL_CONFIG = {
    "model": "meta-llama/llama-4-scout-17b-16e-instruct",
    "temperature": 0.5,
    "max_completion_tokens": 1024,
    "stream": True,
    "stop": None
}

EMERGENCY_CONTACTS = {
    "Hospital Address": "John Smith Hospital, 123 Health St.",
    "Ambulance": "911 (US) / 112 (EU)",
    "24/7 Support": "+1-800-123-4567"
}

EMERGENCY_KEYWORDS = [
    "emergency", "911", "112", "immediate help",
    "severe pain", "cannot breathe", "chest pain",
    "unconscious", "seizure", "stroke"
]

SYSTEM_PROMPT = """You are DoctorX, a medical AI assistant. Follow these guidelines:
1. Never diagnose - only suggest possible conditions
2. Always recommend consulting a medical professional
3. Prioritize patient safety and well-being
4. Maintain professional yet empathetic tone
5. Be clear about being an AI
6. For emergencies, direct to emergency services

Format your responses as follows:
πŸ€– AI Assistant: [Your greeting]
πŸ’­ Understanding: [Brief interpretation of the query]
πŸ₯ Medical Context: [Relevant medical information]
πŸ“‹ Suggestions:
- [Point 1]
- [Point 2]
⚠️ Important: Always consult a healthcare professional for proper diagnosis and treatment."""

# --- Security & Input Validation ---
def sanitize_input(text: str) -> str:
    """Remove potentially harmful patterns from user input."""
    if not text:
        return ""
    # Remove potential XSS and injection patterns
    sanitized = re.sub(r"[<>{}[\]~`]", "", text[:2000])
    return sanitized.strip()

def validate_response(response: str) -> bool:
    """Validate AI response for safety concerns."""
    blacklist = ["take your own life", "kill yourself", "hate you"]
    return not any(phrase in response.lower() for phrase in blacklist)

def process_emergency(query: str) -> bool:
    """Check if query indicates a medical emergency."""
    return any(keyword in query.lower() for keyword in EMERGENCY_KEYWORDS)

def generate_medical_response(query: str, chat_history: List[Tuple[str, str]]) -> Generator[str, None, None]:
    """
    Generate a medical response using the LLM with streaming support.
    
    Args:
        query: The user's medical query
        chat_history: List of previous interactions as (role, message) tuples
    
    Yields:
        Chunks of the generated response
    """
    try:
        # Format chat history
        history_messages = [
            {"role": "user" if role == "user" else "assistant", "content": msg}
            for role, msg in chat_history[-5:]
        ]
        
        # Construct messages with system prompt
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            *history_messages,
            {"role": "user", "content": query}
        ]
        
        # Generate streaming response
        completion = client.chat.completions.create(
            messages=messages,
            **MODEL_CONFIG
        )
        
        # Process response chunks
        for chunk in completion:
            if chunk.choices[0].delta.content:
                yield chunk.choices[0].delta.content
        
    except Exception as e:
        logger.error(f"Error generating response: {str(e)}")
        yield "I apologize, but I encountered an error. Please try again or contact support."

# --- Main Application ---
def initialize_session_state() -> None:
    """Initialize Streamlit session state variables."""
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []

def setup_page() -> None:
    """Configure Streamlit page settings."""
    st.set_page_config(
        page_title="DoctorX - Your AI Health Assistant",
        page_icon="🧠",
        layout="wide",
        initial_sidebar_state="expanded"
    )

def render_sidebar() -> None:
    """Render the sidebar with emergency information."""
    with st.sidebar:
        st.image("https://img.icons8.com/color/96/000000/heart-health.png")
        st.header("🚨 Emergency Contacts")
        for key, value in EMERGENCY_CONTACTS.items():
            st.subheader(key)
            st.caption(value)
        st.divider()
        st.warning("⚠️ This is not a substitute for emergency medical care")

def handle_user_input(user_input: str) -> None:
    """Process and respond to user input."""
    cleaned_input = sanitize_input(user_input)
    
    if process_emergency(cleaned_input):
        st.error("🚨 This appears to be an emergency. Please contact emergency services immediately!")
        st.info("See emergency contacts in the sidebar β†’")
        return
    
    st.session_state.chat_history.append(("user", cleaned_input))
    
    with st.chat_message("assistant"):
        response_placeholder = st.empty()
        full_response = []
        
        with st.spinner("πŸ€” Analyzing your query..."):
            for response_chunk in generate_medical_response(cleaned_input, st.session_state.chat_history):
                full_response.append(response_chunk)
                response_placeholder.markdown("".join(full_response))
            
            if validate_response("".join(full_response)):
                st.session_state.chat_history.append(("assistant", "".join(full_response)))
            else:
                safe_response = "I apologize, but I cannot provide that information. Please consult a healthcare professional."
                response_placeholder.markdown(safe_response)
                st.session_state.chat_history.append(("assistant", safe_response))

def render_quick_access_buttons() -> None:
    """Render quick access buttons for common health queries."""
    st.divider()
    st.subheader("πŸ“Œ Common Health Topics")
    
    common_queries = [
        "What are common symptoms of anxiety?",
        "How to maintain good sleep hygiene?",
        "When should I see a doctor about headaches?",
        "Tips for managing stress",
        "Understanding blood pressure readings"
    ]
    
    cols = st.columns(len(common_queries))
    for col, query in zip(cols, common_queries):
        if col.button(query):
            handle_user_input(query)

def main() -> None:
    """Main application function."""
    try:
        setup_page()
        initialize_session_state()
        
        # Get API key
        api_key = get_api_key()
        if not api_key:
            st.stop()
            
        # Initialize Groq client
        global client
        client = Groq(api_key=api_key)
        
        render_sidebar()
        st.title("🧠 DoctorX")
        st.caption("Preliminary health guidance - Always consult healthcare professionals")

        # Display chat history
        for role, message in st.session_state.chat_history:
            with st.chat_message(role):
                st.markdown(message)

        # Handle user input
        if user_input := st.chat_input("Type your health question here...", key="user_input"):
            handle_user_input(user_input)

        render_quick_access_buttons()

    except Exception as e:
        logger.error(f"Application error: {str(e)}")
        st.error("An unexpected error occurred. Please refresh the page and try again.")

# Remove the global client initialization
if __name__ == "__main__":
    main()