from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch class GroundednessChecker: def __init__(self, model_path="./grounding_detector"): self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForSequenceClassification.from_pretrained(model_path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def check(self, question: str, answer: str, context: str) -> dict: """Check if answer is grounded in context""" inputs = self.tokenizer( question, answer + " [SEP] " + context, padding=True, truncation=True, max_length=512, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) return { "is_grounded": bool(torch.argmax(probs)), "confidence": probs[0][1].item(), "details": { "question": question, "answer": answer, "context_snippet": context[:200] + "..." if len(context) > 200 else context } } # Usage Example if __name__ == "__main__": # Initialize checker checker = GroundednessChecker() # Example from banking PDS context = """ Premium Savings Account Terms: - Annual Percentage Yield (APY): 4.25% - Minimum opening deposit: $1,000 - Monthly maintenance fee: $5 (waived if daily balance >= $1,000) - Maximum withdrawals: 6 per month """ # Grounded example grounded_result = checker.check( question="What is the minimum opening deposit?", answer="$1,000", context=context ) print("Grounded Result:", grounded_result) # Ungrounded example ungrounded_result = checker.check( question="What is the monthly maintenance fee?", answer="$10 monthly charge", context=context ) print("Ungrounded Result:", ungrounded_result)