PhoenixDecim commited on
Commit
ab6a69a
Β·
1 Parent(s): 92ae1b2

Modified the prompt and confidence score

Browse files
Files changed (3) hide show
  1. README.md +1 -3
  2. app.py +96 -45
  3. data_filters.py +5 -5
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Slm Financial Rag
3
  emoji: πŸš€
4
  colorFrom: green
5
  colorTo: indigo
@@ -10,5 +10,3 @@ pinned: false
10
  license: unknown
11
  short_description: SLM with RAG for Financial Reports
12
  ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: SLM Financial RAG
3
  emoji: πŸš€
4
  colorFrom: green
5
  colorTo: indigo
 
10
  license: unknown
11
  short_description: SLM with RAG for Financial Reports
12
  ---
 
 
app.py CHANGED
@@ -37,7 +37,8 @@ os.makedirs("data", exist_ok=True)
37
  # SLM: Microsoft PHI-2 model is loaded
38
  # It does have higher memory and compute requirements compared to TinyLlama and Falcon
39
  # But it gives the best results among the three
40
- DEVICE = "cpu" # or cuda
 
41
  # MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
42
  # MODEL_NAME = "tiiuae/falcon-rw-1b"
43
  MODEL_NAME = "microsoft/phi-2"
@@ -54,7 +55,7 @@ if tokenizer.pad_token is None:
54
  # Since the model is to be hosted on a cpu instance, we use float32
55
  # For GPU, we can use float16 or bfloat16
56
  model = AutoModelForCausalLM.from_pretrained(
57
- MODEL_NAME, torch_dtype=torch.float32, trust_remote_code=True
58
  ).to(DEVICE)
59
  model.eval()
60
  # model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
@@ -233,6 +234,13 @@ def process_files(files, chunk_size=512):
233
  pickle.dump(bm25_data, f)
234
  return "Files processed successfully! You can now query."
235
 
 
 
 
 
 
 
 
236
 
237
  # Input guardrail implementation
238
  # Regex is used to filter queries related to sensitive topics
@@ -242,7 +250,7 @@ def process_files(files, chunk_size=512):
242
  def is_query_allowed(query):
243
  """Checks if the query violates security or confidentiality rules"""
244
  for pattern in restricted_patterns:
245
- if re.search(pattern, query, re.IGNORECASE):
246
  return False, "This query requests sensitive or confidential information."
247
  doc = nlp(query)
248
  for ent in doc.ents:
@@ -309,45 +317,65 @@ def hybrid_retrieve(query, chunk_texts, index, bm25, top_k=5, lambda_faiss=0.7):
309
  return final_results
310
 
311
 
 
 
 
 
 
 
 
 
312
  # A confidence score is computed using FAISS and BM25 ranking
313
- # FAISS: The similarity score between the query (with response) and the retrieved chunks are normalized
314
- # BM25: The BM25 scores for the query is normalized
315
- # Both the scores are aggregated using a weighted sum (lambda FAISS) and normalized
316
- def compute_confidence_score(query, retrieved_chunks, bm25, lambda_faiss):
317
- """Calculates a confidence score using FAISS and BM25 rankings."""
 
 
 
 
 
 
 
 
 
 
318
  if not retrieved_chunks:
319
- return 0
320
- query_embedding = embed_model.encode(query, normalize_embeddings=True)
321
- response_embedding = embed_model.encode(
322
  " ".join(retrieved_chunks), normalize_embeddings=True
323
  )
324
- # FAISS Similarity
325
- faiss_score = np.dot(query_embedding, response_embedding)
 
326
  normalized_faiss = (faiss_score + 1) / 2
327
- # BM25 Ranking
328
- tokenized_query = query.lower().split()
329
- bm25_scores = bm25.get_scores(tokenized_query)
 
330
  if bm25_scores.size > 0:
331
- min_bm25 = (
332
- np.min(bm25_scores) if np.min(bm25_scores) != np.max(bm25_scores) else 0
333
- )
334
- max_bm25 = (
335
- np.max(bm25_scores) if np.min(bm25_scores) != np.max(bm25_scores) else 1
336
- )
337
- bm25_score = (
338
- np.mean([bm25_scores[idx] for idx in range(len(retrieved_chunks))])
339
- if len(retrieved_chunks) > 0
340
  else 0
341
  )
342
- normalized_bm25 = (bm25_score - min_bm25) / (max_bm25 - min_bm25)
343
  normalized_bm25 = max(0, min(1, normalized_bm25))
344
  else:
345
- normalized_bm25 = 0
346
- # Final Confidence Score (use Lambda FAISS value for weighted sum)
347
- confidence_score = round(
348
- (normalized_faiss * lambda_faiss + normalized_bm25 * (1 - lambda_faiss)), 2
 
 
 
 
 
349
  )
