llmgaurdrails / model_inference /groundedness_checker.py
Sasidhar's picture
Upload 3 files
8ab2445 verified
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
class GroundednessChecker:
def __init__(self, model_path="./grounding_detector"):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def check(self, question: str, answer: str, context: str) -> dict:
"""Check if answer is grounded in context"""
inputs = self.tokenizer(
question,
answer + " [SEP] " + context,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
return {
"is_grounded": bool(torch.argmax(probs)),
"confidence": probs[0][1].item(),
"details": {
"question": question,
"answer": answer,
"context_snippet": context[:200] + "..." if len(context) > 200 else context
}
}
# Usage Example
if __name__ == "__main__":
# Initialize checker
checker = GroundednessChecker()
# Example from banking PDS
context = """
Premium Savings Account Terms:
- Annual Percentage Yield (APY): 4.25%
- Minimum opening deposit: $1,000
- Monthly maintenance fee: $5 (waived if daily balance >= $1,000)
- Maximum withdrawals: 6 per month
"""
# Grounded example
grounded_result = checker.check(
question="What is the minimum opening deposit?",
answer="$1,000",
context=context
)
print("Grounded Result:", grounded_result)
# Ungrounded example
ungrounded_result = checker.check(
question="What is the monthly maintenance fee?",
answer="$10 monthly charge",
context=context
)
print("Ungrounded Result:", ungrounded_result)