rycont zzrng76 commited on
Commit
91ee34e
1 Parent(s): 3e6f468

Create pipeline.py (#2)

Browse files

- Create pipeline.py (7ac259c19ecc70c2aad0575c7f2ff667f2e95fac)


Co-authored-by: yrlee <[email protected]>

Files changed (1) hide show
  1. pipeline.py +55 -0
pipeline.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+
3
+ class MyPipeline(Pipeline):
4
+ def _sanitize_parameters(self, **kwargs):
5
+ preprocess_kwargs = {}
6
+ if "max_length" in kwargs:
7
+ preprocess_kwargs["max_length"] = kwargs["max_length"]
8
+ if "num_beams" in kwargs:
9
+ preprocess_kwargs["num_beams"] = kwargs["num_beams"]
10
+
11
+ return preprocess_kwargs, {}, {}
12
+ def preprocess(self, inputs, **kwargs):
13
+ inputs = re.sub(r'[^A-Za-z가-힣,<>0-9:&# ]', '', inputs)
14
+ inputs = "질문 생성: <unused0>"+inputs
15
+
16
+ input_ids = [tokenizer.bos_token_id] + tokenizer.encode(inputs) + [tokenizer.eos_token_id]
17
+ return {"inputs":torch.tensor([input_ids]),'max_length':kwargs['max_length'],'num_beams':kwargs['num_beams'] }
18
+
19
+ def _forward(self, model_inputs):
20
+ res_ids = model.generate(
21
+ model_inputs['inputs'],
22
+ max_length=model_inputs['max_length'],
23
+ num_beams=model_inputs['num_beams'],
24
+ eos_token_id=tokenizer.eos_token_id,
25
+ bad_words_ids=[[tokenizer.unk_token_id]]
26
+ )
27
+ return {"logits": res_ids}
28
+
29
+ def postprocess(self, model_outputs):
30
+ a = tokenizer.batch_decode(model_outputs["logits"].tolist())[0]
31
+ out_question = a.replace('<s>', '').replace('</s>', '')
32
+ return out_question
33
+
34
+ def _inference(self,paragraph,**kwargs):
35
+ input_ids = self.preprocess(paragraph,**kwargs)
36
+ reds_ids = self._forward(input_ids)
37
+ out_question = self.postprocess(reds_ids)
38
+ return out_question
39
+
40
+ def make_question(self, text, **kwargs):
41
+ words = text.split(" ")
42
+ frame_size = kwargs['frame_size']
43
+ hop_length = kwargs['hop_length']
44
+ steps = round((len(words)-frame_size)/hop_length) + 1
45
+ outs = []
46
+ for step in range(steps):
47
+ try:
48
+ script = " ".join(words[step*hop_length:step*hop_length+frame_size])
49
+ except:
50
+ script = " ".join(words[(1+step)*hop_length:])
51
+
52
+ outs.append(self._inference(script,**kwargs))
53
+ #if step>4:
54
+ # break
55
+ return outs