350
- return confidence_score
351
 
352
 
353
  # UI handle for query model button
@@ -405,40 +433,62 @@ def query_model(
405
  else:
406
  break
407
  prompt = (
408
- f"Based on the following information:\n\n{context}\n\n"
409
- "Answer the query in one or two sentences. "
410
- "Do not provide follow-ups. "
411
- f"Answer the query: {query}"
 
 
 
 
 
 
 
412
  )
413
  inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
414
  inputs.pop("token_type_ids", None)
415
  logger.info("Generating output")
416
  input_len = inputs["input_ids"].shape[-1]
 
417
  with torch.inference_mode():
418
  output = model.generate(
419
  **inputs,
420
  max_new_tokens=max_new_tokens,
421
  num_return_sequences=1,
422
  repetition_penalty=repetition_penalty,
 
 
423
  pad_token_id=tokenizer.eos_token_id,
424
  )
425
- start_len = 0
426
- if use_extraction:
427
- start_len = input_len
428
- output = output[0][start_len:]
429
  execution_time = time.perf_counter() - start_time
430
  logger.info(f"Query processed in {execution_time:.2f} seconds.")
431
- response = tokenizer.decode(output, skip_special_tokens=True)
432
- confidence_score = compute_confidence_score(
433
- query + " " + response, retrieved_chunks, bm25, lambda_faiss
 
 
 
 
 
 
 
 
 
 
 
434
  )
435
- logger.info(f"Confidence: {confidence_score*100}%")
436
  if confidence_score <= 0.3:
437
  logger.error(f"The system is unsure about this response.")
438
  response += "\nThe system is unsure about this response."
 
 
 
 
439
  return (
440
  response,
441
- f"Confidence: {confidence_score*100}%\nTime taken: {execution_time:.2f} seconds",
442
  )
443
 
444
 
@@ -464,7 +514,7 @@ with gr.Blocks(title="Financial Statement RAG with LLM") as ui:
464
  top_k_input = gr.Number(value=15, label="Top K (Default: 15)")
465
  lambda_faiss_input = gr.Slider(0, 1, value=0.5, label="Lambda FAISS (0-1)")
466
  repetition_penalty = gr.Slider(
467
- 1, 2, value=1.0, label="Repetition Penality (1-2)"
468
  )
469
  max_tokens_input = gr.Number(value=100, label="Max New Tokens")
470
  use_extraction = gr.Checkbox(label="Retrieve only the answer", value=False)
@@ -485,6 +535,7 @@ with gr.Blocks(title="Financial Statement RAG with LLM") as ui:
485
  ],
486
  outputs=[query_output, time_output],
487
  )
 
488
  # Application entry point
489
  if __name__ == "__main__":
490
  logger.info("Starting Gradio server...")
 
37
  # SLM: Microsoft PHI-2 model is loaded
38
  # It does have higher memory and compute requirements compared to TinyLlama and Falcon
39
  # But it gives the best results among the three
40
+ # DEVICE = "cpu" # or cuda
41
+ DEVICE = "cuda" # or cuda
42
  # MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
43
  # MODEL_NAME = "tiiuae/falcon-rw-1b"
44
  MODEL_NAME = "microsoft/phi-2"
 
55
  # Since the model is to be hosted on a cpu instance, we use float32
56
  # For GPU, we can use float16 or bfloat16
57
  model = AutoModelForCausalLM.from_pretrained(
58
+ MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True
59
  ).to(DEVICE)
60
  model.eval()
61
  # model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
 
234
  pickle.dump(bm25_data, f)
235
  return "Files processed successfully! You can now query."
236
 
237
+ def contains_financial_entities(query):
238
+ """Check if the query has financial entities"""
239
+ doc = nlp(query)
240
+ for ent in doc.ents:
241
+ if ent.label_ in FINANCIAL_ENTITY_LABELS:
242
+ return True
243
+ return False
244
 
245
  # Input guardrail implementation
246
  # Regex is used to filter queries related to sensitive topics
 
250
  def is_query_allowed(query):
251
  """Checks if the query violates security or confidentiality rules"""
252
  for pattern in restricted_patterns:
253
+ if re.search(pattern, query.lower(), re.IGNORECASE):
254
  return False, "This query requests sensitive or confidential information."
255
  doc = nlp(query)
256
  for ent in doc.ents:
 
317
  return final_results
318
 
319
 
320
+ def compute_entropy(logits):
321
+ """Compute entropy from logits."""
322
+ probs = torch.softmax(logits, dim=-1)
323
+ log_probs = torch.log(probs + 1e-9)
324
+ entropy = -(probs * log_probs).sum(dim=-1)
325
+ return entropy.mean().item()
326
+
327
+
328
  # A confidence score is computed using FAISS and BM25 ranking
