Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# β
Optimized Triage Chatbot Code for Hugging Face Space (NVIDIA T4 GPU)
|
2 |
# Covers: Memory optimizations, 4-bit quantization, lazy loading, FAISS caching, faster inference, safe Gradio UI
|
|
|
3 |
|
4 |
import os
|
5 |
import time
|
@@ -10,10 +11,8 @@ import psutil
|
|
10 |
from datetime import datetime
|
11 |
from huggingface_hub import login
|
12 |
from dotenv import load_dotenv
|
13 |
-
from datasets import load_dataset, load_from_disk, Dataset
|
14 |
from transformers import (
|
15 |
AutoTokenizer, AutoModelForCausalLM,
|
16 |
-
TrainingArguments, Trainer,
|
17 |
BitsAndBytesConfig
|
18 |
)
|
19 |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
@@ -33,10 +32,7 @@ class SecretsManager:
|
|
33 |
def setup():
|
34 |
load_dotenv()
|
35 |
creds = {
|
36 |
-
'KAGGLE_USERNAME': os.getenv('KAGGLE_USERNAME'),
|
37 |
-
'KAGGLE_KEY': os.getenv('KAGGLE_KEY'),
|
38 |
'HF_TOKEN': os.getenv('HF_TOKEN'),
|
39 |
-
'WANDB_KEY': os.getenv('WANDB_KEY')
|
40 |
}
|
41 |
if creds['HF_TOKEN']:
|
42 |
login(token=creds['HF_TOKEN'])
|
@@ -44,9 +40,9 @@ class SecretsManager:
|
|
44 |
return creds
|
45 |
|
46 |
# ===========================
|
47 |
-
# π§ CHATBOT
|
48 |
# ===========================
|
49 |
-
class
|
50 |
def __init__(self):
|
51 |
self.tokenizer = None
|
52 |
self.model = None
|
@@ -57,196 +53,232 @@ class PearlyBot:
|
|
57 |
self.num_relevant_chunks = 3
|
58 |
self.last_interaction_time = time.time()
|
59 |
self.interaction_cooldown = 1.0
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
def
|
62 |
if self.model is not None:
|
63 |
return
|
64 |
-
|
|
|
65 |
bnb_config = BitsAndBytesConfig(
|
66 |
load_in_4bit=True,
|
67 |
bnb_4bit_use_double_quant=True,
|
68 |
bnb_4bit_quant_type="nf4",
|
69 |
bnb_4bit_compute_dtype=torch.float16
|
70 |
)
|
|
|
71 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
72 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
73 |
-
|
|
|
74 |
model_name,
|
75 |
device_map="auto",
|
76 |
quantization_config=bnb_config,
|
77 |
-
torch_dtype=torch.float16
|
78 |
-
low_cpu_mem_usage=True
|
79 |
)
|
80 |
-
|
81 |
-
|
82 |
-
r=
|
83 |
-
lora_alpha=
|
84 |
target_modules=["q_proj", "v_proj"],
|
85 |
lora_dropout=0.05,
|
86 |
-
bias="none",
|
87 |
task_type="CAUSAL_LM"
|
88 |
)
|
89 |
-
|
90 |
-
self.model
|
91 |
-
logger.info("β
|
92 |
|
93 |
-
def
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
else:
|
106 |
-
self.build_faiss_index()
|
107 |
-
|
108 |
-
def build_faiss_index(self):
|
109 |
-
logger.info("π§ Building FAISS index from knowledge base")
|
110 |
-
knowledge_base = self._load_knowledge_base()
|
111 |
-
self.setup_embeddings()
|
112 |
-
texts = self._split_texts(knowledge_base)
|
113 |
-
self.vector_store = FAISS.from_texts(
|
114 |
-
texts,
|
115 |
self.embeddings,
|
116 |
-
|
117 |
)
|
118 |
-
self.vector_store.save_local("index_store")
|
119 |
|
120 |
-
def
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
"
|
|
|
|
|
|
|
128 |
}
|
129 |
-
|
130 |
-
|
|
|
|
|
131 |
f.write(content)
|
132 |
-
|
133 |
-
|
134 |
-
def _split_texts(self, kb):
|
135 |
-
splitter = RecursiveCharacterTextSplitter(
|
136 |
chunk_size=self.chunk_size,
|
137 |
-
chunk_overlap=self.chunk_overlap
|
138 |
-
length_function=len,
|
139 |
-
add_start_index=True
|
140 |
)
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
def
|
148 |
try:
|
149 |
-
|
150 |
-
|
151 |
-
for doc, score in results if score < 0.8]
|
152 |
-
return "\n\n".join(context)
|
153 |
except Exception as e:
|
154 |
logger.error(f"Context error: {e}")
|
155 |
return ""
|
156 |
|
157 |
@torch.inference_mode()
|
158 |
-
def
|
159 |
try:
|
160 |
-
#
|
161 |
if time.time() - self.last_interaction_time < self.interaction_cooldown:
|
162 |
time.sleep(self.interaction_cooldown)
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
prompt = f"""<start_of_turn>system
|
175 |
-
|
176 |
{context}
|
177 |
-
|
178 |
-
{
|
|
|
|
|
|
|
|
|
|
|
179 |
<end_of_turn>
|
180 |
<start_of_turn>user
|
181 |
{message}
|
182 |
<end_of_turn>
|
183 |
<start_of_turn>assistant"""
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
outputs = self.model.generate(
|
187 |
**inputs,
|
188 |
-
max_new_tokens=
|
189 |
-
min_new_tokens=20,
|
190 |
-
do_sample=True,
|
191 |
temperature=0.7,
|
192 |
top_p=0.9,
|
193 |
-
repetition_penalty=1.2
|
194 |
-
no_repeat_ngram_size=3
|
195 |
)
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
self.last_interaction_time = time.time()
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
except Exception as e:
|
203 |
logger.error(f"Generation error: {e}")
|
204 |
-
return "
|
205 |
|
206 |
# ===========================
|
207 |
-
# π¬ GRADIO
|
208 |
# ===========================
|
209 |
-
def
|
210 |
-
bot =
|
211 |
-
|
212 |
-
|
213 |
-
if not message.strip():
|
214 |
-
return history
|
215 |
-
|
216 |
-
# Convert Gradio-style history to model-style format
|
217 |
-
structured_history = []
|
218 |
-
for user_msg, bot_msg in history:
|
219 |
-
structured_history.append({"role": "user", "content": user_msg})
|
220 |
-
structured_history.append({"role": "assistant", "content": bot_msg})
|
221 |
-
|
222 |
-
# Generate bot response
|
223 |
-
response = bot.generate_response(message, structured_history)
|
224 |
-
|
225 |
-
# Append new message pair
|
226 |
-
history.append([message, response])
|
227 |
-
return history
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
-
with gr.Blocks() as
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
return
|
245 |
|
246 |
# ===========================
|
247 |
-
# π
|
248 |
# ===========================
|
249 |
if __name__ == "__main__":
|
250 |
SecretsManager.setup()
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
1 |
# β
Optimized Triage Chatbot Code for Hugging Face Space (NVIDIA T4 GPU)
|
2 |
# Covers: Memory optimizations, 4-bit quantization, lazy loading, FAISS caching, faster inference, safe Gradio UI
|
3 |
+
# Includes: Proper Gradio history handling, response cleaning, safety checks
|
4 |
|
5 |
import os
|
6 |
import time
|
|
|
11 |
from datetime import datetime
|
12 |
from huggingface_hub import login
|
13 |
from dotenv import load_dotenv
|
|
|
14 |
from transformers import (
|
15 |
AutoTokenizer, AutoModelForCausalLM,
|
|
|
16 |
BitsAndBytesConfig
|
17 |
)
|
18 |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
|
32 |
def setup():
|
33 |
load_dotenv()
|
34 |
creds = {
|
|
|
|
|
35 |
'HF_TOKEN': os.getenv('HF_TOKEN'),
|
|
|
36 |
}
|
37 |
if creds['HF_TOKEN']:
|
38 |
login(token=creds['HF_TOKEN'])
|
|
|
40 |
return creds
|
41 |
|
42 |
# ===========================
|
43 |
+
# π§ MEDICAL CHATBOT CORE
|
44 |
# ===========================
|
45 |
+
class MedicalTriageBot:
|
46 |
def __init__(self):
|
47 |
self.tokenizer = None
|
48 |
self.model = None
|
|
|
53 |
self.num_relevant_chunks = 3
|
54 |
self.last_interaction_time = time.time()
|
55 |
self.interaction_cooldown = 1.0
|
56 |
+
self.safety_phrases = [
|
57 |
+
"999", "111", "emergency", "GP", "NHS",
|
58 |
+
"consult a doctor", "seek medical attention"
|
59 |
+
]
|
60 |
|
61 |
+
def setup_model(self, model_name="google/gemma-7b-it"):
|
62 |
if self.model is not None:
|
63 |
return
|
64 |
+
|
65 |
+
logger.info("π Initializing medical AI model")
|
66 |
bnb_config = BitsAndBytesConfig(
|
67 |
load_in_4bit=True,
|
68 |
bnb_4bit_use_double_quant=True,
|
69 |
bnb_4bit_quant_type="nf4",
|
70 |
bnb_4bit_compute_dtype=torch.float16
|
71 |
)
|
72 |
+
|
73 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
74 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
75 |
+
|
76 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
77 |
model_name,
|
78 |
device_map="auto",
|
79 |
quantization_config=bnb_config,
|
80 |
+
torch_dtype=torch.float16
|
|
|
81 |
)
|
82 |
+
|
83 |
+
peft_config = LoraConfig(
|
84 |
+
r=8,
|
85 |
+
lora_alpha=32,
|
86 |
target_modules=["q_proj", "v_proj"],
|
87 |
lora_dropout=0.05,
|
|
|
88 |
task_type="CAUSAL_LM"
|
89 |
)
|
90 |
+
|
91 |
+
self.model = get_peft_model(base_model, peft_config)
|
92 |
+
logger.info("β
Medical AI model ready")
|
93 |
|
94 |
+
def setup_rag_system(self):
|
95 |
+
logger.info("π Initializing medical knowledge base")
|
96 |
+
self.embeddings = HuggingFaceEmbeddings(
|
97 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
98 |
+
model_kwargs={"device": "cpu"}
|
99 |
+
)
|
100 |
+
|
101 |
+
if not os.path.exists("medical_index/index.faiss"):
|
102 |
+
self.build_medical_index()
|
103 |
+
|
104 |
+
self.vector_store = FAISS.load_local(
|
105 |
+
"medical_index",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
self.embeddings,
|
107 |
+
allow_dangerous_deserialization=True
|
108 |
)
|
|
|
109 |
|
110 |
+
def build_medical_index(self):
|
111 |
+
medical_knowledge = {
|
112 |
+
"emergency_protocols.txt": """Emergency Protocols:
|
113 |
+
- Chest pain: Call 999 immediately
|
114 |
+
- Breathing difficulties: Urgent 999 call
|
115 |
+
- Severe bleeding: Apply pressure, call 999""",
|
116 |
+
|
117 |
+
"triage_guidelines.txt": """Triage Guidelines:
|
118 |
+
- Persistent fever >48h: Contact 111
|
119 |
+
- Minor injuries: Visit urgent care
|
120 |
+
- Medication questions: Consult GP"""
|
121 |
}
|
122 |
+
|
123 |
+
os.makedirs("medical_knowledge", exist_ok=True)
|
124 |
+
for filename, content in medical_knowledge.items():
|
125 |
+
with open(f"medical_knowledge/{filename}", "w") as f:
|
126 |
f.write(content)
|
127 |
+
|
128 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
|
129 |
chunk_size=self.chunk_size,
|
130 |
+
chunk_overlap=self.chunk_overlap
|
|
|
|
|
131 |
)
|
132 |
+
|
133 |
+
documents = []
|
134 |
+
for text in medical_knowledge.values():
|
135 |
+
documents.extend(text_splitter.split_text(text))
|
136 |
+
|
137 |
+
vector_store = FAISS.from_texts(
|
138 |
+
documents,
|
139 |
+
self.embeddings,
|
140 |
+
metadatas=[{"source": f"doc_{i}"} for i in range(len(documents))]
|
141 |
+
)
|
142 |
+
vector_store.save_local("medical_index")
|
143 |
|
144 |
+
def get_medical_context(self, query):
|
145 |
try:
|
146 |
+
docs = self.vector_store.similarity_search(query, k=2)
|
147 |
+
return "\n".join([d.page_content for d in docs])
|
|
|
|
|
148 |
except Exception as e:
|
149 |
logger.error(f"Context error: {e}")
|
150 |
return ""
|
151 |
|
152 |
@torch.inference_mode()
|
153 |
+
def generate_safe_response(self, message, history):
|
154 |
try:
|
155 |
+
# Rate limiting
|
156 |
if time.time() - self.last_interaction_time < self.interaction_cooldown:
|
157 |
time.sleep(self.interaction_cooldown)
|
158 |
+
|
159 |
+
# Convert Gradio history to conversational format
|
160 |
+
conversation = "\n".join(
|
161 |
+
[f"User: {user}\nAssistant: {bot}" for user, bot in history[-3:]]
|
162 |
+
)
|
163 |
+
|
164 |
+
# Get medical context
|
165 |
+
context = self.get_medical_context(message)
|
166 |
+
|
167 |
+
# Create safety-focused prompt
|
|
|
168 |
prompt = f"""<start_of_turn>system
|
169 |
+
You are a medical triage assistant. Use this context:
|
170 |
{context}
|
171 |
+
Current conversation:
|
172 |
+
{conversation}
|
173 |
+
Guidelines:
|
174 |
+
1. Assess symptom severity
|
175 |
+
2. Recommend appropriate care level
|
176 |
+
3. Never diagnose or prescribe
|
177 |
+
4. Always include safety netting
|
178 |
<end_of_turn>
|
179 |
<start_of_turn>user
|
180 |
{message}
|
181 |
<end_of_turn>
|
182 |
<start_of_turn>assistant"""
|
183 |
+
|
184 |
+
# Generate response
|
185 |
+
inputs = self.tokenizer(
|
186 |
+
prompt,
|
187 |
+
return_tensors="pt",
|
188 |
+
truncation=True,
|
189 |
+
max_length=1024
|
190 |
+
).to(self.model.device)
|
191 |
+
|
192 |
outputs = self.model.generate(
|
193 |
**inputs,
|
194 |
+
max_new_tokens=256,
|
|
|
|
|
195 |
temperature=0.7,
|
196 |
top_p=0.9,
|
197 |
+
repetition_penalty=1.2
|
|
|
198 |
)
|
199 |
+
|
200 |
+
# Clean and validate response
|
201 |
+
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
202 |
+
clean_response = full_response.split("<start_of_turn>assistant")[-1]
|
203 |
+
clean_response = clean_response.split("<end_of_turn>")[0].strip()
|
204 |
+
|
205 |
+
# Ensure medical safety
|
206 |
+
if not any(phrase in clean_response.lower() for phrase in self.safety_phrases):
|
207 |
+
clean_response += "\n\nIf symptoms persist, please contact NHS 111."
|
208 |
+
|
209 |
self.last_interaction_time = time.time()
|
210 |
+
return clean_response[:500] # Limit response length
|
211 |
+
|
|
|
212 |
except Exception as e:
|
213 |
logger.error(f"Generation error: {e}")
|
214 |
+
return "Please contact NHS 111 directly for urgent medical advice."
|
215 |
|
216 |
# ===========================
|
217 |
+
# π¬ SAFE GRADIO INTERFACE
|
218 |
# ===========================
|
219 |
+
def create_medical_interface():
|
220 |
+
bot = MedicalTriageBot()
|
221 |
+
bot.setup_model()
|
222 |
+
bot.setup_rag_system()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
+
def handle_conversation(message, history):
|
225 |
+
try:
|
226 |
+
response = bot.generate_safe_response(message, history)
|
227 |
+
return history + [(message, response)]
|
228 |
+
except Exception as e:
|
229 |
+
logger.error(f"Conversation error: {e}")
|
230 |
+
return history + [(message, "System error - please refresh the page")]
|
231 |
|
232 |
+
with gr.Blocks(theme=gr.themes.Soft()) as interface:
|
233 |
+
gr.Markdown("# NHS Triage Assistant")
|
234 |
+
gr.HTML("""<div class="emergency-banner">π¨ In emergencies, always call 999 immediately</div>""")
|
235 |
+
|
236 |
+
with gr.Row():
|
237 |
+
chatbot = gr.Chatbot(
|
238 |
+
value=[("", "Hello! I'm your NHS digital assistant. How can I help you today?")],
|
239 |
+
height=500,
|
240 |
+
label="Medical Triage Chat"
|
241 |
+
)
|
242 |
+
|
243 |
+
with gr.Row():
|
244 |
+
message_input = gr.Textbox(
|
245 |
+
placeholder="Describe your symptoms...",
|
246 |
+
label="Your Message",
|
247 |
+
max_lines=3
|
248 |
+
)
|
249 |
+
submit_btn = gr.Button("Send", variant="primary")
|
250 |
+
|
251 |
+
clear_btn = gr.Button("Clear History")
|
252 |
+
|
253 |
+
# Event handlers
|
254 |
+
message_input.submit(
|
255 |
+
handle_conversation,
|
256 |
+
[message_input, chatbot],
|
257 |
+
[chatbot]
|
258 |
+
).then(lambda: "", None, [message_input])
|
259 |
+
|
260 |
+
submit_btn.click(
|
261 |
+
handle_conversation,
|
262 |
+
[message_input, chatbot],
|
263 |
+
[chatbot]
|
264 |
+
).then(lambda: "", None, [message_input])
|
265 |
+
|
266 |
+
clear_btn.click(
|
267 |
+
lambda: [("", "Hello! I'm your NHS digital assistant. How can I help you today?")],
|
268 |
+
None,
|
269 |
+
[chatbot]
|
270 |
+
)
|
271 |
|
272 |
+
return interface
|
273 |
|
274 |
# ===========================
|
275 |
+
# π LAUNCH APPLICATION
|
276 |
# ===========================
|
277 |
if __name__ == "__main__":
|
278 |
SecretsManager.setup()
|
279 |
+
medical_app = create_medical_interface()
|
280 |
+
medical_app.launch(
|
281 |
+
server_name="0.0.0.0",
|
282 |
+
server_port=7860,
|
283 |
+
share=False
|
284 |
+
)
|