p-christ commited on
Commit
89a0e11
·
1 Parent(s): 2f355e5

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +169 -0
pipeline.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ from typing import Dict, List, Any
3
+ import itertools
4
+ from nltk import sent_tokenize
5
+ # import torch
6
+ import nltk
7
+
8
+ class PreTrainedPipeline():
9
+
10
+ def __init__(self, path=""):
11
+ # IMPLEMENT_THIS
12
+ # Preload all the elements you are going to need at inference.
13
+ # For instance your model, processors, tokenizer that might be needed.
14
+ # This function is only called once, so do all the heavy processing I/O here"""
15
+ nltk.download('punkt')
16
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
17
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
18
+
19
+ self.model_type="t5"
20
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ self.device = "cpu"
22
+
23
+ self.model.to(self.device)
24
+
25
+
26
+
27
+ def __call__(self, inputs: str, max_words_per_answer: int = 3):
28
+ if len(inputs) == 0: return []
29
+ inputs = " ".join(inputs.split())
30
+ sents, answers = self._extract_answers(inputs)
31
+ flat_answers = list(itertools.chain(*answers))
32
+
33
+ if len(flat_answers) == 0:
34
+ return []
35
+
36
+ questions, qg_examples = self.prepare_and_generate_questions(sents, answers)
37
+ output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
38
+ output = self.clean_generated_QAs(output, max_words_per_answer)
39
+ return output
40
+
41
+ def prepare_and_generate_questions(self, sents, answers):
42
+ qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
43
+
44
+ qg_inputs = [example['source_text'] for example in qg_examples]
45
+ questions = self._generate_questions(qg_inputs)
46
+ return questions, qg_examples
47
+
48
+
49
+ def clean_answers_list_of_lists(self, answers):
50
+ clean_answers = []
51
+ for answer_list in answers:
52
+ answer_list = answer_list[:-1]
53
+ answer_list = list(set([a.strip() for a in answer_list]))
54
+ clean_answers.append(answer_list)
55
+ return clean_answers
56
+
57
+
58
+ def _extract_answers(self, context):
59
+ sents, inputs = self._prepare_inputs_for_ans_extraction(context)
60
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
61
+
62
+ outs = self.model.generate(
63
+ input_ids=inputs['input_ids'].to(self.device),
64
+ attention_mask=inputs['attention_mask'].to(self.device),
65
+ max_length=32,
66
+ )
67
+
68
+ dec = [self.tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
69
+ answers = [item.split('<sep>') for item in dec]
70
+
71
+ answers = self.clean_answers_list_of_lists(answers)
72
+
73
+ return sents, answers
74
+
75
+
76
+
77
+ def _prepare_inputs_for_ans_extraction(self, text):
78
+ sents = sent_tokenize(text)
79
+
80
+ inputs = []
81
+ for i in range(len(sents)):
82
+ source_text = "extract answers:"
83
+ for j, sent in enumerate(sents):
84
+ if i == j:
85
+ sent = "<hl> %s <hl>" % sent
86
+ source_text = "%s %s" % (source_text, sent)
87
+ source_text = source_text.strip()
88
+
89
+ if self.model_type == "t5":
90
+ source_text = source_text + " </s>"
91
+ inputs.append(source_text)
92
+
93
+ return sents, inputs
94
+
95
+ def _tokenize(self,
96
+ inputs,
97
+ padding=True,
98
+ truncation=True,
99
+ add_special_tokens=True,
100
+ max_length=512
101
+ ):
102
+ inputs = self.tokenizer.batch_encode_plus(
103
+ inputs,
104
+ max_length=max_length,
105
+ add_special_tokens=add_special_tokens,
106
+ truncation=truncation,
107
+ padding="max_length" if padding else False,
108
+ pad_to_max_length=padding,
109
+ return_tensors="pt"
110
+ )
111
+ return inputs
112
+
113
+ def _generate_questions(self, inputs):
114
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
115
+
116
+ outs = self.model.generate(
117
+ input_ids=inputs['input_ids'].to(self.device),
118
+ attention_mask=inputs['attention_mask'].to(self.device),
119
+ max_length=32,
120
+ num_beams=4,
121
+ )
122
+
123
+ questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
124
+ return questions
125
+
126
+ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
127
+ inputs = []
128
+ for i, answer in enumerate(answers):
129
+ if len(answer) == 0: continue
130
+ for answer_text in answer:
131
+ sent = sents[i]
132
+ sents_copy = sents[:]
133
+ answer_text = self.remove_pad(answer_text)
134
+ answer_text = answer_text.strip()
135
+
136
+ try:
137
+ ans_start_idx = sent.lower().index(answer_text.lower())
138
+ except ValueError:
139
+ # Means the answer is not in the sentence so we skip this one
140
+ continue
141
+
142
+ sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
143
+ sents_copy[i] = sent
144
+
145
+ source_text = " ".join(sents_copy)
146
+ source_text = f"generate question: {source_text}"
147
+ if self.model_type == "t5":
148
+ source_text = source_text + " </s>"
149
+
150
+ inputs.append({"answer": answer_text, "source_text": source_text})
151
+
152
+ return inputs
153
+
154
+ def clean_generated_QAs(self, generated_QAs, max_words_per_answer):
155
+ clean_QAs = []
156
+ answers_used = set()
157
+ # Only allow 1 question per answer, take the first case of it
158
+ for qa in generated_QAs:
159
+ answer_word_length = len(qa['answer'].strip().split())
160
+ if qa['answer'] in answers_used or answer_word_length > max_words_per_answer:
161
+ continue
162
+ answers_used.add(qa['answer'])
163
+ clean_QAs.append(qa)
164
+ return clean_QAs
165
+
166
+ def remove_pad(self, str):
167
+ if "<pad>" in str:
168
+ return str.replace("<pad>", "")
169
+ return str