PearlIsa commited on
Commit
d93da55
Β·
verified Β·
1 Parent(s): 0742e5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -137
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # βœ… Optimized Triage Chatbot Code for Hugging Face Space (NVIDIA T4 GPU)
2
  # Covers: Memory optimizations, 4-bit quantization, lazy loading, FAISS caching, faster inference, safe Gradio UI
 
3
 
4
  import os
5
  import time
@@ -10,10 +11,8 @@ import psutil
10
  from datetime import datetime
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
- from datasets import load_dataset, load_from_disk, Dataset
14
  from transformers import (
15
  AutoTokenizer, AutoModelForCausalLM,
16
- TrainingArguments, Trainer,
17
  BitsAndBytesConfig
18
  )
19
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
@@ -33,10 +32,7 @@ class SecretsManager:
33
  def setup():
34
  load_dotenv()
35
  creds = {
36
- 'KAGGLE_USERNAME': os.getenv('KAGGLE_USERNAME'),
37
- 'KAGGLE_KEY': os.getenv('KAGGLE_KEY'),
38
  'HF_TOKEN': os.getenv('HF_TOKEN'),
39
- 'WANDB_KEY': os.getenv('WANDB_KEY')
40
  }
41
  if creds['HF_TOKEN']:
42
  login(token=creds['HF_TOKEN'])
@@ -44,9 +40,9 @@ class SecretsManager:
44
  return creds
45
 
46
  # ===========================
47
- # 🧠 CHATBOT CLASS
48
  # ===========================
49
- class PearlyBot:
50
  def __init__(self):
51
  self.tokenizer = None
52
  self.model = None
@@ -57,196 +53,232 @@ class PearlyBot:
57
  self.num_relevant_chunks = 3
58
  self.last_interaction_time = time.time()
59
  self.interaction_cooldown = 1.0
 
 
 
 
60
 
61
- def setup_model_and_tokenizer(self, model_name="google/gemma-7b"):
62
  if self.model is not None:
63
  return
64
- logger.info("πŸš€ Loading model & tokenizer")
 
65
  bnb_config = BitsAndBytesConfig(
66
  load_in_4bit=True,
67
  bnb_4bit_use_double_quant=True,
68
  bnb_4bit_quant_type="nf4",
69
  bnb_4bit_compute_dtype=torch.float16
70
  )
 
71
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
72
  self.tokenizer.pad_token = self.tokenizer.eos_token
73
- model = AutoModelForCausalLM.from_pretrained(
 
74
  model_name,
75
  device_map="auto",
76
  quantization_config=bnb_config,
77
- torch_dtype=torch.float16,
78
- low_cpu_mem_usage=True
79
  )
80
- model = prepare_model_for_kbit_training(model)
81
- lora_config = LoraConfig(
82
- r=4,
83
- lora_alpha=16,
84
  target_modules=["q_proj", "v_proj"],
85
  lora_dropout=0.05,
86
- bias="none",
87
  task_type="CAUSAL_LM"
88
  )
89
- self.model = get_peft_model(model, lora_config)
90
- self.model.to("cuda" if torch.cuda.is_available() else "cpu")
91
- logger.info("βœ… Model & tokenizer ready")
92
 
93
- def setup_embeddings(self):
94
- if self.embeddings is None:
95
- logger.info("πŸ“Œ Loading sentence-transformer embeddings")
96
- self.embeddings = HuggingFaceEmbeddings(
97
- model_name="sentence-transformers/all-MiniLM-L6-v2",
98
- cache_folder="./embeddings_cache"
99
- )
100
-
101
- def load_faiss_index(self):
102
- logger.info("πŸ“ Loading FAISS index")
103
- if os.path.exists("index_store/index.faiss"):
104
- self.vector_store = FAISS.load_local("index_store", self.embeddings)
105
- else:
106
- self.build_faiss_index()
107
-
108
- def build_faiss_index(self):
109
- logger.info("πŸ”§ Building FAISS index from knowledge base")
110
- knowledge_base = self._load_knowledge_base()
111
- self.setup_embeddings()
112
- texts = self._split_texts(knowledge_base)
113
- self.vector_store = FAISS.from_texts(
114
- texts,
115
  self.embeddings,
116
- metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
117
  )
118
- self.vector_store.save_local("index_store")
119
 
120
- def _load_knowledge_base(self):
121
- kb_dir = "knowledge_base"
122
- os.makedirs(kb_dir, exist_ok=True)
123
- kb_files = {
124
- "triage.txt": "Severe chest pain? Call 999. Persistent cough? Book GP.",
125
- "emergency.txt": "Unconscious? 999. Breathing issues? 999.",
126
- "cultural.txt": "Respect prayer times, language needs, traditional remedies.",
127
- "gp_booking.txt": "I need to book a GP for routine care next week."
 
 
 
128
  }
