hallu_scorer / custom_pipeline.py
tcapelle's picture
Upload custom_pipeline.py
501a0bd verified
from transformers import AutoTokenizer, Pipeline
import torch
class PairTextClassificationPipeline(Pipeline):
def __init__(self, model, tokenizer=None, **kwargs):
# Initialize tokenizer first
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
# Make sure we store the tokenizer before calling super().__init__
self.tokenizer = tokenizer
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
self.prompt = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}"
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
return preprocess_kwargs, {}, {}
def preprocess(self, inputs):
# Expect inputs to be list of (Premise, Hypothesis) tuples
pair_dict = {'text1': inputs[0], 'text2': inputs[1]}
formatted_prompt = self.prompt.format(**pair_dict)
model_inputs = self.tokenizer(
formatted_prompt,
return_tensors='pt',
padding=True
)
return model_inputs
def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs
def postprocess(self, model_outputs):
logits = model_outputs.logits
logits = logits[:, 0, :] # tok_cls
transformed_probs = torch.softmax(logits, dim=-1)
raw_scores = transformed_probs[:, 1] # probability of class 1
return raw_scores.item()