Spaces:
Sleeping
Sleeping
import re | |
import spacy | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import random | |
from datetime import datetime, timedelta | |
from dateutil.parser import parse as parse_date | |
# A simplistic Ungrounded Answer Generator. | |
class UngroundedAnswerGenerator: | |
def __init__(self): | |
self.nlp = spacy.load("en_core_web_sm") | |
self.sim_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# | |
self.financial_terms = [ | |
"CommBank Credit Card", | |
"Personal credit cards", | |
"Business credit cards", | |
"PIN", | |
"ePayments Code", | |
"Conditions of Use", | |
"Schedule of Credit Card Particulars", | |
"Banking Code of Practice", | |
"NetBank", | |
"CommBank app", | |
"Electronic Banking Terms and Conditions", | |
"Tap & Pay", | |
"cash advance", | |
"credit limit", | |
"ATM cash withdrawals", | |
"international transaction fee", | |
"Mastercard", | |
"Visa", | |
"balance transfers", | |
"regular payments", | |
"additional cardholder", | |
"digital wallet", | |
"statements and notices", | |
"closing balance", | |
"minimum payment", | |
"interest-free period on purchases", | |
"SurePay instalment plan", | |
"AutoPay", | |
"fees and interest rates", | |
"annual interest rates", | |
"daily interest rate", | |
"statement period", | |
"balance transfer period", | |
"unauthorised transaction", | |
"card scheme refunds", | |
"purchase plan", | |
"card balance plan", | |
"cash advance balance plan", | |
"instalment setup fee", | |
"purchase balance", | |
"cash advances balance", | |
"interest rate for the plan", | |
"credit card account", | |
"default under your contract" | |
] | |
def generate(self, context: str, answer: str) -> str: | |
strategy = self._select_strategy(answer) | |
return strategy(context, answer) | |
def _select_strategy(self, answer: str): | |
doc = self.nlp(answer) | |
ents = [ent.label_ for ent in doc.ents] | |
if "DATE" in ents: | |
return self._perturb_dates | |
if any(e in ["MONEY", "PERCENT"] for e in ents): | |
return self._perturb_numbers | |
return self._semantic_distractor | |
def _perturb_numbers(self, context: str, answer: str) -> str: | |
if "$" in answer: | |
base = self._extract_number(answer) | |
return f"${base * random.uniform(0.8, 1.2):.2f}" | |
elif "%" in answer: | |
base = self._extract_number(answer) | |
return f"{base * random.uniform(0.5, 1.5):.1f}%" | |
return answer | |
def _perturb_dates(self, context: str, answer: str) -> str: | |
try: | |
dt = parse_date(answer) | |
if dt: | |
delta = timedelta(days=random.randint(-30, 30)) | |
return (dt + delta).strftime("%Y-%m-%d") | |
except: | |
pass | |
return answer | |
def _semantic_distractor(self, context: str, answer: str) -> str: | |
answer_emb = self.sim_model.encode(answer) | |
term_embs = self.sim_model.encode(self.financial_terms) | |
similarities = np.dot(term_embs, answer_emb) | |
return self.financial_terms[np.argsort(similarities)[-2]] | |
def _extract_number(self, text: str) -> float: | |
try: | |
return float(re.search(r"\d+\.?\d*", text).group()) | |
except: | |
return random.uniform(1, 1000) | |