PhoenixDecim commited on
Commit
e1e8013
·
1 Parent(s): a704797

Improved input query filters

Browse files
Files changed (2) hide show
  1. app.py +53 -14
  2. data_filters.py +9 -0
app.py CHANGED
@@ -21,6 +21,8 @@ from data_filters import (
21
  restricted_patterns,
22
  restricted_topics,
23
  FINANCIAL_DATA_PATTERNS,
 
 
24
  sensitive_terms,
25
  FINANCIAL_TERMS,
26
  )
@@ -37,8 +39,8 @@ os.makedirs("data", exist_ok=True)
37
  # SLM: Microsoft PHI-2 model is loaded
38
  # It does have higher memory and compute requirements compared to TinyLlama and Falcon
39
  # But it gives the best results among the three
40
- DEVICE = "cpu" # or cuda
41
- # DEVICE = "cuda" # or cuda
42
  # MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
43
  # MODEL_NAME = "tiiuae/falcon-rw-1b"
44
  MODEL_NAME = "microsoft/phi-2"
@@ -55,7 +57,7 @@ if tokenizer.pad_token is None:
55
  # Since the model is to be hosted on a cpu instance, we use float32
56
  # For GPU, we can use float16 or bfloat16
57
  model = AutoModelForCausalLM.from_pretrained(
58
- MODEL_NAME, torch_dtype=torch.float32, trust_remote_code=True
59
  ).to(DEVICE)
60
  model.eval()
61
  # model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
@@ -234,25 +236,62 @@ def process_files(files, chunk_size=512):
234
  pickle.dump(bm25_data, f)
235
  return "Files processed successfully! You can now query."
236
 
 
237
  def contains_financial_entities(query):
238
- """Check if the query has financial entities"""
239
  doc = nlp(query)
240
  for ent in doc.ents:
241
  if ent.label_ in FINANCIAL_ENTITY_LABELS:
242
  return True
243
  return False
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  # Input guardrail implementation
 
246
  # Regex is used to filter queries related to sensitive topics
247
  # Uses spaCy model's Named Entity Recognition to filter queries for personal details
248
  # Uses cosine similarity with the embedded query and sensitive topic vectors
249
  # to filter out queries violating confidential/security rules (additional)
250
  def is_query_allowed(query):
251
  """Checks if the query violates security or confidentiality rules"""
 
 
252
  for pattern in restricted_patterns:
253
  if re.search(pattern, query.lower(), re.IGNORECASE):
254
  return False, "This query requests sensitive or confidential information."
255
  doc = nlp(query)
 
256
  for ent in doc.ents:
257
  if ent.label_ == "PERSON":
258
  for token in ent.subtree:
@@ -265,6 +304,7 @@ def is_query_allowed(query):
265
  topic_embeddings = embed_model.encode(
266
  list(restricted_topics), normalize_embeddings=True
267
  )
 
268
  similarities = np.dot(topic_embeddings, query_embedding)
269
  if np.max(similarities) > 0.85:
270
  return False, "This query requests sensitive or confidential information."
