Create pipeline.py (#2)
Browse files- Create pipeline.py (7ac259c19ecc70c2aad0575c7f2ff667f2e95fac)
Co-authored-by: yrlee <[email protected]>
- pipeline.py +55 -0
pipeline.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Pipeline
|
2 |
+
|
3 |
+
class MyPipeline(Pipeline):
|
4 |
+
def _sanitize_parameters(self, **kwargs):
|
5 |
+
preprocess_kwargs = {}
|
6 |
+
if "max_length" in kwargs:
|
7 |
+
preprocess_kwargs["max_length"] = kwargs["max_length"]
|
8 |
+
if "num_beams" in kwargs:
|
9 |
+
preprocess_kwargs["num_beams"] = kwargs["num_beams"]
|
10 |
+
|
11 |
+
return preprocess_kwargs, {}, {}
|
12 |
+
def preprocess(self, inputs, **kwargs):
|
13 |
+
inputs = re.sub(r'[^A-Za-z가-힣,<>0-9:&# ]', '', inputs)
|
14 |
+
inputs = "질문 생성: <unused0>"+inputs
|
15 |
+
|
16 |
+
input_ids = [tokenizer.bos_token_id] + tokenizer.encode(inputs) + [tokenizer.eos_token_id]
|
17 |
+
return {"inputs":torch.tensor([input_ids]),'max_length':kwargs['max_length'],'num_beams':kwargs['num_beams'] }
|
18 |
+
|
19 |
+
def _forward(self, model_inputs):
|
20 |
+
res_ids = model.generate(
|
21 |
+
model_inputs['inputs'],
|
22 |
+
max_length=model_inputs['max_length'],
|
23 |
+
num_beams=model_inputs['num_beams'],
|
24 |
+
eos_token_id=tokenizer.eos_token_id,
|
25 |
+
bad_words_ids=[[tokenizer.unk_token_id]]
|
26 |
+
)
|
27 |
+
return {"logits": res_ids}
|
28 |
+
|
29 |
+
def postprocess(self, model_outputs):
|
30 |
+
a = tokenizer.batch_decode(model_outputs["logits"].tolist())[0]
|
31 |
+
out_question = a.replace('<s>', '').replace('</s>', '')
|
32 |
+
return out_question
|
33 |
+
|
34 |
+
def _inference(self,paragraph,**kwargs):
|
35 |
+
input_ids = self.preprocess(paragraph,**kwargs)
|
36 |
+
reds_ids = self._forward(input_ids)
|
37 |
+
out_question = self.postprocess(reds_ids)
|
38 |
+
return out_question
|
39 |
+
|
40 |
+
def make_question(self, text, **kwargs):
|
41 |
+
words = text.split(" ")
|
42 |
+
frame_size = kwargs['frame_size']
|
43 |
+
hop_length = kwargs['hop_length']
|
44 |
+
steps = round((len(words)-frame_size)/hop_length) + 1
|
45 |
+
outs = []
|
46 |
+
for step in range(steps):
|
47 |
+
try:
|
48 |
+
script = " ".join(words[step*hop_length:step*hop_length+frame_size])
|
49 |
+
except:
|
50 |
+
script = " ".join(words[(1+step)*hop_length:])
|
51 |
+
|
52 |
+
outs.append(self._inference(script,**kwargs))
|
53 |
+
#if step>4:
|
54 |
+
# break
|
55 |
+
return outs
|