329
+ # FAISS: The similarity score between the response and the retrieved chunks are normalized
330
+ # BM25: The BM25 scores for the query and response combined tokens is normalized
331
+ # The mean of top token probability mean and 1-entropy score is the model_conf_signal
332
+ # FAISS, BM25 and the model_conf_signal are combined using a weighted sum
333
+ def compute_response_confidence(
334
+ query,
335
+ response,
336
+ retrieved_chunks,
337
+ bm25,
338
+ model_conf_signal=0.5,
339
+ lambda_faiss=0.4,
340
+ lambda_conf=0.4,
341
+ lambda_bm25=1.8,
342
+ ):
343
+ """Calculates a confidence score using FAISS, BM25, top token probabilites and entropy score"""
344
  if not retrieved_chunks:
345
+ return 0.0
346
+ # Compute FAISS similarity
347
+ retrieved_embedding = embed_model.encode(
348
  " ".join(retrieved_chunks), normalize_embeddings=True
349
  )
350
+ response_embedding = embed_model.encode(response, normalize_embeddings=True)
351
+ faiss_score = np.dot(retrieved_embedding, response_embedding)
352
+ # Normalize the FAISS score
353
  normalized_faiss = (faiss_score + 1) / 2
354
+ # Compute BM25 for combined query + response
355
+ tokenized_combined = (query + " " + response).lower().split()
356
+ bm25_scores = bm25.get_scores(tokenized_combined)
357
+ # Normalize the BM25 score
358
  if bm25_scores.size > 0:
359
+ bm25_score = np.mean(bm25_scores)
360
+ min_bm25, max_bm25 = np.min(bm25_scores), np.max(bm25_scores)
361
+ normalized_bm25 = (
362
+ (bm25_score - min_bm25) / (max_bm25 - min_bm25 + 1e-6)
363
+ if min_bm25 != max_bm25
 
 
 
 
364
  else 0
365
  )
 
366
  normalized_bm25 = max(0, min(1, normalized_bm25))
367
  else:
368
+ normalized_bm25 = 0.0
369
+ logger.info(
370
+ f"Faiss score: {normalized_faiss}, bm25: {normalized_bm25}, "
371
+ f"Mean Top Token + Entropy Avg: {model_conf_signal}"
372
+ )
373
+ confidence_score = (
374
+ lambda_faiss * normalized_faiss
375
+ + model_conf_signal * lambda_conf
376
+ + lambda_bm25 * normalized_bm25
377
  )
378
+ return round(min(100, max(0, confidence_score.item() * 100)), 2)
379
 
380
 
381
  # UI handle for query model button
 
433
  else:
434
  break
435
  prompt = (
436
+ "You are a financial analyst. Answer financial queries concisely using only the numerical data "
437
+ "explicitly present in the provided financial context:\n\n"
438
+ f"{context}\n\n"
439
+ "Strictly use only the given financial data. Do not assume, infer, or generate missing data."
440
+ " Retain the original format of financial figures exactly as given."
441
+ " Do not attempt to convert the currency into any other format."
442
+ " If the requested information is not available in the provided context, respond with "
443
+ "'No relevant financial data available.'"
444
+ " Provide exactly one answer in a single sentence."
445
+ " Do not generate explanations, additional text, or answer multiple queries."
446
+ f"\nQuery: {query}"
447
  )
448
  inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
449
  inputs.pop("token_type_ids", None)
450
  logger.info("Generating output")
451
  input_len = inputs["input_ids"].shape[-1]
452
+ logger.info(f"Input len: {input_len}")
453
  with torch.inference_mode():
454
  output = model.generate(
455
  **inputs,
456
  max_new_tokens=max_new_tokens,
457
  num_return_sequences=1,
458
  repetition_penalty=repetition_penalty,
459
+ output_scores=True,
460
+ return_dict_in_generate=True,
461
  pad_token_id=tokenizer.eos_token_id,
462
  )
463
+ sequences = output["sequences"][0][input_len:]
 
 
 
464
  execution_time = time.perf_counter() - start_time
465
  logger.info(f"Query processed in {execution_time:.2f} seconds.")
