Spaces:
Running
Running
Commit
·
18d1c8f
1
Parent(s):
2100725
Added penalty for reasoning and future prediction questions
Browse files- app.py +42 -7
- data_filters.py +4 -0
app.py
CHANGED
@@ -24,6 +24,7 @@ from data_filters import (
|
|
24 |
FINANCIAL_ENTITY_LABELS,
|
25 |
GENERAL_KNOWLEDGE_PATTERNS,
|
26 |
sensitive_terms,
|
|
|
27 |
FINANCIAL_TERMS,
|
28 |
)
|
29 |
|
@@ -266,6 +267,15 @@ def is_general_knowledge_query(query):
|
|
266 |
return False
|
267 |
|
268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
def is_irrelevant_query(query):
|
270 |
"""Check if the query is not finance related"""
|
271 |
# If the query is general knowledge and not finance-related
|
@@ -365,6 +375,20 @@ def compute_entropy(logits):
|
|
365 |
return entropy.mean().item()
|
366 |
|
367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
# A confidence score is computed using FAISS and BM25 ranking
|
369 |
# FAISS: The similarity score between the response and the retrieved chunks are normalized
|
370 |
# BM25: The BM25 scores for the query and response combined tokens is normalized
|
@@ -375,12 +399,14 @@ def compute_response_confidence(
|
|
375 |
response,
|
376 |
retrieved_chunks,
|
377 |
bm25,
|
378 |
-
model_conf_signal
|
379 |
-
lambda_faiss=0.
|
380 |
-
lambda_conf=0.
|
381 |
-
lambda_bm25=1.
|
|
|
|
|
382 |
):
|
383 |
-
"""Calculates a confidence score
|
384 |
if not retrieved_chunks:
|
385 |
return 0.0
|
386 |
# Compute FAISS similarity
|
@@ -406,15 +432,24 @@ def compute_response_confidence(
|
|
406 |
normalized_bm25 = max(0, min(1, normalized_bm25))
|
407 |
else:
|
408 |
normalized_bm25 = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
logger.info(
|
410 |
-
f"Faiss score: {normalized_faiss},
|
411 |
-
f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}"
|
|
|
412 |
)
|
413 |
# Weighted sum of all the normalized scores
|
414 |
confidence_score = (
|
415 |
lambda_faiss * normalized_faiss
|
416 |
+ model_conf_signal * lambda_conf
|
417 |
+ lambda_bm25 * normalized_bm25
|
|
|
|
|
418 |
)
|
419 |
return round(min(100, max(0, confidence_score.item() * 100)), 2)
|
420 |
|
|
|
24 |
FINANCIAL_ENTITY_LABELS,
|
25 |
GENERAL_KNOWLEDGE_PATTERNS,
|
26 |
sensitive_terms,
|
27 |
+
EXPLANATORY_PATTERNS,
|
28 |
FINANCIAL_TERMS,
|
29 |
)
|
30 |
|
|
|
267 |
return False
|
268 |
|
269 |
|
270 |
+
def get_latest_available_year(retrieved_chunks):
|
271 |
+
"""Extracts the latest available year from retrieved financial data"""
|
272 |
+
years = set()
|
273 |
+
year_pattern = r"\b(20\d{2})\b"
|
274 |
+
for chunk in retrieved_chunks:
|
275 |
+
years.update(map(int, re.findall(year_pattern, chunk)))
|
276 |
+
return max(years) if years else 2024
|
277 |
+
|
278 |
+
|
279 |
def is_irrelevant_query(query):
|
280 |
"""Check if the query is not finance related"""
|
281 |
# If the query is general knowledge and not finance-related
|
|
|
375 |
return entropy.mean().item()
|
376 |
|
377 |
|
378 |
+
def contains_future_year(query, retrieved_chunks):
|
379 |
+
"""Detects if the query asks for future data beyond available reports"""
|
380 |
+
latest_year = get_latest_available_year(retrieved_chunks)
|
381 |
+
# Extract years from query
|
382 |
+
future_years = set(map(int, re.findall(r"\b(20\d{2})\b", query)))
|
383 |
+
return any(year > latest_year for year in future_years)
|
384 |
+
|
385 |
+
|
386 |
+
def is_explanatory_query(query):
|
387 |
+
"""Checks if the query requires an explanation rather than factual data"""
|
388 |
+
query_lower = query.lower()
|
389 |
+
return any(re.search(pattern, query_lower) for pattern in EXPLANATORY_PATTERNS)
|
390 |
+
|
391 |
+
|
392 |
# A confidence score is computed using FAISS and BM25 ranking
|
393 |
# FAISS: The similarity score between the response and the retrieved chunks are normalized
|
394 |
# BM25: The BM25 scores for the query and response combined tokens is normalized
|
|
|
399 |
response,
|
400 |
retrieved_chunks,
|
401 |
bm25,
|
402 |
+
model_conf_signal,
|
403 |
+
lambda_faiss=0.6,
|
404 |
+
lambda_conf=0.3,
|
405 |
+
lambda_bm25=1.0,
|
406 |
+
future_penalty=-0.3,
|
407 |
+
explanation_penalty=-0.2,
|
408 |
):
|
409 |
+
"""Calculates a confidence score for the model response"""
|
410 |
if not retrieved_chunks:
|
411 |
return 0.0
|
412 |
# Compute FAISS similarity
|
|
|
432 |
normalized_bm25 = max(0, min(1, normalized_bm25))
|
433 |
else:
|
434 |
normalized_bm25 = 0.0
|
435 |
+
# Penalize if query contains future years
|
436 |
+
future_penalty = -0.3 if contains_future_year(query, retrieved_chunks) else 0.0
|
437 |
+
# Penalize if query is reasoning based
|
438 |
+
explanation_penalty_value = (
|
439 |
+
explanation_penalty if is_explanatory_query(query) else 0.0
|
440 |
+
)
|
441 |
logger.info(
|
442 |
+
f"Faiss score: {normalized_faiss}, BM25: {normalized_bm25}\n"
|
443 |
+
f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}\n"
|
444 |
+
f"Future penalty: {future_penalty}, Reasoning penalty: {explanation_penalty_value}"
|
445 |
)
|
446 |
# Weighted sum of all the normalized scores
|
447 |
confidence_score = (
|
448 |
lambda_faiss * normalized_faiss
|
449 |
+ model_conf_signal * lambda_conf
|
450 |
+ lambda_bm25 * normalized_bm25
|
451 |
+
+ future_penalty
|
452 |
+
+ explanation_penalty_value
|
453 |
)
|
454 |
return round(min(100, max(0, confidence_score.item() * 100)), 2)
|
455 |
|
data_filters.py
CHANGED
@@ -48,6 +48,10 @@ sensitive_terms = {
|
|
48 |
"wages",
|
49 |
}
|
50 |
|
|
|
|
|
|
|
|
|
51 |
|
52 |
FINANCIAL_DATA_PATTERNS = (
|
53 |
r"\b(\₹?\s?\d{1,3}(?:,\d{2,3})*(?:\.\d+)?\s*(million|billion|crore|lakh|%)"
|
|
|
48 |
"wages",
|
49 |
}
|
50 |
|
51 |
+
EXPLANATORY_PATTERNS = [
|
52 |
+
r"\b(why|reason|cause|explanation|due to|because|factor|impact of|effect of|influence of|driven by)\b",
|
53 |
+
r"\b(how did|what led to|what caused|why did|how was|contributing factor|explain)\b",
|
54 |
+
]
|
55 |
|
56 |
FINANCIAL_DATA_PATTERNS = (
|
57 |
r"\b(\₹?\s?\d{1,3}(?:,\d{2,3})*(?:\.\d+)?\s*(million|billion|crore|lakh|%)"
|