rbawden commited on
Commit
7dab9c4
1 Parent(s): c14e159

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +197 -0
pipeline.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ from transformers import Pipeline, pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from transformers.tokenization_utils_base import TruncationStrategy
4
+ from torch import Tensor
5
+ import html.parser
6
+ import unicodedata
7
+ import sys, os, re
8
+
9
+ class ReaccentPipeline(Pipeline):
10
+
11
+ def __init__(self, beam_size=5, batch_size=32, **kwargs):
12
+ self.beam_size = beam_size
13
+ super().__init__(**kwargs)
14
+
15
+
16
+ def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
17
+ preprocess_params = {}
18
+ if truncation is not None:
19
+ preprocess_params["truncation"] = truncation
20
+
21
+ forward_params = generate_kwargs
22
+
23
+ postprocess_params = {}
24
+
25
+ if clean_up_tokenisation_spaces is not None:
26
+ postprocess_params["clean_up_tokenisation_spaces"] = clean_up_tokenisation_spaces
27
+
28
+ return preprocess_params, forward_params, postprocess_params
29
+
30
+
31
+ def check_inputs(self, input_length: int, min_length: int, max_length: int):
32
+ """
33
+ Checks whether there might be something wrong with given input with regard to the model.
34
+ """
35
+ return True
36
+
37
+ def make_printable(self, s):
38
+ '''Replace non-printable characters in a string.'''
39
+ return s.translate(NOPRINT_TRANS_TABLE)
40
+
41
+
42
+ def normalise(self, line):
43
+ #line = unicodedata.normalize('NFKC', line)
44
+ #line = self.make_printable(line)
45
+ for before, after in [('[«»\“\”]', '"'),
46
+ ('[‘’]', "'"),
47
+ (' +', ' '),
48
+ ('\"+', '"'),
49
+ ("'+", "'"),
50
+ ('^ *', ''),
51
+ (' *$', '')]:
52
+ line = re.sub(before, after, line)
53
+ return line.strip() + ' </s>'
54
+
55
+ def _parse_and_tokenise(self, *args, truncation):
56
+ prefix = ""
57
+ if isinstance(args[0], list):
58
+ if self.tokenizer.pad_token_id is None:
59
+ raise ValueError("Please make sure that the tokeniser has a pad_token_id when using a batch input")
60
+ args = ([prefix + arg for arg in args[0]],)
61
+ padding = True
62
+
63
+ elif isinstance(args[0], str):
64
+ args = (prefix + args[0],)
65
+ padding = False
66
+ else:
67
+ raise ValueError(
68
+ f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
69
+ )
70
+ inputs = [self.normalise(x) for x in args]
71
+ inputs = self.tokenizer(inputs, padding=padding, truncation=truncation, return_tensors=self.framework)
72
+ toks = []
73
+ for tok_ids in inputs.input_ids:
74
+ toks.append(" ".join(self.tokenizer.convert_ids_to_tokens(tok_ids)))
75
+ # This is produced by tokenisers but is an invalid generate kwargs
76
+ if "token_type_ids" in inputs:
77
+ del inputs["token_type_ids"]
78
+ return inputs
79
+
80
+ def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs):
81
+ inputs = self._parse_and_tokenise(inputs, truncation=truncation, **kwargs)
82
+ return inputs
83
+
84
+ def _forward(self, model_inputs, **generate_kwargs):
85
+ in_b, input_length = model_inputs["input_ids"].shape
86
+
87
+ generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
88
+ generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
89
+ generate_kwargs['num_beams'] = self.beam_size
90
+ self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
91
+ output_ids = self.model.generate(**model_inputs, **generate_kwargs)
92
+ out_b = output_ids.shape[0]
93
+ output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
94
+ return {"output_ids": output_ids}
95
+
96
+ def postprocess(self, model_outputs, clean_up_tokenisation_spaces=False):
97
+ records = []
98
+ for output_ids in model_outputs["output_ids"][0]:
99
+ record = {
100
+ "text": self.tokenizer.decode(
101
+ output_ids,
102
+ skip_special_tokens=True,
103
+ clean_up_tokenisation_spaces=clean_up_tokenisation_spaces,
104
+ )
105
+ }
106
+ records.append(record)
107
+ return records
108
+
109
+ def correct_hallunications(self, orig, output):
110
+ # align the original and output tokens
111
+
112
+ # check that the correspondences are legitimate and correct if not
113
+
114
+ # replace <EMOJI> symbols by the original ones
115
+ return output
116
+
117
+ def __call__(self, *args, **kwargs):
118
+ r"""
119
+ Generate the output text(s) using text(s) given as inputs.
120
+ Args:
121
+ args (`str` or `List[str]`):
122
+ Input text for the encoder.
123
+ return_tensors (`bool`, *optional*, defaults to `False`):
124
+ Whether or not to include the tensors of predictions (as token indices) in the outputs.
125
+ return_text (`bool`, *optional*, defaults to `True`):
126
+ Whether or not to include the decoded texts in the outputs.
127
+ clean_up_tokenisation_spaces (`bool`, *optional*, defaults to `False`):
128
+ Whether or not to clean up the potential extra spaces in the text output.
129
+ truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
130
+ The truncation strategy for the tokenisation within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
131
+ (default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
132
+ max_length instead of throwing an error down the line.
133
+ generate_kwargs:
134
+ Additional keyword arguments to pass along to the generate method of the model (see the generate method
135
+ corresponding to your framework [here](./model#generative-models)).
136
+ Return:
137
+ A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:
138
+ - **generated_text** (`str`, present when `return_text=True`) -- The generated text.
139
+ - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
140
+ ids of the generated text.
141
+ """
142
+
143
+ result = super().__call__(*args, **kwargs)
144
+ if (
145
+ isinstance(args[0], list)
146
+ and all(isinstance(el, str) for el in args[0])
147
+ and all(len(res) == 1 for res in result)
148
+ ):
149
+ return [res[0] for res in result]
150
+ return result
151
+
152
+
153
+ def normalise_text(list_sents, batch_size=32, beam_size=5):
154
+ tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
155
+ model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
156
+ normalisation_pipeline = ReaccentPipeline(model=model,
157
+ tokenizer=tokeniser,
158
+ batch_size=batch_size,
159
+ beam_size=beam_size)
160
+ normalised_outputs = normalisation_pipeline(list_sents)
161
+ return normalised_outputs
162
+
163
+ def normalise_from_stdin(batch_size=32, beam_size=5):
164
+ tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
165
+ model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
166
+ normalisation_pipeline = ReaccentPipeline(model=model,
167
+ tokenizer=tokeniser,
168
+ batch_size=batch_size,
169
+ beam_size=beam_size)
170
+ list_sents = []
171
+ for sent in sys.stdin:
172
+ list_sents.append(sent)
173
+ normalised_outputs = normalisation_pipeline(list_sents)
174
+ for sent in normalised_outputs:
175
+ print(sent['text'].strip())
176
+ return normalised_outputs
177
+
178
+
179
+ if __name__ == '__main__':
180
+
181
+ import argparse
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument('-k', '--batch_size', type=int, default=32, help='Set the batch size for decoding')
184
+ parser.add_argument('-b', '--beam_size', type=int, default=5, help='Set the beam size for decoding')
185
+ parser.add_argument('-i', '--input_file', type=str, default=None, help='Input file. If None, read from STDIN')
186
+ args = parser.parse_args()
187
+
188
+ if args.input_file is None:
189
+ normalise_from_stdin(batch_size=args.batch_size, beam_size=args.beam_size)
190
+ else:
191
+ list_sents = []
192
+ with open(args.input_file) as fp:
193
+ for line in fp:
194
+ list_sents.append(line.strip())
195
+ output_sents = normalise_text(list_sents, batch_size=args.batch_size, beam_size=args.beam_size)
196
+ for output_sent in output_sents:
197
+ print(output_sent)