466
+ log_probs = output["scores"] # List of logits per generated token
467
+ token_probs = [torch.softmax(lp, dim=-1) for lp in log_probs]
468
+ # Extract top token probabilities for each step
469
+ token_confidences = [tp.max().item() for tp in token_probs]
470
+ # Compute final confidence score
471
+ top_token_conf = sum(token_confidences) / len(token_confidences)
472
+ print(f"Token Token Probability Mean: {top_token_conf:.4f}")
473
+ entropy_score = sum(compute_entropy(lp) for lp in log_probs) / len(log_probs)
474
+ entropy_conf = 1 - (entropy_score / torch.log(torch.tensor(tokenizer.vocab_size)))
475
+ print(f"Entropy-based Confidence: {entropy_conf:.4f}")
476
+ model_conf_signal = (top_token_conf + (1 - entropy_conf)) / 2
477
+ response = tokenizer.decode(sequences, skip_special_tokens=True)
478
+ confidence_score = compute_response_confidence(
479
+ query, response, retrieved_chunks, bm25, model_conf_signal
480
  )
481
+ logger.info(f"Confidence: {confidence_score}%")
482
  if confidence_score <= 0.3:
483
  logger.error(f"The system is unsure about this response.")
484
  response += "\nThe system is unsure about this response."
485
+ final_out = ""
486
+ if not use_extraction:
487
+ final_out += f"Context: {context}\nQuery: {query}\n"
488
+ final_out += f"Response: {response}"
489
  return (
490
  response,
491
+ f"Confidence: {confidence_score}%\nTime taken: {execution_time:.2f} seconds",
492
  )
493
 
494
 
 
514
  top_k_input = gr.Number(value=15, label="Top K (Default: 15)")
515
  lambda_faiss_input = gr.Slider(0, 1, value=0.5, label="Lambda FAISS (0-1)")
516
  repetition_penalty = gr.Slider(
517
+ 1, 2, value=1.2, label="Repetition Penality (1-2)"
518
  )
519
  max_tokens_input = gr.Number(value=100, label="Max New Tokens")
520
  use_extraction = gr.Checkbox(label="Retrieve only the answer", value=False)
 
535
  ],
536
  outputs=[query_output, time_output],
537
  )
538
+
539
  # Application entry point
540
  if __name__ == "__main__":
541
  logger.info("Starting Gradio server...")
data_filters.py CHANGED
@@ -1,16 +1,16 @@
1
  """Sensitive data filters"""
2
 
3
  restricted_patterns = [
4
- r"\b(?:CFO|CEO|CTO|executive|director|manager|employee|staff|worker)\b.*\b(?:salary|compensation|bonus|pay|income)\b",
5
- r"\b(?:salary|compensation|bonus|pay|income)\b.*\b(?:CFO|CEO|CTO|executive|director|manager|employee|staff|worker)\b",
6
  r"\b(?:acquisition|merger|buyout)\b.*\b(?:before|pre-announcement|leak|inside information)\b",
7
  r"\b(?:before|pre-announcement|leak|inside information)\b.*\b(?:acquisition|merger|buyout)\b",
8
  r"\b(?:stock price|share price|insider trading|buying shares)\b",
9
  r"\b(?:internal policy|data breach|security protocol|confidential|classified)\b",
10
  r"\b(?:password|access credentials|encryption key|secure key)\b",
11
- r"\b(?:social security number|SSN|passport number|credit card|bank account|tax ID|TIN|personal details)\b",
12
- r"\b(?:employee records|payroll|medical records|HR data|salary data|PII|personally identifiable information)\b",
13
- r"\b(?:CFO|CEO|CTO|executive|director|manager|employee|staff|worker)\b.*\b(?:address|work location|home location|residence|personal contact|phone number|email|office location)\b",
14
  ]
15
 
16
  restricted_topics = {
 
1
  """Sensitive data filters"""
2
 
3
  restricted_patterns = [
4
+ r"\b(?:cfo|ceo|cto|executive|director|manager|employee|staff|worker)\b.*\b(?:salary|compensation|bonus|pay|income)\b",
5
+ r"\b(?:salary|compensation|bonus|pay|income)\b.*\b(?:cfo|ceo|cto|executive|director|manager|employee|staff|worker)\b",
6
  r"\b(?:acquisition|merger|buyout)\b.*\b(?:before|pre-announcement|leak|inside information)\b",
7
  r"\b(?:before|pre-announcement|leak|inside information)\b.*\b(?:acquisition|merger|buyout)\b",
8
  r"\b(?:stock price|share price|insider trading|buying shares)\b",
9
  r"\b(?:internal policy|data breach|security protocol|confidential|classified)\b",
10
  r"\b(?:password|access credentials|encryption key|secure key)\b",
11
+ r"\b(?:social security number|ssn|passport number|credit card|bank account|tax id|tin|personal details)\b",
12
+ r"\b(?:employee records|payroll|medical records|hr data|salary data|pii|personally identifiable information)\b",
13
+ r"\b(?:cfo|ceo|cto|executive|director|manager|employee|staff|worker)\b.*\b(?:address|work location|home location|residence|personal contact|phone number|email|office location)\b",
14
  ]
15
 
16
  restricted_topics = {