PhoenixDecim commited on
Commit
18d1c8f
·
1 Parent(s): 2100725

Added penalty for reasoning and future prediction questions

Browse files
Files changed (2) hide show
  1. app.py +42 -7
  2. data_filters.py +4 -0
app.py CHANGED
@@ -24,6 +24,7 @@ from data_filters import (
24
  FINANCIAL_ENTITY_LABELS,
25
  GENERAL_KNOWLEDGE_PATTERNS,
26
  sensitive_terms,
 
27
  FINANCIAL_TERMS,
28
  )
29
 
@@ -266,6 +267,15 @@ def is_general_knowledge_query(query):
266
  return False
267
 
268
 
 
 
 
 
 
 
 
 
 
269
  def is_irrelevant_query(query):
270
  """Check if the query is not finance related"""
271
  # If the query is general knowledge and not finance-related
@@ -365,6 +375,20 @@ def compute_entropy(logits):
365
  return entropy.mean().item()
366
 
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  # A confidence score is computed using FAISS and BM25 ranking
369
  # FAISS: The similarity score between the response and the retrieved chunks are normalized
370
  # BM25: The BM25 scores for the query and response combined tokens is normalized
@@ -375,12 +399,14 @@ def compute_response_confidence(
375
  response,
376
  retrieved_chunks,
377
  bm25,
378
- model_conf_signal=0.5,
379
- lambda_faiss=0.4,
380
- lambda_conf=0.4,
381
- lambda_bm25=1.8,
 
 
382
  ):
383
- """Calculates a confidence score using FAISS, BM25, top token probabilites and entropy score"""
384
  if not retrieved_chunks:
385
  return 0.0
386
  # Compute FAISS similarity
@@ -406,15 +432,24 @@ def compute_response_confidence(
406
  normalized_bm25 = max(0, min(1, normalized_bm25))
407
  else:
408
  normalized_bm25 = 0.0
 
 
 
 
 
 
409
  logger.info(
410
- f"Faiss score: {normalized_faiss}, bm25: {normalized_bm25}, "
411
- f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}"
 
412
  )
413
  # Weighted sum of all the normalized scores
414
  confidence_score = (
415
  lambda_faiss * normalized_faiss
416
  + model_conf_signal * lambda_conf
417
  + lambda_bm25 * normalized_bm25
 
 
418
  )
419
  return round(min(100, max(0, confidence_score.item() * 100)), 2)
420
 
 
24
  FINANCIAL_ENTITY_LABELS,
25
  GENERAL_KNOWLEDGE_PATTERNS,
26
  sensitive_terms,
27
+ EXPLANATORY_PATTERNS,
28
  FINANCIAL_TERMS,
29
  )
30
 
 
267
  return False
268
 
269
 
270
+ def get_latest_available_year(retrieved_chunks):
271
+ """Extracts the latest available year from retrieved financial data"""
272
+ years = set()
273
+ year_pattern = r"\b(20\d{2})\b"
274
+ for chunk in retrieved_chunks:
275
+ years.update(map(int, re.findall(year_pattern, chunk)))
276
+ return max(years) if years else 2024
277
+
278
+
279
  def is_irrelevant_query(query):
280
  """Check if the query is not finance related"""
281
  # If the query is general knowledge and not finance-related
 
375
  return entropy.mean().item()
376
 
377
 
378
+ def contains_future_year(query, retrieved_chunks):
379
+ """Detects if the query asks for future data beyond available reports"""
380
+ latest_year = get_latest_available_year(retrieved_chunks)
381
+ # Extract years from query
382
+ future_years = set(map(int, re.findall(r"\b(20\d{2})\b", query)))
383
+ return any(year > latest_year for year in future_years)
384
+
385
+
386
+ def is_explanatory_query(query):
387
+ """Checks if the query requires an explanation rather than factual data"""
388
+ query_lower = query.lower()
389
+ return any(re.search(pattern, query_lower) for pattern in EXPLANATORY_PATTERNS)
390
+
391
+
392
  # A confidence score is computed using FAISS and BM25 ranking
393
  # FAISS: The similarity score between the response and the retrieved chunks are normalized
394
  # BM25: The BM25 scores for the query and response combined tokens is normalized
 
399
  response,
400
  retrieved_chunks,
401
  bm25,
402
+ model_conf_signal,
403
+ lambda_faiss=0.6,
404
+ lambda_conf=0.3,
405
+ lambda_bm25=1.0,
406
+ future_penalty=-0.3,
407
+ explanation_penalty=-0.2,
408
  ):
409
+ """Calculates a confidence score for the model response"""
410
  if not retrieved_chunks:
411
  return 0.0
412
  # Compute FAISS similarity
 
432
  normalized_bm25 = max(0, min(1, normalized_bm25))
433
  else:
434
  normalized_bm25 = 0.0
435
+ # Penalize if query contains future years
436
+ future_penalty = -0.3 if contains_future_year(query, retrieved_chunks) else 0.0
437
+ # Penalize if query is reasoning based
438
+ explanation_penalty_value = (
439
+ explanation_penalty if is_explanatory_query(query) else 0.0
440
+ )
441
  logger.info(
442
+ f"Faiss score: {normalized_faiss}, BM25: {normalized_bm25}\n"
443
+ f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}\n"
444
+ f"Future penalty: {future_penalty}, Reasoning penalty: {explanation_penalty_value}"
445
  )
446
  # Weighted sum of all the normalized scores
447
  confidence_score = (
448
  lambda_faiss * normalized_faiss
449
  + model_conf_signal * lambda_conf
450
  + lambda_bm25 * normalized_bm25
451
+ + future_penalty
452
+ + explanation_penalty_value
453
  )
454
  return round(min(100, max(0, confidence_score.item() * 100)), 2)
455
 
data_filters.py CHANGED
@@ -48,6 +48,10 @@ sensitive_terms = {
48
  "wages",
49
  }
50
 
 
 
 
 
51
 
52
  FINANCIAL_DATA_PATTERNS = (
53
  r"\b(\₹?\s?\d{1,3}(?:,\d{2,3})*(?:\.\d+)?\s*(million|billion|crore|lakh|%)"
 
48
  "wages",
49
  }
50
 
51
+ EXPLANATORY_PATTERNS = [
52
+ r"\b(why|reason|cause|explanation|due to|because|factor|impact of|effect of|influence of|driven by)\b",
53
+ r"\b(how did|what led to|what caused|why did|how was|contributing factor|explain)\b",
54
+ ]
55
 
56
  FINANCIAL_DATA_PATTERNS = (
57
  r"\b(\₹?\s?\d{1,3}(?:,\d{2,3})*(?:\.\d+)?\s*(million|billion|crore|lakh|%)"