Spaces:
Running
Running
Commit
·
e1e8013
1
Parent(s):
a704797
Improved input query filters
Browse files- app.py +53 -14
- data_filters.py +9 -0
app.py
CHANGED
@@ -21,6 +21,8 @@ from data_filters import (
|
|
21 |
restricted_patterns,
|
22 |
restricted_topics,
|
23 |
FINANCIAL_DATA_PATTERNS,
|
|
|
|
|
24 |
sensitive_terms,
|
25 |
FINANCIAL_TERMS,
|
26 |
)
|
@@ -37,8 +39,8 @@ os.makedirs("data", exist_ok=True)
|
|
37 |
# SLM: Microsoft PHI-2 model is loaded
|
38 |
# It does have higher memory and compute requirements compared to TinyLlama and Falcon
|
39 |
# But it gives the best results among the three
|
40 |
-
DEVICE = "cpu" # or cuda
|
41 |
-
|
42 |
# MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
|
43 |
# MODEL_NAME = "tiiuae/falcon-rw-1b"
|
44 |
MODEL_NAME = "microsoft/phi-2"
|
@@ -55,7 +57,7 @@ if tokenizer.pad_token is None:
|
|
55 |
# Since the model is to be hosted on a cpu instance, we use float32
|
56 |
# For GPU, we can use float16 or bfloat16
|
57 |
model = AutoModelForCausalLM.from_pretrained(
|
58 |
-
MODEL_NAME, torch_dtype=torch.
|
59 |
).to(DEVICE)
|
60 |
model.eval()
|
61 |
# model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
@@ -234,25 +236,62 @@ def process_files(files, chunk_size=512):
|
|
234 |
pickle.dump(bm25_data, f)
|
235 |
return "Files processed successfully! You can now query."
|
236 |
|
|
|
237 |
def contains_financial_entities(query):
|
238 |
-
"""Check if
|
239 |
doc = nlp(query)
|
240 |
for ent in doc.ents:
|
241 |
if ent.label_ in FINANCIAL_ENTITY_LABELS:
|
242 |
return True
|
243 |
return False
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
# Input guardrail implementation
|
|
|
246 |
# Regex is used to filter queries related to sensitive topics
|
247 |
# Uses spaCy model's Named Entity Recognition to filter queries for personal details
|
248 |
# Uses cosine similarity with the embedded query and sensitive topic vectors
|
249 |
# to filter out queries violating confidential/security rules (additional)
|
250 |
def is_query_allowed(query):
|
251 |
"""Checks if the query violates security or confidentiality rules"""
|
|
|
|
|
252 |
for pattern in restricted_patterns:
|
253 |
if re.search(pattern, query.lower(), re.IGNORECASE):
|
254 |
return False, "This query requests sensitive or confidential information."
|
255 |
doc = nlp(query)
|
|
|
256 |
for ent in doc.ents:
|
257 |
if ent.label_ == "PERSON":
|
258 |
for token in ent.subtree:
|
@@ -265,6 +304,7 @@ def is_query_allowed(query):
|
|
265 |
topic_embeddings = embed_model.encode(
|
266 |
list(restricted_topics), normalize_embeddings=True
|
267 |
)
|
|
|
268 |
similarities = np.dot(topic_embeddings, query_embedding)
|
269 |
if np.max(similarities) > 0.85:
|
270 |
return False, "This query requests sensitive or confidential information."
|
@@ -368,8 +408,9 @@ def compute_response_confidence(
|
|
368 |
normalized_bm25 = 0.0
|
369 |
logger.info(
|
370 |
f"Faiss score: {normalized_faiss}, bm25: {normalized_bm25}, "
|
371 |
-
f"Mean Top Token + Entropy Avg: {model_conf_signal}"
|
372 |
)
|
|
|
373 |
confidence_score = (
|
374 |
lambda_faiss * normalized_faiss
|
375 |
+ model_conf_signal * lambda_conf
|
@@ -436,13 +477,10 @@ def query_model(
|
|
436 |
"You are a financial analyst. Answer financial queries concisely using only the numerical data "
|
437 |
"explicitly present in the provided financial context:\n\n"
|
438 |
f"{context}\n\n"
|
439 |
-
"
|
440 |
-
" Retain the original format of financial figures
|
441 |
-
"
|
442 |
-
"
|
443 |
-
"'No relevant financial data available.'"
|
444 |
-
" Provide exactly one answer in a single sentence."
|
445 |
-
" Do not generate explanations, additional text, or answer multiple queries."
|
446 |
f"\nQuery: {query}"
|
447 |
)
|
448 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
|
@@ -463,7 +501,8 @@ def query_model(
|
|
463 |
sequences = output["sequences"][0][input_len:]
|
464 |
execution_time = time.perf_counter() - start_time
|
465 |
logger.info(f"Query processed in {execution_time:.2f} seconds.")
|
466 |
-
|
|
|
467 |
token_probs = [torch.softmax(lp, dim=-1) for lp in log_probs]
|
468 |
# Extract top token probabilities for each step
|
469 |
token_confidences = [tp.max().item() for tp in token_probs]
|
@@ -487,7 +526,7 @@ def query_model(
|
|
487 |
final_out += f"Context: {context}\nQuery: {query}\n"
|
488 |
final_out += f"Response: {response}"
|
489 |
return (
|
490 |
-
|
491 |
f"Confidence: {confidence_score}%\nTime taken: {execution_time:.2f} seconds",
|
492 |
)
|
493 |
|
|
|
21 |
restricted_patterns,
|
22 |
restricted_topics,
|
23 |
FINANCIAL_DATA_PATTERNS,
|
24 |
+
FINANCIAL_ENTITY_LABELS,
|
25 |
+
GENERAL_KNOWLEDGE_PATTERNS,
|
26 |
sensitive_terms,
|
27 |
FINANCIAL_TERMS,
|
28 |
)
|
|
|
39 |
# SLM: Microsoft PHI-2 model is loaded
|
40 |
# It does have higher memory and compute requirements compared to TinyLlama and Falcon
|
41 |
# But it gives the best results among the three
|
42 |
+
# DEVICE = "cpu" # or cuda
|
43 |
+
DEVICE = "cuda" # or cuda
|
44 |
# MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
|
45 |
# MODEL_NAME = "tiiuae/falcon-rw-1b"
|
46 |
MODEL_NAME = "microsoft/phi-2"
|
|
|
57 |
# Since the model is to be hosted on a cpu instance, we use float32
|
58 |
# For GPU, we can use float16 or bfloat16
|
59 |
model = AutoModelForCausalLM.from_pretrained(
|
60 |
+
MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True
|
61 |
).to(DEVICE)
|
62 |
model.eval()
|
63 |
# model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
|
|
236 |
pickle.dump(bm25_data, f)
|
237 |
return "Files processed successfully! You can now query."
|
238 |
|
239 |
+
|
240 |
def contains_financial_entities(query):
|
241 |
+
"""Check if query contains financial entities"""
|
242 |
doc = nlp(query)
|
243 |
for ent in doc.ents:
|
244 |
if ent.label_ in FINANCIAL_ENTITY_LABELS:
|
245 |
return True
|
246 |
return False
|
247 |
|
248 |
+
|
249 |
+
def contains_geographical_entities(query):
|
250 |
+
"""Check if the query contains geographical entities"""
|
251 |
+
doc = nlp(query)
|
252 |
+
return any(ent.label_ == "GPE" for ent in doc.ents)
|
253 |
+
|
254 |
+
|
255 |
+
def contains_financial_terms(query):
|
256 |
+
"""Check if the query contains financial terms"""
|
257 |
+
return any(term in query.lower() for term in FINANCIAL_TERMS)
|
258 |
+
|
259 |
+
|
260 |
+
def is_general_knowledge_query(query):
|
261 |
+
"""Check if query contains general knowledge"""
|
262 |
+
query_lower = query.lower()
|
263 |
+
for pattern in GENERAL_KNOWLEDGE_PATTERNS:
|
264 |
+
if re.search(pattern, query_lower):
|
265 |
+
return True
|
266 |
+
return False
|
267 |
+
|
268 |
+
|
269 |
+
def is_irrelevant_query(query):
|
270 |
+
"""Check if the query is not finance related"""
|
271 |
+
# If the query is general knowledge and not finance-related
|
272 |
+
if is_general_knowledge_query(query) and not contains_financial_terms(query):
|
273 |
+
return True
|
274 |
+
# If the query contains only geographical terms without financial entities
|
275 |
+
if contains_geographical_entities(query) and not contains_financial_entities(query):
|
276 |
+
return True
|
277 |
+
return False
|
278 |
+
|
279 |
+
|
280 |
# Input guardrail implementation
|
281 |
+
# NER + Regex + List of terms used to filter irrelevant queries
|
282 |
# Regex is used to filter queries related to sensitive topics
|
283 |
# Uses spaCy model's Named Entity Recognition to filter queries for personal details
|
284 |
# Uses cosine similarity with the embedded query and sensitive topic vectors
|
285 |
# to filter out queries violating confidential/security rules (additional)
|
286 |
def is_query_allowed(query):
|
287 |
"""Checks if the query violates security or confidentiality rules"""
|
288 |
+
if is_irrelevant_query(query):
|
289 |
+
return False, "Query is not finance-related. Please ask a financial question."
|
290 |
for pattern in restricted_patterns:
|
291 |
if re.search(pattern, query.lower(), re.IGNORECASE):
|
292 |
return False, "This query requests sensitive or confidential information."
|
293 |
doc = nlp(query)
|
294 |
+
# Check if there's a person entity and contains sensitive terms
|
295 |
for ent in doc.ents:
|
296 |
if ent.label_ == "PERSON":
|
297 |
for token in ent.subtree:
|
|
|
304 |
topic_embeddings = embed_model.encode(
|
305 |
list(restricted_topics), normalize_embeddings=True
|
306 |
)
|
307 |
+
# Check similarities between the restricted topics and the query
|
308 |
similarities = np.dot(topic_embeddings, query_embedding)
|
309 |
if np.max(similarities) > 0.85:
|
310 |
return False, "This query requests sensitive or confidential information."
|
|
|
408 |
normalized_bm25 = 0.0
|
409 |
logger.info(
|
410 |
f"Faiss score: {normalized_faiss}, bm25: {normalized_bm25}, "
|
411 |
+
f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}"
|
412 |
)
|
413 |
+
# Weighted sum of all the normalized scores
|
414 |
confidence_score = (
|
415 |
lambda_faiss * normalized_faiss
|
416 |
+ model_conf_signal * lambda_conf
|
|
|
477 |
"You are a financial analyst. Answer financial queries concisely using only the numerical data "
|
478 |
"explicitly present in the provided financial context:\n\n"
|
479 |
f"{context}\n\n"
|
480 |
+
"Use only the given financial data—do not assume, infer, or generate missing values."
|
481 |
+
" Retain the original format of financial figures without conversion."
|
482 |
+
" If the requested information is unavailable, respond with 'No relevant financial data available.'"
|
483 |
+
" Provide a single-sentence answer without explanations, additional text, or multiple responses."
|
|
|
|
|
|
|
484 |
f"\nQuery: {query}"
|
485 |
)
|
486 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
|
|
|
501 |
sequences = output["sequences"][0][input_len:]
|
502 |
execution_time = time.perf_counter() - start_time
|
503 |
logger.info(f"Query processed in {execution_time:.2f} seconds.")
|
504 |
+
# Get the logits per generated token
|
505 |
+
log_probs = output["scores"]
|
506 |
token_probs = [torch.softmax(lp, dim=-1) for lp in log_probs]
|
507 |
# Extract top token probabilities for each step
|
508 |
token_confidences = [tp.max().item() for tp in token_probs]
|
|
|
526 |
final_out += f"Context: {context}\nQuery: {query}\n"
|
527 |
final_out += f"Response: {response}"
|
528 |
return (
|
529 |
+
final_out,
|
530 |
f"Confidence: {confidence_score}%\nTime taken: {execution_time:.2f} seconds",
|
531 |
)
|
532 |
|
data_filters.py
CHANGED
@@ -29,6 +29,15 @@ restricted_topics = {
|
|
29 |
"financial package",
|
30 |
}
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
sensitive_terms = {
|
33 |
"salary",
|
34 |
"compensation",
|
|
|
29 |
"financial package",
|
30 |
}
|
31 |
|
32 |
+
FINANCIAL_ENTITY_LABELS = {"MONEY", "PERCENT", "CARDINAL", "ORG"}
|
33 |
+
|
34 |
+
GENERAL_KNOWLEDGE_PATTERNS = [
|
35 |
+
r"\b(?:capital of|where is|who is|when did|what is|history of|define|meaning of|synonym of|antonym of|explain|how does|why is)\b",
|
36 |
+
r"\b(?:country|city|continent|leader|president|prime minister|language|currency|population|politics|war|anthem|flag|national animal|national bird|national flower|national sport|monarch|king|queen|ruler|army|military|constitution|government|laws|famous person|historical figure|famous landmark|ocean|mountain|river|lake|climate|weather|culture|tradition|festival|holiday|invention|discovery|science|technology|art|literature|music|religion|mythology|folklore|education|university|school|mathematics|physics|chemistry|biology|philosophy|astronomy|space|planet|star|galaxy|universe|health|medicine|disease|virus|bacteria|genetics|DNA|evolution|ecology|environment|pollution|wildlife|habitat|natural disaster|earthquake|volcano|tsunami|hurricane|storm|flood|drought)\b",
|
37 |
+
r"\b(?:[A-Z][a-z]+(?:'s)?\s+(?:capital|president|prime minister|national animal|national bird|national flower|national sport|anthem|flag|currency|language|leader|government|constitution|laws|monarch|king|queen|army|military|famous person|historical figure|landmark|river|ocean|mountain|religion|festival|holiday))\b",
|
38 |
+
]
|
39 |
+
|
40 |
+
|
41 |
sensitive_terms = {
|
42 |
"salary",
|
43 |
"compensation",
|