YanBoChen commited on
Commit
fa23be2
·
1 Parent(s): 4c919d2

feat(user_prompt): add keyword index checking and enhance medical query validation (add validate_medical_query() and _check_keyword_in_index(), based on implementation_todo_20250730_user_prompt.md

Browse files
Files changed (1) hide show
  1. src/user_prompt.py +153 -0
src/user_prompt.py CHANGED
@@ -15,6 +15,9 @@ import logging
15
  from typing import Dict, Optional, Any, List
16
  from sentence_transformers import SentenceTransformer
17
  import numpy as np # Added missing import for numpy
 
 
 
18
 
19
  # Import our centralized medical conditions configuration
20
  from medical_conditions import (
@@ -42,6 +45,10 @@ class UserPromptProcessor:
42
  self.meditron_client = meditron_client
43
  self.retrieval_system = retrieval_system
44
  self.embedding_model = SentenceTransformer("NeuML/pubmedbert-base-embeddings")
 
 
 
 
45
  logger.info("UserPromptProcessor initialized")
46
 
47
  def extract_condition_keywords(self, user_query: str) -> Dict[str, str]:
@@ -254,6 +261,72 @@ class UserPromptProcessor:
254
  # Basic validation: check if any keyword is non-empty
255
  return any(kw.strip() for kw in emergency_kws + treatment_kws)
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  def handle_user_confirmation(self, extracted_info: Dict[str, str]) -> Dict[str, Any]:
258
  """
259
  Handle user confirmation for extracted condition and keywords
@@ -296,6 +369,86 @@ Please confirm:
296
  'extracted_info': extracted_info
297
  }
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  def main():
300
  """
301
  Example usage and testing of UserPromptProcessor
 
15
  from typing import Dict, Optional, Any, List
16
  from sentence_transformers import SentenceTransformer
17
  import numpy as np # Added missing import for numpy
18
+ import os # Added missing import for os
19
+ import json # Added missing import for json
20
+ import re # Added missing import for re
21
 
22
  # Import our centralized medical conditions configuration
23
  from medical_conditions import (
 
45
  self.meditron_client = meditron_client
46
  self.retrieval_system = retrieval_system
47
  self.embedding_model = SentenceTransformer("NeuML/pubmedbert-base-embeddings")
48
+
49
+ # Add embeddings directory path
50
+ self.embeddings_dir = os.path.join(os.path.dirname(__file__), '..', 'models', 'embeddings')
51
+
52
  logger.info("UserPromptProcessor initialized")
53
 
54
  def extract_condition_keywords(self, user_query: str) -> Dict[str, str]:
 
261
  # Basic validation: check if any keyword is non-empty
262
  return any(kw.strip() for kw in emergency_kws + treatment_kws)
263
 
264
+ def _check_keyword_in_index(self, keyword: str, index_type: str) -> bool:
265
+ """
266
+ Check if a keyword exists in the specified medical index
267
+
268
+ Args:
269
+ keyword: Keyword to check
270
+ index_type: Type of index ('emergency' or 'treatment')
271
+
272
+ Returns:
273
+ Boolean indicating keyword existence in the index
274
+ """
275
+ # Validate input parameters
276
+ if not keyword or not index_type:
277
+ logger.warning(f"Invalid input: keyword='{keyword}', index_type='{index_type}'")
278
+ return False
279
+
280
+ # Supported index types
281
+ valid_index_types = ['emergency', 'treatment']
282
+ if index_type not in valid_index_types:
283
+ logger.error(f"Unsupported index type: {index_type}")
284
+ return False
285
+
286
+ try:
287
+ # Construct path to chunks file
288
+ chunks_path = os.path.join(self.embeddings_dir, f"{index_type}_chunks.json")
289
+
290
+ # Check file existence
291
+ if not os.path.exists(chunks_path):
292
+ logger.error(f"Index file not found: {chunks_path}")
293
+ return False
294
+
295
+ # Load chunks with error handling
296
+ with open(chunks_path, 'r', encoding='utf-8') as f:
297
+ chunks = json.load(f)
298
+
299
+ # Normalize keyword for flexible matching
300
+ keyword_lower = keyword.lower().strip()
301
+
302
+ # Advanced keyword matching
303
+ for chunk in chunks:
304
+ chunk_text = chunk.get('text', '').lower()
305
+
306
+ # Exact match
307
+ if keyword_lower in chunk_text:
308
+ logger.info(f"Exact match found for '{keyword}' in {index_type} index")
309
+ return True
310
+
311
+ # Partial match with word boundaries
312
+ if re.search(r'\b' + re.escape(keyword_lower) + r'\b', chunk_text):
313
+ logger.info(f"Partial match found for '{keyword}' in {index_type} index")
314
+ return True
315
+
316
+ # No match found
317
+ logger.info(f"No match found for '{keyword}' in {index_type} index")
318
+ return False
319
+
320
+ except json.JSONDecodeError:
321
+ logger.error(f"Invalid JSON in {chunks_path}")
322
+ return False
323
+ except IOError as e:
324
+ logger.error(f"IO error reading {chunks_path}: {e}")
325
+ return False
326
+ except Exception as e:
327
+ logger.error(f"Unexpected error in _check_keyword_in_index: {e}")
328
+ return False
329
+
330
  def handle_user_confirmation(self, extracted_info: Dict[str, str]) -> Dict[str, Any]:
331
  """
332
  Handle user confirmation for extracted condition and keywords
 
369
  'extracted_info': extracted_info
370
  }
371
 
372
+ def validate_medical_query(self, user_query: str) -> Dict[str, Any]:
373
+ """
374
+ Validate if the query is a medical-related query using multi-layer verification
375
+
376
+ Args:
377
+ user_query: User's input query
378
+
379
+ Returns:
380
+ Dict with validation result or None if medical query
381
+ """
382
+ # Expanded medical keywords covering comprehensive medical terminology
383
+ predefined_medical_keywords = {
384
+ # Symptoms and signs
385
+ 'pain', 'symptom', 'ache', 'fever', 'inflammation',
386
+ 'bleeding', 'swelling', 'rash', 'bruise', 'wound',
387
+
388
+ # Medical professional terms
389
+ 'disease', 'condition', 'syndrome', 'disorder',
390
+ 'medical', 'health', 'diagnosis', 'treatment',
391
+ 'therapy', 'medication', 'prescription',
392
+
393
+ # Body systems and organs
394
+ 'heart', 'lung', 'brain', 'kidney', 'liver',
395
+ 'blood', 'nerve', 'muscle', 'bone', 'joint',
396
+
397
+ # Medical actions
398
+ 'examine', 'check', 'test', 'scan', 'surgery',
399
+ 'operation', 'emergency', 'urgent', 'critical',
400
+
401
+ # Specific medical fields
402
+ 'cardiology', 'neurology', 'oncology', 'pediatrics',
403
+ 'psychiatry', 'dermatology', 'orthopedics'
404
+ }
405
+
406
+ # Check if query contains predefined medical keywords
407
+ query_lower = user_query.lower()
408
+ if any(kw in query_lower for kw in predefined_medical_keywords):
409
+ return None # Validated by predefined keywords
410
+
411
+ # Step 2: Use Meditron for final determination
412
+ try:
413
+ # Ensure Meditron client is properly initialized
414
+ if not hasattr(self, 'meditron_client') or self.meditron_client is None:
415
+ self.logger.warning("Meditron client not initialized")
416
+ return self._generate_invalid_query_response()
417
+
418
+ meditron_result = self.meditron_client.analyze_medical_query(
419
+ query=user_query,
420
+ max_tokens=100 # Limit tokens for efficiency
421
+ )
422
+
423
+ # If Meditron successfully extracts a medical condition
424
+ if meditron_result.get('extracted_condition'):
425
+ return None # Validated by Meditron
426
+
427
+ except Exception as e:
428
+ # Log Meditron analysis failure without blocking the process
429
+ self.logger.warning(f"Meditron query validation failed: {e}")
430
+
431
+ # If no medical relevance is found
432
+ return self._generate_invalid_query_response()
433
+
434
+ def _generate_invalid_query_response(self) -> Dict[str, Any]:
435
+ """
436
+ Generate response for non-medical queries
437
+
438
+ Returns:
439
+ Dict with invalid query guidance
440
+ """
441
+ return {
442
+ 'type': 'invalid_query',
443
+ 'message': "This is OnCall.AI, a clinical medical assistance platform. "
444
+ "Please input a medical problem you need help resolving. "
445
+ "\n\nExamples:\n"
446
+ "- 'I'm experiencing chest pain'\n"
447
+ "- 'What are symptoms of stroke?'\n"
448
+ "- 'How to manage acute asthma?'\n"
449
+ "- 'I have a persistent headache'"
450
+ }
451
+
452
  def main():
453
  """
454
  Example usage and testing of UserPromptProcessor