@@ -368,8 +408,9 @@ def compute_response_confidence(
368
  normalized_bm25 = 0.0
369
  logger.info(
370
  f"Faiss score: {normalized_faiss}, bm25: {normalized_bm25}, "
371
- f"Mean Top Token + Entropy Avg: {model_conf_signal}"
372
  )
 
373
  confidence_score = (
374
  lambda_faiss * normalized_faiss
375
  + model_conf_signal * lambda_conf
@@ -436,13 +477,10 @@ def query_model(
436
  "You are a financial analyst. Answer financial queries concisely using only the numerical data "
437
  "explicitly present in the provided financial context:\n\n"
438
  f"{context}\n\n"
439
- "Strictly use only the given financial data. Do not assume, infer, or generate missing data."
440
- " Retain the original format of financial figures exactly as given."
441
- " Do not attempt to convert the currency into any other format."
442
- " If the requested information is not available in the provided context, respond with "
443
- "'No relevant financial data available.'"
444
- " Provide exactly one answer in a single sentence."
445
- " Do not generate explanations, additional text, or answer multiple queries."
446
  f"\nQuery: {query}"
447
  )
448
  inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
@@ -463,7 +501,8 @@ def query_model(
463
  sequences = output["sequences"][0][input_len:]
464
  execution_time = time.perf_counter() - start_time
465
  logger.info(f"Query processed in {execution_time:.2f} seconds.")
466
- log_probs = output["scores"] # List of logits per generated token
 
467
  token_probs = [torch.softmax(lp, dim=-1) for lp in log_probs]
468
  # Extract top token probabilities for each step
469
  token_confidences = [tp.max().item() for tp in token_probs]
@@ -487,7 +526,7 @@ def query_model(
487
  final_out += f"Context: {context}\nQuery: {query}\n"
488
  final_out += f"Response: {response}"
489
  return (
490
- response,
491
  f"Confidence: {confidence_score}%\nTime taken: {execution_time:.2f} seconds",
492
  )
493
 
 
21
  restricted_patterns,
22
  restricted_topics,
23
  FINANCIAL_DATA_PATTERNS,
24
+ FINANCIAL_ENTITY_LABELS,
25
+ GENERAL_KNOWLEDGE_PATTERNS,
26
  sensitive_terms,
27
  FINANCIAL_TERMS,
28
  )
 
39
  # SLM: Microsoft PHI-2 model is loaded
40
  # It does have higher memory and compute requirements compared to TinyLlama and Falcon
41
  # But it gives the best results among the three
42
+ # DEVICE = "cpu" # or cuda
43
+ DEVICE = "cuda" # or cuda
44
  # MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
45
  # MODEL_NAME = "tiiuae/falcon-rw-1b"
46
  MODEL_NAME = "microsoft/phi-2"
 
57
  # Since the model is to be hosted on a cpu instance, we use float32
58
  # For GPU, we can use float16 or bfloat16
59
  model = AutoModelForCausalLM.from_pretrained(
60
+ MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True
61
  ).to(DEVICE)
62
  model.eval()
63
  # model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
 
236
  pickle.dump(bm25_data, f)
237
  return "Files processed successfully! You can now query."
238
 
239
+
240
  def contains_financial_entities(query):
241
+ """Check if query contains financial entities"""
242
  doc = nlp(query)
243
  for ent in doc.ents:
244
  if ent.label_ in FINANCIAL_ENTITY_LABELS:
245
  return True
246
  return False
247
 
248
+
249
+ def contains_geographical_entities(query):
250
+ """Check if the query contains geographical entities"""
251
+ doc = nlp(query)
252
+ return any(ent.label_ == "GPE" for ent in doc.ents)
253
+
254
+
255
+ def contains_financial_terms(query):
256
+ """Check if the query contains financial terms"""
257
+ return any(term in query.lower() for term in FINANCIAL_TERMS)
258
+
259
+
260
+ def is_general_knowledge_query(query):
261
+ """Check if query contains general knowledge"""
262
+ query_lower = query.lower()
263
+ for pattern in GENERAL_KNOWLEDGE_PATTERNS:
264
+ if re.search(pattern, query_lower):
265
+ return True
266
+ return False
267
+
268
+
269
+ def is_irrelevant_query(query):
270
+ """Check if the query is not finance related"""
271
+ # If the query is general knowledge and not finance-related
272
+ if is_general_knowledge_query(query) and not contains_financial_terms(query):
273
+ return True
274
+ # If the query contains only geographical terms without financial entities
275
+ if contains_geographical_entities(query) and not contains_financial_entities(query):
276
+ return True
277
+ return False
278
+
279
+
280
  # Input guardrail implementation
281
+ # NER + Regex + List of terms used to filter irrelevant queries
282
  # Regex is used to filter queries related to sensitive topics
283
  # Uses spaCy model's Named Entity Recognition to filter queries for personal details
284
  # Uses cosine similarity with the embedded query and sensitive topic vectors
285
  # to filter out queries violating confidential/security rules (additional)
286
  def is_query_allowed(query):
287
  """Checks if the query violates security or confidentiality rules"""
288
+ if is_irrelevant_query(query):
289
+ return False, "Query is not finance-related. Please ask a financial question."
290
  for pattern in restricted_patterns:
291
  if re.search(pattern, query.lower(), re.IGNORECASE):
292
  return False, "This query requests sensitive or confidential information."
293
  doc = nlp(query)
294
+ # Check if there's a person entity and contains sensitive terms
295
  for ent in doc.ents:
296
  if ent.label_ == "PERSON":
297
  for token in ent.subtree:
 
304
  topic_embeddings = embed_model.encode(
305
  list(restricted_topics), normalize_embeddings=True
306
  )
307
+ # Check similarities between the restricted topics and the query
308
  similarities = np.dot(topic_embeddings, query_embedding)
309
  if np.max(similarities) > 0.85:
310
  return False, "This query requests sensitive or confidential information."
 
408
  normalized_bm25 = 0.0
409
  logger.info(
410
  f"Faiss score: {normalized_faiss}, bm25: {normalized_bm25}, "
411
+ f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}"
412
  )
413
+ # Weighted sum of all the normalized scores
414
  confidence_score = (
415
  lambda_faiss * normalized_faiss
416
  + model_conf_signal * lambda_conf
 
477
  "You are a financial analyst. Answer financial queries concisely using only the numerical data "
478
  "explicitly present in the provided financial context:\n\n"
479
  f"{context}\n\n"
480
+ "Use only the given financial data—do not assume, infer, or generate missing values."
481
+ " Retain the original format of financial figures without conversion."
482
+ " If the requested information is unavailable, respond with 'No relevant financial data available.'"
483
+ " Provide a single-sentence answer without explanations, additional text, or multiple responses."
 
 
 
484
  f"\nQuery: {query}"
485
  )
486
  inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
 
501
  sequences = output["sequences"][0][input_len:]
502
  execution_time = time.perf_counter() - start_time
503
  logger.info(f"Query processed in {execution_time:.2f} seconds.")
504
+ # Get the logits per generated token
505
+ log_probs = output["scores"]
506
  token_probs = [torch.softmax(lp, dim=-1) for lp in log_probs]
507
  # Extract top token probabilities for each step
508
  token_confidences = [tp.max().item() for tp in token_probs]
 
526
  final_out += f"Context: {context}\nQuery: {query}\n"
527
  final_out += f"Response: {response}"
528
  return (
529
+ final_out,
530
  f"Confidence: {confidence_score}%\nTime taken: {execution_time:.2f} seconds",
531
  )
532
 
data_filters.py CHANGED
@@ -29,6 +29,15 @@ restricted_topics = {
29
  "financial package",
30
  }
31
 
 
 
 
 
 
 
 
 
 
32
  sensitive_terms = {
33
  "salary",
34
  "compensation",
 
29
  "financial package",
30
  }
31
 
32
+ FINANCIAL_ENTITY_LABELS = {"MONEY", "PERCENT", "CARDINAL", "ORG"}
33
+
34
+ GENERAL_KNOWLEDGE_PATTERNS = [
35
+ r"\b(?:capital of|where is|who is|when did|what is|history of|define|meaning of|synonym of|antonym of|explain|how does|why is)\b",
36
+ r"\b(?:country|city|continent|leader|president|prime minister|language|currency|population|politics|war|anthem|flag|national animal|national bird|national flower|national sport|monarch|king|queen|ruler|army|military|constitution|government|laws|famous person|historical figure|famous landmark|ocean|mountain|river|lake|climate|weather|culture|tradition|festival|holiday|invention|discovery|science|technology|art|literature|music|religion|mythology|folklore|education|university|school|mathematics|physics|chemistry|biology|philosophy|astronomy|space|planet|star|galaxy|universe|health|medicine|disease|virus|bacteria|genetics|DNA|evolution|ecology|environment|pollution|wildlife|habitat|natural disaster|earthquake|volcano|tsunami|hurricane|storm|flood|drought)\b",
37
+ r"\b(?:[A-Z][a-z]+(?:'s)?\s+(?:capital|president|prime minister|national animal|national bird|national flower|national sport|anthem|flag|currency|language|leader|government|constitution|laws|monarch|king|queen|army|military|famous person|historical figure|landmark|river|ocean|mountain|religion|festival|holiday))\b",
38
+ ]
39
+
40
+
41
  sensitive_terms = {
42
  "salary",
43
  "compensation",