129
- for file, content in kb_files.items():
130
- with open(os.path.join(kb_dir, file), 'w') as f:
 
 
131
  f.write(content)
132
- return kb_files
133
-
134
- def _split_texts(self, kb):
135
- splitter = RecursiveCharacterTextSplitter(
136
  chunk_size=self.chunk_size,
137
- chunk_overlap=self.chunk_overlap,
138
- length_function=len,
139
- add_start_index=True
140
  )
141
- texts = []
142
- for text in kb.values():
143
- chunks = splitter.split_text(text)
144
- texts.extend(chunks)
145
- return texts
 
 
 
 
 
 
146
 
147
- def _get_enhanced_context(self, query):
148
  try:
149
- results = self.vector_store.similarity_search_with_score(query, k=self.num_relevant_chunks)
150
- context = [f"[Source: {doc.metadata.get('source', 'unknown')}]:\n{doc.page_content}"
151
- for doc, score in results if score < 0.8]
152
- return "\n\n".join(context)
153
  except Exception as e:
154
  logger.error(f"Context error: {e}")
155
  return ""
156
 
157
  @torch.inference_mode()
158
- def generate_response(self, message, history):
159
  try:
160
- # Throttle
161
  if time.time() - self.last_interaction_time < self.interaction_cooldown:
162
  time.sleep(self.interaction_cooldown)
163
-
164
- self.setup_model_and_tokenizer()
165
- self.setup_embeddings()
166
- self.load_faiss_index()
167
-
168
- context = self._get_enhanced_context(message)
169
- conv_history = "\n".join([
170
- f"User: {turn['content']}" if turn['role'] == 'user' else f"Assistant: {turn['content']}"
171
- for turn in history[-3:]
172
- ])
173
-
174
  prompt = f"""<start_of_turn>system
175
- Context:
176
  {context}
177
- Conversation:
178
- {conv_history}
 
 
 
 
 
179
  <end_of_turn>
180
  <start_of_turn>user
181
  {message}
182
  <end_of_turn>
183
  <start_of_turn>assistant"""
184
-
185
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.model.device)
 
 
 
 
 
 
 
186
  outputs = self.model.generate(
187
  **inputs,
188
- max_new_tokens=128,
189
- min_new_tokens=20,
190
- do_sample=True,
191
  temperature=0.7,
192
  top_p=0.9,
193
- repetition_penalty=1.2,
194
- no_repeat_ngram_size=3
195
  )
196
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
197
- response = response.split("<start_of_turn>assistant")[-1].strip().split("<end_of_turn>")[0].strip()
 
 
 
 
 
 
 
 
198
  self.last_interaction_time = time.time()
199
- logger.info(f"πŸ’¬ Memory use: {psutil.virtual_memory().percent}%")
200
- return response
201
-
202
  except Exception as e:
203
  logger.error(f"Generation error: {e}")
204
- return "I encountered an error. Please try again."
205
 
206
  # ===========================
207
- # πŸ’¬ GRADIO UI
208
  # ===========================
209
- def create_demo():
210
- bot = PearlyBot()
211
-
212
- def chat(message, history):
213
- if not message.strip():
214
- return history
215
-
216
- # Convert Gradio-style history to model-style format
217
- structured_history = []
218
- for user_msg, bot_msg in history:
219
- structured_history.append({"role": "user", "content": user_msg})
220
- structured_history.append({"role": "assistant", "content": bot_msg})
221
-
222
- # Generate bot response
223
- response = bot.generate_response(message, structured_history)
224
-
225
- # Append new message pair
226
- history.append([message, response])
227
- return history
228
 
 
 
 
 
 
 
 
229
 
