Upload pipeline.py
Browse files- pipeline.py +197 -0
pipeline.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
from transformers import Pipeline, pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
3 |
+
from transformers.tokenization_utils_base import TruncationStrategy
|
4 |
+
from torch import Tensor
|
5 |
+
import html.parser
|
6 |
+
import unicodedata
|
7 |
+
import sys, os, re
|
8 |
+
|
9 |
+
class ReaccentPipeline(Pipeline):
|
10 |
+
|
11 |
+
def __init__(self, beam_size=5, batch_size=32, **kwargs):
|
12 |
+
self.beam_size = beam_size
|
13 |
+
super().__init__(**kwargs)
|
14 |
+
|
15 |
+
|
16 |
+
def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
|
17 |
+
preprocess_params = {}
|
18 |
+
if truncation is not None:
|
19 |
+
preprocess_params["truncation"] = truncation
|
20 |
+
|
21 |
+
forward_params = generate_kwargs
|
22 |
+
|
23 |
+
postprocess_params = {}
|
24 |
+
|
25 |
+
if clean_up_tokenisation_spaces is not None:
|
26 |
+
postprocess_params["clean_up_tokenisation_spaces"] = clean_up_tokenisation_spaces
|
27 |
+
|
28 |
+
return preprocess_params, forward_params, postprocess_params
|
29 |
+
|
30 |
+
|
31 |
+
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
32 |
+
"""
|
33 |
+
Checks whether there might be something wrong with given input with regard to the model.
|
34 |
+
"""
|
35 |
+
return True
|
36 |
+
|
37 |
+
def make_printable(self, s):
|
38 |
+
'''Replace non-printable characters in a string.'''
|
39 |
+
return s.translate(NOPRINT_TRANS_TABLE)
|
40 |
+
|
41 |
+
|
42 |
+
def normalise(self, line):
|
43 |
+
#line = unicodedata.normalize('NFKC', line)
|
44 |
+
#line = self.make_printable(line)
|
45 |
+
for before, after in [('[«»\“\”]', '"'),
|
46 |
+
('[‘’]', "'"),
|
47 |
+
(' +', ' '),
|
48 |
+
('\"+', '"'),
|
49 |
+
("'+", "'"),
|
50 |
+
('^ *', ''),
|
51 |
+
(' *$', '')]:
|
52 |
+
line = re.sub(before, after, line)
|
53 |
+
return line.strip() + ' </s>'
|
54 |
+
|
55 |
+
def _parse_and_tokenise(self, *args, truncation):
|
56 |
+
prefix = ""
|
57 |
+
if isinstance(args[0], list):
|
58 |
+
if self.tokenizer.pad_token_id is None:
|
59 |
+
raise ValueError("Please make sure that the tokeniser has a pad_token_id when using a batch input")
|
60 |
+
args = ([prefix + arg for arg in args[0]],)
|
61 |
+
padding = True
|
62 |
+
|
63 |
+
elif isinstance(args[0], str):
|
64 |
+
args = (prefix + args[0],)
|
65 |
+
padding = False
|
66 |
+
else:
|
67 |
+
raise ValueError(
|
68 |
+
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
|
69 |
+
)
|
70 |
+
inputs = [self.normalise(x) for x in args]
|
71 |
+
inputs = self.tokenizer(inputs, padding=padding, truncation=truncation, return_tensors=self.framework)
|
72 |
+
toks = []
|
73 |
+
for tok_ids in inputs.input_ids:
|
74 |
+
toks.append(" ".join(self.tokenizer.convert_ids_to_tokens(tok_ids)))
|
75 |
+
# This is produced by tokenisers but is an invalid generate kwargs
|
76 |
+
if "token_type_ids" in inputs:
|
77 |
+
del inputs["token_type_ids"]
|
78 |
+
return inputs
|
79 |
+
|
80 |
+
def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs):
|
81 |
+
inputs = self._parse_and_tokenise(inputs, truncation=truncation, **kwargs)
|
82 |
+
return inputs
|
83 |
+
|
84 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
85 |
+
in_b, input_length = model_inputs["input_ids"].shape
|
86 |
+
|
87 |
+
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
|
88 |
+
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
|
89 |
+
generate_kwargs['num_beams'] = self.beam_size
|
90 |
+
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
|
91 |
+
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
92 |
+
out_b = output_ids.shape[0]
|
93 |
+
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
|
94 |
+
return {"output_ids": output_ids}
|
95 |
+
|
96 |
+
def postprocess(self, model_outputs, clean_up_tokenisation_spaces=False):
|
97 |
+
records = []
|
98 |
+
for output_ids in model_outputs["output_ids"][0]:
|
99 |
+
record = {
|
100 |
+
"text": self.tokenizer.decode(
|
101 |
+
output_ids,
|
102 |
+
skip_special_tokens=True,
|
103 |
+
clean_up_tokenisation_spaces=clean_up_tokenisation_spaces,
|
104 |
+
)
|
105 |
+
}
|
106 |
+
records.append(record)
|
107 |
+
return records
|
108 |
+
|
109 |
+
def correct_hallunications(self, orig, output):
|
110 |
+
# align the original and output tokens
|
111 |
+
|
112 |
+
# check that the correspondences are legitimate and correct if not
|
113 |
+
|
114 |
+
# replace <EMOJI> symbols by the original ones
|
115 |
+
return output
|
116 |
+
|
117 |
+
def __call__(self, *args, **kwargs):
|
118 |
+
r"""
|
119 |
+
Generate the output text(s) using text(s) given as inputs.
|
120 |
+
Args:
|
121 |
+
args (`str` or `List[str]`):
|
122 |
+
Input text for the encoder.
|
123 |
+
return_tensors (`bool`, *optional*, defaults to `False`):
|
124 |
+
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
125 |
+
return_text (`bool`, *optional*, defaults to `True`):
|
126 |
+
Whether or not to include the decoded texts in the outputs.
|
127 |
+
clean_up_tokenisation_spaces (`bool`, *optional*, defaults to `False`):
|
128 |
+
Whether or not to clean up the potential extra spaces in the text output.
|
129 |
+
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
|
130 |
+
The truncation strategy for the tokenisation within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
|
131 |
+
(default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
|
132 |
+
max_length instead of throwing an error down the line.
|
133 |
+
generate_kwargs:
|
134 |
+
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
135 |
+
corresponding to your framework [here](./model#generative-models)).
|
136 |
+
Return:
|
137 |
+
A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:
|
138 |
+
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
|
139 |
+
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
|
140 |
+
ids of the generated text.
|
141 |
+
"""
|
142 |
+
|
143 |
+
result = super().__call__(*args, **kwargs)
|
144 |
+
if (
|
145 |
+
isinstance(args[0], list)
|
146 |
+
and all(isinstance(el, str) for el in args[0])
|
147 |
+
and all(len(res) == 1 for res in result)
|
148 |
+
):
|
149 |
+
return [res[0] for res in result]
|
150 |
+
return result
|
151 |
+
|
152 |
+
|
153 |
+
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
154 |
+
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
155 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
156 |
+
normalisation_pipeline = ReaccentPipeline(model=model,
|
157 |
+
tokenizer=tokeniser,
|
158 |
+
batch_size=batch_size,
|
159 |
+
beam_size=beam_size)
|
160 |
+
normalised_outputs = normalisation_pipeline(list_sents)
|
161 |
+
return normalised_outputs
|
162 |
+
|
163 |
+
def normalise_from_stdin(batch_size=32, beam_size=5):
|
164 |
+
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
165 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
166 |
+
normalisation_pipeline = ReaccentPipeline(model=model,
|
167 |
+
tokenizer=tokeniser,
|
168 |
+
batch_size=batch_size,
|
169 |
+
beam_size=beam_size)
|
170 |
+
list_sents = []
|
171 |
+
for sent in sys.stdin:
|
172 |
+
list_sents.append(sent)
|
173 |
+
normalised_outputs = normalisation_pipeline(list_sents)
|
174 |
+
for sent in normalised_outputs:
|
175 |
+
print(sent['text'].strip())
|
176 |
+
return normalised_outputs
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == '__main__':
|
180 |
+
|
181 |
+
import argparse
|
182 |
+
parser = argparse.ArgumentParser()
|
183 |
+
parser.add_argument('-k', '--batch_size', type=int, default=32, help='Set the batch size for decoding')
|
184 |
+
parser.add_argument('-b', '--beam_size', type=int, default=5, help='Set the beam size for decoding')
|
185 |
+
parser.add_argument('-i', '--input_file', type=str, default=None, help='Input file. If None, read from STDIN')
|
186 |
+
args = parser.parse_args()
|
187 |
+
|
188 |
+
if args.input_file is None:
|
189 |
+
normalise_from_stdin(batch_size=args.batch_size, beam_size=args.beam_size)
|
190 |
+
else:
|
191 |
+
list_sents = []
|
192 |
+
with open(args.input_file) as fp:
|
193 |
+
for line in fp:
|
194 |
+
list_sents.append(line.strip())
|
195 |
+
output_sents = normalise_text(list_sents, batch_size=args.batch_size, beam_size=args.beam_size)
|
196 |
+
for output_sent in output_sents:
|
197 |
+
print(output_sent)
|