Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
class GroundednessChecker: | |
def __init__(self, model_path="./grounding_detector"): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
def check(self, question: str, answer: str, context: str) -> dict: | |
"""Check if answer is grounded in context""" | |
inputs = self.tokenizer( | |
question, | |
answer + " [SEP] " + context, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
return { | |
"is_grounded": bool(torch.argmax(probs)), | |
"confidence": probs[0][1].item(), | |
"details": { | |
"question": question, | |
"answer": answer, | |
"context_snippet": context[:200] + "..." if len(context) > 200 else context | |
} | |
} | |
# Usage Example | |
if __name__ == "__main__": | |
# Initialize checker | |
checker = GroundednessChecker() | |
# Example from banking PDS | |
context = """ | |
Premium Savings Account Terms: | |
- Annual Percentage Yield (APY): 4.25% | |
- Minimum opening deposit: $1,000 | |
- Monthly maintenance fee: $5 (waived if daily balance >= $1,000) | |
- Maximum withdrawals: 6 per month | |
""" | |
# Grounded example | |
grounded_result = checker.check( | |
question="What is the minimum opening deposit?", | |
answer="$1,000", | |
context=context | |
) | |
print("Grounded Result:", grounded_result) | |
# Ungrounded example | |
ungrounded_result = checker.check( | |
question="What is the monthly maintenance fee?", | |
answer="$10 monthly charge", | |
context=context | |
) | |
print("Ungrounded Result:", ungrounded_result) | |