|
from transformers import AutoTokenizer, Pipeline |
|
import torch |
|
|
|
class PairTextClassificationPipeline(Pipeline): |
|
def __init__(self, model, tokenizer=None, **kwargs): |
|
|
|
if tokenizer is None: |
|
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) |
|
|
|
self.tokenizer = tokenizer |
|
super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
self.prompt = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}" |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, inputs): |
|
|
|
pair_dict = {'text1': inputs[0], 'text2': inputs[1]} |
|
formatted_prompt = self.prompt.format(**pair_dict) |
|
model_inputs = self.tokenizer( |
|
formatted_prompt, |
|
return_tensors='pt', |
|
padding=True |
|
) |
|
return model_inputs |
|
|
|
def _forward(self, model_inputs): |
|
model_outputs = self.model(**model_inputs) |
|
return model_outputs |
|
|
|
def postprocess(self, model_outputs): |
|
logits = model_outputs.logits |
|
logits = logits[:, 0, :] |
|
transformed_probs = torch.softmax(logits, dim=-1) |
|
raw_scores = transformed_probs[:, 1] |
|
return raw_scores.item() |
|
|