llmgaurdrails / custom_models /groundedness_checker /ungrounded_answer_generator.py
Sasidhar's picture
Upload 16 files
826f9a4 verified
import re
import spacy
from sentence_transformers import SentenceTransformer
import numpy as np
import random
from datetime import datetime, timedelta
from dateutil.parser import parse as parse_date
# A simplistic Ungrounded Answer Generator.
class UngroundedAnswerGenerator:
def __init__(self):
self.nlp = spacy.load("en_core_web_sm")
self.sim_model = SentenceTransformer('all-MiniLM-L6-v2')
#
self.financial_terms = [
"CommBank Credit Card",
"Personal credit cards",
"Business credit cards",
"PIN",
"ePayments Code",
"Conditions of Use",
"Schedule of Credit Card Particulars",
"Banking Code of Practice",
"NetBank",
"CommBank app",
"Electronic Banking Terms and Conditions",
"Tap & Pay",
"cash advance",
"credit limit",
"ATM cash withdrawals",
"international transaction fee",
"Mastercard",
"Visa",
"balance transfers",
"regular payments",
"additional cardholder",
"digital wallet",
"statements and notices",
"closing balance",
"minimum payment",
"interest-free period on purchases",
"SurePay instalment plan",
"AutoPay",
"fees and interest rates",
"annual interest rates",
"daily interest rate",
"statement period",
"balance transfer period",
"unauthorised transaction",
"card scheme refunds",
"purchase plan",
"card balance plan",
"cash advance balance plan",
"instalment setup fee",
"purchase balance",
"cash advances balance",
"interest rate for the plan",
"credit card account",
"default under your contract"
]
def generate(self, context: str, answer: str) -> str:
strategy = self._select_strategy(answer)
return strategy(context, answer)
def _select_strategy(self, answer: str):
doc = self.nlp(answer)
ents = [ent.label_ for ent in doc.ents]
if "DATE" in ents:
return self._perturb_dates
if any(e in ["MONEY", "PERCENT"] for e in ents):
return self._perturb_numbers
return self._semantic_distractor
def _perturb_numbers(self, context: str, answer: str) -> str:
if "$" in answer:
base = self._extract_number(answer)
return f"${base * random.uniform(0.8, 1.2):.2f}"
elif "%" in answer:
base = self._extract_number(answer)
return f"{base * random.uniform(0.5, 1.5):.1f}%"
return answer
def _perturb_dates(self, context: str, answer: str) -> str:
try:
dt = parse_date(answer)
if dt:
delta = timedelta(days=random.randint(-30, 30))
return (dt + delta).strftime("%Y-%m-%d")
except:
pass
return answer
def _semantic_distractor(self, context: str, answer: str) -> str:
answer_emb = self.sim_model.encode(answer)
term_embs = self.sim_model.encode(self.financial_terms)
similarities = np.dot(term_embs, answer_emb)
return self.financial_terms[np.argsort(similarities)[-2]]
def _extract_number(self, text: str) -> float:
try:
return float(re.search(r"\d+\.?\d*", text).group())
except:
return random.uniform(1, 1000)