230
- with gr.Blocks() as demo:
231
- chatbot = gr.Chatbot(
232
- value=[["Hello!", "Hi, I’m Pearly, your GP triage assistant. I’m here to help you assess your symptoms and guide you to the right care. How are you feeling today?"]],
233
- height=500,
234
- show_label=False
235
- )
236
- msg = gr.Textbox(label="Type your message")
237
- send = gr.Button("Send")
238
- clear = gr.Button("Clear Chat")
239
-
240
- msg.submit(chat, [msg, chatbot], [chatbot]).then(lambda: gr.update(value=""), None, [msg])
241
- send.click(chat, [msg, chatbot], [chatbot]).then(lambda: gr.update(value=""), None, [msg])
242
- clear.click(lambda: [], None, chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- return demo
245
 
246
  # ===========================
247
- # πŸš€ MAIN
248
  # ===========================
249
  if __name__ == "__main__":
250
  SecretsManager.setup()
251
- demo = create_demo()
252
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
  # βœ… Optimized Triage Chatbot Code for Hugging Face Space (NVIDIA T4 GPU)
2
  # Covers: Memory optimizations, 4-bit quantization, lazy loading, FAISS caching, faster inference, safe Gradio UI
3
+ # Includes: Proper Gradio history handling, response cleaning, safety checks
4
 
5
  import os
6
  import time
 
11
  from datetime import datetime
12
  from huggingface_hub import login
13
  from dotenv import load_dotenv
 
14
  from transformers import (
15
  AutoTokenizer, AutoModelForCausalLM,
 
16
  BitsAndBytesConfig
17
  )
18
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
 
32
  def setup():
33
  load_dotenv()
34
  creds = {
 
 
35
  'HF_TOKEN': os.getenv('HF_TOKEN'),
 
36
  }
37
  if creds['HF_TOKEN']:
38
  login(token=creds['HF_TOKEN'])
 
40
  return creds
41
 
42
  # ===========================
43
+ # 🧠 MEDICAL CHATBOT CORE
44
  # ===========================
45
+ class MedicalTriageBot:
46
  def __init__(self):
47
  self.tokenizer = None
48
  self.model = None
 
53
  self.num_relevant_chunks = 3
54
  self.last_interaction_time = time.time()
55
  self.interaction_cooldown = 1.0
56
+ self.safety_phrases = [
57
+ "999", "111", "emergency", "GP", "NHS",
58
+ "consult a doctor", "seek medical attention"
59
+ ]
60
 
61
+ def setup_model(self, model_name="google/gemma-7b-it"):
62
  if self.model is not None:
63
  return
64
+
65
+ logger.info("πŸš€ Initializing medical AI model")
66
  bnb_config = BitsAndBytesConfig(
67
  load_in_4bit=True,
68
  bnb_4bit_use_double_quant=True,
69
  bnb_4bit_quant_type="nf4",
70
  bnb_4bit_compute_dtype=torch.float16
71
  )
72
+
73
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
74
  self.tokenizer.pad_token = self.tokenizer.eos_token
75
+
76
+ base_model = AutoModelForCausalLM.from_pretrained(
77
  model_name,
78
  device_map="auto",
79
  quantization_config=bnb_config,
80
+ torch_dtype=torch.float16
 
81
  )
82
+
83
+ peft_config = LoraConfig(
84
+ r=8,
85
+ lora_alpha=32,
86
  target_modules=["q_proj", "v_proj"],
87
  lora_dropout=0.05,
 
88
  task_type="CAUSAL_LM"
89
  )
90
+
91
+ self.model = get_peft_model(base_model, peft_config)
92
+ logger.info("βœ… Medical AI model ready")
93
 
94
+ def setup_rag_system(self):
95
+ logger.info("πŸ“š Initializing medical knowledge base")
96
+ self.embeddings = HuggingFaceEmbeddings(
97
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
98
+ model_kwargs={"device": "cpu"}
99
+ )
100
+
101
+ if not os.path.exists("medical_index/index.faiss"):
102
+ self.build_medical_index()
103
+
104
+ self.vector_store = FAISS.load_local(
105
+ "medical_index",
 
 
 
 
 
 
 
 
 
 
106
  self.embeddings,
107
+ allow_dangerous_deserialization=True
108
  )
 
109
 
110
+ def build_medical_index(self):
111
+ medical_knowledge = {
112
+ "emergency_protocols.txt": """Emergency Protocols:
113
+ - Chest pain: Call 999 immediately
114
+ - Breathing difficulties: Urgent 999 call
115
+ - Severe bleeding: Apply pressure, call 999""",
116
+
117
+ "triage_guidelines.txt": """Triage Guidelines:
118
+ - Persistent fever >48h: Contact 111
119
+ - Minor injuries: Visit urgent care
120
+ - Medication questions: Consult GP"""
121
  }
122
+
123
+ os.makedirs("medical_knowledge", exist_ok=True)
124
+ for filename, content in medical_knowledge.items():
125
+ with open(f"medical_knowledge/{filename}", "w") as f:
126
  f.write(content)
127
+
128
+ text_splitter = RecursiveCharacterTextSplitter(
 
 
129
  chunk_size=self.chunk_size,
130
+ chunk_overlap=self.chunk_overlap
 
 
131
  )
132
+
133
+ documents = []
134
+ for text in medical_knowledge.values():
135
+ documents.extend(text_splitter.split_text(text))
136
+
137
+ vector_store = FAISS.from_texts(
138
+ documents,
139
+ self.embeddings,
140
+ metadatas=[{"source": f"doc_{i}"} for i in range(len(documents))]
141
+ )
142
+ vector_store.save_local("medical_index")
143
 
144
+ def get_medical_context(self, query):
145
  try:
146
+ docs = self.vector_store.similarity_search(query, k=2)
147
+ return "\n".join([d.page_content for d in docs])
 
 
148
  except Exception as e:
149
  logger.error(f"Context error: {e}")
150
  return ""
151
 
152
  @torch.inference_mode()
153
+ def generate_safe_response(self, message, history):
154
  try:
155
+ # Rate limiting
156
  if time.time() - self.last_interaction_time < self.interaction_cooldown:
157
  time.sleep(self.interaction_cooldown)
158
+
159
+ # Convert Gradio history to conversational format
160
+ conversation = "\n".join(
161
+ [f"User: {user}\nAssistant: {bot}" for user, bot in history[-3:]]
162
+ )
163
+
164
+ # Get medical context
165
+ context = self.get_medical_context(message)
166
+
167
+ # Create safety-focused prompt
 
168
  prompt = f"""<start_of_turn>system
169
+ You are a medical triage assistant. Use this context:
170
  {context}
171
+ Current conversation:
172
+ {conversation}
173
+ Guidelines:
174
+ 1. Assess symptom severity
175
+ 2. Recommend appropriate care level
176
+ 3. Never diagnose or prescribe
177
+ 4. Always include safety netting
178
  <end_of_turn>
179
  <start_of_turn>user
180
  {message}
181
  <end_of_turn>
182
  <start_of_turn>assistant"""
183
+
184
+ # Generate response
185
+ inputs = self.tokenizer(
186
+ prompt,
187
+ return_tensors="pt",
188
+ truncation=True,
189
+ max_length=1024
190
+ ).to(self.model.device)
191
+
192
  outputs = self.model.generate(
193
  **inputs,
194
+ max_new_tokens=256,
 
 
195
  temperature=0.7,
196
  top_p=0.9,
197
+ repetition_penalty=1.2
 
198
  )
199
+
200
+ # Clean and validate response
201
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
202
+ clean_response = full_response.split("<start_of_turn>assistant")[-1]
203
+ clean_response = clean_response.split("<end_of_turn>")[0].strip()
204
+
205
+ # Ensure medical safety
206
+ if not any(phrase in clean_response.lower() for phrase in self.safety_phrases):
207
+ clean_response += "\n\nIf symptoms persist, please contact NHS 111."
208
+
209
  self.last_interaction_time = time.time()
210
+ return clean_response[:500] # Limit response length
211
+
 
212
  except Exception as e:
213
  logger.error(f"Generation error: {e}")
214
+ return "Please contact NHS 111 directly for urgent medical advice."
215
 
216
  # ===========================
217
+ # πŸ’¬ SAFE GRADIO INTERFACE
218
  # ===========================
219
+ def create_medical_interface():
220
+ bot = MedicalTriageBot()
221
+ bot.setup_model()
222
+ bot.setup_rag_system()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ def handle_conversation(message, history):
225
+ try:
226
+ response = bot.generate_safe_response(message, history)
227
+ return history + [(message, response)]
228
+ except Exception as e:
229
+ logger.error(f"Conversation error: {e}")
230
+ return history + [(message, "System error - please refresh the page")]
231
 
232
+ with gr.Blocks(theme=gr.themes.Soft()) as interface:
233
+ gr.Markdown("# NHS Triage Assistant")
234
+ gr.HTML("""<div class="emergency-banner">🚨 In emergencies, always call 999 immediately</div>""")
235
+
236
+ with gr.Row():
237
+ chatbot = gr.Chatbot(
238
+ value=[("", "Hello! I'm your NHS digital assistant. How can I help you today?")],
239
+ height=500,
240
+ label="Medical Triage Chat"
241
+ )
242
+
243
+ with gr.Row():
244
+ message_input = gr.Textbox(
245
+ placeholder="Describe your symptoms...",
246
+ label="Your Message",
247
+ max_lines=3
248
+ )
249
+ submit_btn = gr.Button("Send", variant="primary")
250
+
251
+ clear_btn = gr.Button("Clear History")
252
+
253
+ # Event handlers
254
+ message_input.submit(
255
+ handle_conversation,
256
+ [message_input, chatbot],
257
+ [chatbot]
258
+ ).then(lambda: "", None, [message_input])
259
+
260
+ submit_btn.click(
261
+ handle_conversation,
262
+ [message_input, chatbot],
263
+ [chatbot]
264
+ ).then(lambda: "", None, [message_input])
265
+
266
+ clear_btn.click(
267
+ lambda: [("", "Hello! I'm your NHS digital assistant. How can I help you today?")],
268
+ None,
269
+ [chatbot]
270
+ )
271
 
272
+ return interface
273
 
274
  # ===========================
275
+ # πŸš€ LAUNCH APPLICATION
276
  # ===========================
277
  if __name__ == "__main__":
278
  SecretsManager.setup()
279
+ medical_app = create_medical_interface()
280
+ medical_app.launch(
281
+ server_name="0.0.0.0",
282
+ server_port=7860,
283
+ share=False
284
+ )