tcapelle commited on
Commit
501a0bd
·
verified ·
1 Parent(s): 41c6ba2

Upload custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +38 -0
custom_pipeline.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, Pipeline
2
+ import torch
3
+
4
+ class PairTextClassificationPipeline(Pipeline):
5
+ def __init__(self, model, tokenizer=None, **kwargs):
6
+ # Initialize tokenizer first
7
+ if tokenizer is None:
8
+ tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
9
+ # Make sure we store the tokenizer before calling super().__init__
10
+ self.tokenizer = tokenizer
11
+ super().__init__(model=model, tokenizer=tokenizer, **kwargs)
12
+ self.prompt = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}"
13
+
14
+ def _sanitize_parameters(self, **kwargs):
15
+ preprocess_kwargs = {}
16
+ return preprocess_kwargs, {}, {}
17
+
18
+ def preprocess(self, inputs):
19
+ # Expect inputs to be list of (Premise, Hypothesis) tuples
20
+ pair_dict = {'text1': inputs[0], 'text2': inputs[1]}
21
+ formatted_prompt = self.prompt.format(**pair_dict)
22
+ model_inputs = self.tokenizer(
23
+ formatted_prompt,
24
+ return_tensors='pt',
25
+ padding=True
26
+ )
27
+ return model_inputs
28
+
29
+ def _forward(self, model_inputs):
30
+ model_outputs = self.model(**model_inputs)
31
+ return model_outputs
32
+
33
+ def postprocess(self, model_outputs):
34
+ logits = model_outputs.logits
35
+ logits = logits[:, 0, :] # tok_cls
36
+ transformed_probs = torch.softmax(logits, dim=-1)
37
+ raw_scores = transformed_probs[:, 1] # probability of class 1
38
+ return raw_scores.item()