Upload pipeline.py
Browse files- pipeline.py +140 -142
pipeline.py
CHANGED
@@ -20,9 +20,20 @@ def basic_tokenise(string):
|
|
20 |
string = re.sub(char + '(?! )' , char + ' ', string)
|
21 |
return string.strip()
|
22 |
|
23 |
-
def homogenise(sent):
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
sent = sent.lower()
|
25 |
-
#
|
|
|
|
|
|
|
|
|
|
|
26 |
replace_from = "ǽǣáàâäąãăåćčçďéèêëęěğìíîĩĭıïĺľłńñňòóôõöøŕřśšşťţùúûũüǔỳýŷÿźẑżžÁÀÂÄĄÃĂÅĆČÇĎÉÈÊËĘĚĞÌÍÎĨĬİÏĹĽŁŃÑŇÒÓÔÕÖØŔŘŚŠŞŤŢÙÚÛŨÜǓỲÝŶŸŹẐŻŽſ"
|
27 |
replace_into = "ææaaaaaaaacccdeeeeeegiiiiiiilllnnnoooooorrsssttuuuuuuyyyyzzzzAAAAAAAACCCDEEEEEEGIIIIIIILLLNNNOOOOOORRSSSTTUUUUUUYYYYZZZZs"
|
28 |
table = sent.maketrans(replace_from, replace_into)
|
@@ -161,7 +172,7 @@ def space_before(idx, sent):
|
|
161 |
######## Normaliation pipeline #########
|
162 |
class NormalisationPipeline(Pipeline):
|
163 |
|
164 |
-
def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, cache_file=None, **kwargs):
|
165 |
self.beam_size = beam_size
|
166 |
# classic tokeniser function (used for alignments)
|
167 |
if tokenise_func is not None:
|
@@ -170,7 +181,10 @@ class NormalisationPipeline(Pipeline):
|
|
170 |
self.classic_tokenise = basic_tokenise
|
171 |
|
172 |
# load lexicon
|
173 |
-
|
|
|
|
|
|
|
174 |
super().__init__(**kwargs)
|
175 |
|
176 |
|
@@ -187,11 +201,11 @@ class NormalisationPipeline(Pipeline):
|
|
187 |
for entry_dict in dataset['test']:
|
188 |
entry = entry_dict['form']
|
189 |
orig_words.append(entry.lower())
|
190 |
-
if homogenise(entry)
|
191 |
-
homog_words[homogenise(entry)] = entry
|
192 |
-
else:
|
193 |
remove.add(homogenise(entry))
|
194 |
-
|
|
|
|
|
195 |
for entry in remove:
|
196 |
del homog_words[entry]
|
197 |
|
@@ -203,11 +217,8 @@ class NormalisationPipeline(Pipeline):
|
|
203 |
preprocess_params = {}
|
204 |
if truncation is not None:
|
205 |
preprocess_params["truncation"] = truncation
|
206 |
-
|
207 |
forward_params = generate_kwargs
|
208 |
-
|
209 |
postprocess_params = {}
|
210 |
-
|
211 |
if clean_up_tokenisation_spaces is not None:
|
212 |
postprocess_params["clean_up_tokenisation_spaces"] = clean_up_tokenisation_spaces
|
213 |
|
@@ -226,15 +237,8 @@ class NormalisationPipeline(Pipeline):
|
|
226 |
|
227 |
|
228 |
def normalise(self, line):
|
229 |
-
|
230 |
-
|
231 |
-
for before, after in [('[«»\“\”]', '"'),
|
232 |
-
('[‘’]', "'"),
|
233 |
-
(' +', ' '),
|
234 |
-
('\"+', '"'),
|
235 |
-
("'+", "'"),
|
236 |
-
('^ *', ''),
|
237 |
-
(' *$', '')]:
|
238 |
line = re.sub(before, after, line)
|
239 |
return line.strip() + ' </s>'
|
240 |
|
@@ -269,7 +273,6 @@ class NormalisationPipeline(Pipeline):
|
|
269 |
|
270 |
def _forward(self, model_inputs, **generate_kwargs):
|
271 |
in_b, input_length = model_inputs["input_ids"].shape
|
272 |
-
|
273 |
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
|
274 |
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
|
275 |
generate_kwargs['num_beams'] = self.beam_size
|
@@ -279,68 +282,66 @@ class NormalisationPipeline(Pipeline):
|
|
279 |
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
|
280 |
return {"output_ids": output_ids}
|
281 |
|
282 |
-
def postprocess(self, model_outputs,
|
283 |
records = []
|
284 |
for output_ids in model_outputs["output_ids"][0]:
|
285 |
-
record = {
|
286 |
-
|
287 |
-
output_ids,
|
288 |
-
skip_special_tokens=True,
|
289 |
-
clean_up_tokenisation_spaces=clean_up_tokenisation_spaces,
|
290 |
-
)
|
291 |
-
}
|
292 |
records.append(record)
|
293 |
return records
|
294 |
|
295 |
-
def
|
296 |
-
#return [pred_sent]
|
297 |
-
print(alignment)
|
298 |
output = []
|
299 |
-
# align the two
|
300 |
-
#alignments = self.align(orig_sent, pred_sent)
|
301 |
-
# correct word by word
|
302 |
-
len_diff_orig, len_diff_pred = 0, 0
|
303 |
-
pred_idxs = []
|
304 |
-
start = 0
|
305 |
-
for i, char in enumerate(re.sub(' +', ' ', pred_sent_tok) + " "):
|
306 |
-
if char == " ":
|
307 |
-
pred_idxs.append((start, i-1))
|
308 |
-
start = i+1
|
309 |
-
print(pred_idxs)
|
310 |
-
print('°°°°°°°°°°°°°°')
|
311 |
-
suffix_pred_sent = pred_sent
|
312 |
for i, (orig_word, pred_word, _) in enumerate(alignment):
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
# replace word in tokenised sentence
|
318 |
-
|
319 |
-
|
320 |
-
output.append(postproc_word)
|
321 |
-
return re.sub(' +', ' ', ' '.join(output)), alignment
|
322 |
|
323 |
def postprocess_correct_word(self, orig_word, pred_word, alignment):
|
|
|
324 |
# pred_word exists in lexicon, take it
|
|
|
325 |
if pred_word.lower() in self.lexicon_orig:
|
326 |
-
|
|
|
327 |
# otherwise, if original word exists, take that
|
328 |
if orig_word.lower() in self.lexicon_orig:
|
329 |
-
|
330 |
-
|
|
|
331 |
# otherwise if pred word is in the lexicon with some changes, take that
|
332 |
if pred_replacement is not None:
|
333 |
-
|
334 |
-
return pred_replacement,
|
335 |
-
orig_replacement = self.lexicon_homog.get(homogenise(orig_word), None)
|
336 |
# otherwise if orig word is in the lexicon with some changes, take that
|
337 |
if orig_replacement is not None:
|
338 |
-
|
339 |
-
return orig_replacement,
|
340 |
# otherwise return original word (or pred?) + postprocessing?
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
def get_caps(self, word):
|
|
|
|
|
344 |
first, second, allcaps = False, False, False
|
345 |
if len(word) > 0 and word[0].upper() == word[0]:
|
346 |
first = True
|
@@ -356,45 +357,28 @@ class NormalisationPipeline(Pipeline):
|
|
356 |
elif first and second:
|
357 |
return word[0].upper() + word[1].upper() + word[2:]
|
358 |
elif first:
|
359 |
-
|
|
|
|
|
|
|
360 |
elif second:
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
362 |
else:
|
363 |
return word
|
364 |
-
|
365 |
-
def lexicon_lookup(self, candidate):
|
366 |
-
norm_candidate = homogenise(candidate.lower())
|
367 |
-
replacements = []
|
368 |
-
for candidate_word in candidate.split('▁'):
|
369 |
-
capitals = self.get_caps(candidate_word)
|
370 |
-
replacements.append([])
|
371 |
-
for word in self.lexicon:
|
372 |
-
if homogenise(word.lower()) == candidate_word:
|
373 |
-
if len(replacements[-1]) > 0:
|
374 |
-
return None # if ambiguity skip
|
375 |
-
replacements[-1].append(self.set_caps(candidate, *capitals))
|
376 |
-
|
377 |
-
if [] not in replacements:
|
378 |
-
return ' '.join([x[0] for x in replacements]) # or some better strategy
|
379 |
-
else:
|
380 |
-
return None
|
381 |
|
382 |
-
def __call__(self,
|
383 |
r"""
|
384 |
-
Generate the output
|
385 |
Args:
|
386 |
-
args (`
|
387 |
Input text for the encoder.
|
388 |
-
|
389 |
-
|
390 |
-
return_text (`bool`, *optional*, defaults to `True`):
|
391 |
-
Whether or not to include the decoded texts in the outputs.
|
392 |
-
clean_up_tokenisation_spaces (`bool`, *optional*, defaults to `False`):
|
393 |
-
Whether or not to clean up the potential extra spaces in the text output.
|
394 |
-
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
|
395 |
-
The truncation strategy for the tokenisation within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
|
396 |
-
(default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
|
397 |
-
max_length instead of throwing an error down the line.
|
398 |
generate_kwargs:
|
399 |
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
400 |
corresponding to your framework [here](./model#generative-models)).
|
@@ -404,23 +388,18 @@ class NormalisationPipeline(Pipeline):
|
|
404 |
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
|
405 |
ids of the generated text.
|
406 |
"""
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
output.append({'text': result[i][0]['text'], 'alignment': char_spans})
|
420 |
-
return output
|
421 |
-
|
422 |
-
else:
|
423 |
-
return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
|
424 |
|
425 |
def align(self, sent_ref, sent_pred):
|
426 |
sent_ref_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
|
@@ -431,28 +410,35 @@ class NormalisationPipeline(Pipeline):
|
|
431 |
if i_ref == 0 and i_pred == 0:
|
432 |
continue
|
433 |
# spaces in both, add straight away
|
434 |
-
if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' '
|
435 |
-
|
436 |
alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
|
437 |
last_weight = weight
|
438 |
current_word = ['', '']
|
439 |
seen1.append(i_ref)
|
440 |
seen2.append(i_pred)
|
441 |
else:
|
442 |
-
end_space = '░'
|
443 |
if i_ref <= len(sent_ref_tok) and i_ref not in seen1:
|
444 |
if i_ref > 0:
|
445 |
current_word[0] += sent_ref_tok[i_ref-1]
|
446 |
seen1.append(i_ref)
|
447 |
if i_pred <= len(sent_pred_tok) and i_pred not in seen2:
|
448 |
if i_pred > 0:
|
449 |
-
current_word[1] += sent_pred_tok[i_pred-1] if sent_pred_tok[i_pred-1] != ' ' else '▁'
|
450 |
-
end_space = '' if space_after(i_pred, sent_pred_tok) else '░'
|
451 |
seen2.append(i_pred)
|
452 |
if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' and current_word[0].strip() != '':
|
453 |
alignment.append((current_word[0].strip(), current_word[1].strip() + end_space, weight-last_weight))
|
454 |
last_weight = weight
|
455 |
current_word = ['', '']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
# final word
|
457 |
alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
|
458 |
# check that both strings are entirely covered
|
@@ -465,48 +451,51 @@ class NormalisationPipeline(Pipeline):
|
|
465 |
'\n2: ' + re.sub(' +', ' ', recovered2) + "\n2: " + re.sub(' +', ' ', sent_pred_tok)
|
466 |
return alignment, sent_pred_tok
|
467 |
|
|
|
|
|
468 |
|
469 |
def get_char_idx_align(self, sent_ref, sent_pred, alignment):
|
470 |
-
#sent_ref = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
|
471 |
-
#sent_pred = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
|
472 |
-
|
473 |
covered_ref, covered_pred = 0, 0
|
474 |
-
ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
|
475 |
-
pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
|
476 |
align_idx = []
|
477 |
-
|
478 |
for a_ref, a_pred, _ in alignment:
|
479 |
if a_ref == '' and a_pred == '':
|
|
|
480 |
continue
|
481 |
-
a_pred = re.sub(' +', '', a_pred).strip()
|
482 |
-
span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref)
|
483 |
covered_ref += len(a_ref)
|
484 |
-
span_pred = [pred_chars[covered_pred], pred_chars[covered_pred +
|
485 |
-
covered_pred +=
|
486 |
align_idx.append((span_ref, span_pred))
|
487 |
|
488 |
return align_idx
|
489 |
|
490 |
-
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
491 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
492 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
493 |
normalisation_pipeline = NormalisationPipeline(model=model,
|
494 |
tokenizer=tokeniser,
|
495 |
batch_size=batch_size,
|
496 |
beam_size=beam_size,
|
497 |
-
cache_file=
|
|
|
498 |
normalised_outputs = normalisation_pipeline(list_sents)
|
499 |
return normalised_outputs
|
500 |
|
501 |
-
def normalise_from_stdin(batch_size=32, beam_size=5):
|
502 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
503 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
504 |
normalisation_pipeline = NormalisationPipeline(model=model,
|
505 |
-
|
506 |
batch_size=batch_size,
|
507 |
beam_size=beam_size,
|
508 |
-
cache_file=
|
|
|
|
|
509 |
list_sents = []
|
|
|
510 |
for sent in sys.stdin:
|
511 |
list_sents.append(sent.strip())
|
512 |
normalised_outputs = normalisation_pipeline(list_sents)
|
@@ -514,32 +503,41 @@ def normalise_from_stdin(batch_size=32, beam_size=5):
|
|
514 |
alignment=sent['alignment']
|
515 |
|
516 |
# printing in order to debug
|
517 |
-
print('src = ', list_sents[s])
|
518 |
-
print('
|
519 |
# checking that the alignment makes sense
|
520 |
-
for b, a in alignment:
|
521 |
-
|
522 |
-
|
523 |
|
524 |
return normalised_outputs
|
525 |
|
526 |
|
527 |
if __name__ == '__main__':
|
528 |
-
|
529 |
import argparse
|
530 |
parser = argparse.ArgumentParser()
|
531 |
parser.add_argument('-k', '--batch_size', type=int, default=32, help='Set the batch size for decoding')
|
532 |
parser.add_argument('-b', '--beam_size', type=int, default=5, help='Set the beam size for decoding')
|
533 |
parser.add_argument('-i', '--input_file', type=str, default=None, help='Input file. If None, read from STDIN')
|
|
|
|
|
|
|
534 |
args = parser.parse_args()
|
535 |
|
536 |
if args.input_file is None:
|
537 |
-
normalise_from_stdin(batch_size=args.batch_size,
|
|
|
|
|
|
|
538 |
else:
|
539 |
list_sents = []
|
540 |
with open(args.input_file) as fp:
|
541 |
for line in fp:
|
542 |
list_sents.append(line.strip())
|
543 |
-
output_sents = normalise_text(list_sents,
|
|
|
|
|
|
|
|
|
544 |
for output_sent in output_sents:
|
545 |
-
print(output_sent)
|
|
|
20 |
string = re.sub(char + '(?! )' , char + ' ', string)
|
21 |
return string.strip()
|
22 |
|
23 |
+
def homogenise(sent, allow_alter_length=False):
|
24 |
+
'''
|
25 |
+
Homogenise an input sentence by lowercasing, removing diacritics, etc.
|
26 |
+
If allow_alter_length is False, then only applies changes that do not alter
|
27 |
+
the length of the original sentence (i.e. one-to-one modifications). If True,
|
28 |
+
then also apply n-m replacements.
|
29 |
+
'''
|
30 |
sent = sent.lower()
|
31 |
+
# n-m replacemenets
|
32 |
+
if allow_alter_length:
|
33 |
+
for before, after in [('ã', 'an'), ('xoe', 'œ')]:
|
34 |
+
sent = sent.replace(before, after)
|
35 |
+
sent = sent.strip('-')
|
36 |
+
# 1-1 replacements only (must not change the number of characters
|
37 |
replace_from = "ǽǣáàâäąãăåćčçďéèêëęěğìíîĩĭıïĺľłńñňòóôõöøŕřśšşťţùúûũüǔỳýŷÿźẑżžÁÀÂÄĄÃĂÅĆČÇĎÉÈÊËĘĚĞÌÍÎĨĬİÏĹĽŁŃÑŇÒÓÔÕÖØŔŘŚŠŞŤŢÙÚÛŨÜǓỲÝŶŸŹẐŻŽſ"
|
38 |
replace_into = "ææaaaaaaaacccdeeeeeegiiiiiiilllnnnoooooorrsssttuuuuuuyyyyzzzzAAAAAAAACCCDEEEEEEGIIIIIIILLLNNNOOOOOORRSSSTTUUUUUUYYYYZZZZs"
|
39 |
table = sent.maketrans(replace_from, replace_into)
|
|
|
172 |
######## Normaliation pipeline #########
|
173 |
class NormalisationPipeline(Pipeline):
|
174 |
|
175 |
+
def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, cache_file=None, no_postproc=False, **kwargs):
|
176 |
self.beam_size = beam_size
|
177 |
# classic tokeniser function (used for alignments)
|
178 |
if tokenise_func is not None:
|
|
|
181 |
self.classic_tokenise = basic_tokenise
|
182 |
|
183 |
# load lexicon
|
184 |
+
if no_postproc:
|
185 |
+
self.lexicon_orig, self.lexicon_homog = None, None
|
186 |
+
else:
|
187 |
+
self.lexicon_orig, self.lexicon_homog = self.load_lexicon(cache_file=cache_file)
|
188 |
super().__init__(**kwargs)
|
189 |
|
190 |
|
|
|
201 |
for entry_dict in dataset['test']:
|
202 |
entry = entry_dict['form']
|
203 |
orig_words.append(entry.lower())
|
204 |
+
if homogenise(entry) in homog_words and homog_words[homogenise(entry)] != entry.lower():
|
|
|
|
|
205 |
remove.add(homogenise(entry))
|
206 |
+
if homogenise(entry) not in homog_words:
|
207 |
+
homog_words[homogenise(entry)] = entry.lower()
|
208 |
+
|
209 |
for entry in remove:
|
210 |
del homog_words[entry]
|
211 |
|
|
|
217 |
preprocess_params = {}
|
218 |
if truncation is not None:
|
219 |
preprocess_params["truncation"] = truncation
|
|
|
220 |
forward_params = generate_kwargs
|
|
|
221 |
postprocess_params = {}
|
|
|
222 |
if clean_up_tokenisation_spaces is not None:
|
223 |
postprocess_params["clean_up_tokenisation_spaces"] = clean_up_tokenisation_spaces
|
224 |
|
|
|
237 |
|
238 |
|
239 |
def normalise(self, line):
|
240 |
+
for before, after in [('[«»\“\”]', '"'), ('[‘’]', "'"), (' +', ' '), ('\"+', '"'),
|
241 |
+
("'+", "'"), ('^ *', ''), (' *$', '')]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
line = re.sub(before, after, line)
|
243 |
return line.strip() + ' </s>'
|
244 |
|
|
|
273 |
|
274 |
def _forward(self, model_inputs, **generate_kwargs):
|
275 |
in_b, input_length = model_inputs["input_ids"].shape
|
|
|
276 |
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
|
277 |
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
|
278 |
generate_kwargs['num_beams'] = self.beam_size
|
|
|
282 |
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
|
283 |
return {"output_ids": output_ids}
|
284 |
|
285 |
+
def postprocess(self, model_outputs, clean_up_tok_spaces=False):
|
286 |
records = []
|
287 |
for output_ids in model_outputs["output_ids"][0]:
|
288 |
+
record = {"text": self.tokenizer.decode(output_ids, skip_special_tokens=True,
|
289 |
+
clean_up_tokenisation_spaces=clean_up_tok_spaces).strip()}
|
|
|
|
|
|
|
|
|
|
|
290 |
records.append(record)
|
291 |
return records
|
292 |
|
293 |
+
def postprocess_correct_sent(self, alignment):
|
|
|
|
|
294 |
output = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
for i, (orig_word, pred_word, _) in enumerate(alignment):
|
296 |
+
if orig_word != '':
|
297 |
+
postproc_word = self.postprocess_correct_word(orig_word, pred_word, alignment)
|
298 |
+
alignment[i] = (orig_word, postproc_word, -1) # replace prediction in the alignment
|
299 |
+
return alignment
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
def postprocess_correct_word(self, orig_word, pred_word, alignment):
|
302 |
+
#print('*' + pred_word + '*' + orig_word + '*')
|
303 |
# pred_word exists in lexicon, take it
|
304 |
+
orig_caps = self.get_caps(orig_word)
|
305 |
if pred_word.lower() in self.lexicon_orig:
|
306 |
+
#print('pred exists')
|
307 |
+
return self.set_caps(pred_word, *orig_caps)
|
308 |
# otherwise, if original word exists, take that
|
309 |
if orig_word.lower() in self.lexicon_orig:
|
310 |
+
#print('orig word = ', pred_word, orig_word)
|
311 |
+
return orig_word
|
312 |
+
pred_replacement = self.lexicon_homog.get(homogenise(pred_word, True), None)
|
313 |
# otherwise if pred word is in the lexicon with some changes, take that
|
314 |
if pred_replacement is not None:
|
315 |
+
#print('pred replace = ', pred_word, pred_replacement)
|
316 |
+
return self.add_orig_punct(pred_word, self.set_caps(pred_replacement, *orig_caps))
|
317 |
+
orig_replacement = self.lexicon_homog.get(homogenise(orig_word, True), None)
|
318 |
# otherwise if orig word is in the lexicon with some changes, take that
|
319 |
if orig_replacement is not None:
|
320 |
+
#print('orig replace = ', pred_word, orig_replacement)
|
321 |
+
return self.add_orig_punct(orig_word, self.set_caps(orig_replacement, *orig_caps))
|
322 |
# otherwise return original word (or pred?) + postprocessing?
|
323 |
+
#print('last orig replace = ', pred_word, orig_word)
|
324 |
+
|
325 |
+
# TODO: how about, if close enough between src and pred, return pred?
|
326 |
+
return orig_word
|
327 |
+
|
328 |
+
def get_surrounding_punct(self, word):
|
329 |
+
beginning_match = re.match("^(['\-]*)", word)
|
330 |
+
beginning, end = '', ''
|
331 |
+
if beginning_match:
|
332 |
+
beginning = beginning_match.group(1)
|
333 |
+
end_match = re.match("(['\-]*)$", word)
|
334 |
+
if end_match:
|
335 |
+
end = end.group(1)
|
336 |
+
return beginning, end
|
337 |
+
|
338 |
+
def add_orig_punct(self, old_word, new_word):
|
339 |
+
beginning, end = self.get_surrounding_punct(old_word)
|
340 |
+
return beginning + new_word + end
|
341 |
+
|
342 |
def get_caps(self, word):
|
343 |
+
# remove any non-alphatic characters at begining or end
|
344 |
+
word = word.strip("-'")
|
345 |
first, second, allcaps = False, False, False
|
346 |
if len(word) > 0 and word[0].upper() == word[0]:
|
347 |
first = True
|
|
|
357 |
elif first and second:
|
358 |
return word[0].upper() + word[1].upper() + word[2:]
|
359 |
elif first:
|
360 |
+
if len(word) > 1:
|
361 |
+
return word[0].upper() + word[1:]
|
362 |
+
else:
|
363 |
+
return word[0].upper() + word[1:]
|
364 |
elif second:
|
365 |
+
if len(word) > 2:
|
366 |
+
return word[0] + word[1].upper() + word[2:]
|
367 |
+
elif len(word) > 1:
|
368 |
+
return word[0] + word[1].upper() + word[2:]
|
369 |
+
else:
|
370 |
+
return word[0]
|
371 |
else:
|
372 |
return word
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
+
def __call__(self, input_sents, **kwargs):
|
375 |
r"""
|
376 |
+
Generate the output texts using texts given as inputs.
|
377 |
Args:
|
378 |
+
args (`List[str]`):
|
379 |
Input text for the encoder.
|
380 |
+
apply_postprocessing (`Bool`):
|
381 |
+
Apply postprocessing using the lexicon
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
generate_kwargs:
|
383 |
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
384 |
corresponding to your framework [here](./model#generative-models)).
|
|
|
388 |
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
|
389 |
ids of the generated text.
|
390 |
"""
|
391 |
+
result = super().__call__(input_sents, **kwargs)
|
392 |
+
|
393 |
+
output = []
|
394 |
+
for i in range(len(result)):
|
395 |
+
input_sent, pred_sent = input_sents[i].strip(), result[i][0]['text'].strip()
|
396 |
+
alignment, pred_sent_tok = self.align(input_sent, pred_sent)
|
397 |
+
if self.lexicon_orig is not None:
|
398 |
+
alignment = self.postprocess_correct_sent(alignment)
|
399 |
+
pred_sent = self.get_pred_from_alignment(alignment)
|
400 |
+
char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
|
401 |
+
output.append({'text': pred_sent, 'alignment': char_spans})
|
402 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
403 |
|
404 |
def align(self, sent_ref, sent_pred):
|
405 |
sent_ref_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
|
|
|
410 |
if i_ref == 0 and i_pred == 0:
|
411 |
continue
|
412 |
# spaces in both, add straight away
|
413 |
+
if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' \
|
414 |
+
and i_pred <= len(sent_pred_tok) and sent_pred_tok[i_pred-1] == ' ':
|
415 |
alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
|
416 |
last_weight = weight
|
417 |
current_word = ['', '']
|
418 |
seen1.append(i_ref)
|
419 |
seen2.append(i_pred)
|
420 |
else:
|
421 |
+
end_space = '' #'░'
|
422 |
if i_ref <= len(sent_ref_tok) and i_ref not in seen1:
|
423 |
if i_ref > 0:
|
424 |
current_word[0] += sent_ref_tok[i_ref-1]
|
425 |
seen1.append(i_ref)
|
426 |
if i_pred <= len(sent_pred_tok) and i_pred not in seen2:
|
427 |
if i_pred > 0:
|
428 |
+
current_word[1] += sent_pred_tok[i_pred-1] if sent_pred_tok[i_pred-1] != ' ' else ' ' #'▁'
|
429 |
+
end_space = '' if space_after(i_pred, sent_pred_tok) else ''# '░'
|
430 |
seen2.append(i_pred)
|
431 |
if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' and current_word[0].strip() != '':
|
432 |
alignment.append((current_word[0].strip(), current_word[1].strip() + end_space, weight-last_weight))
|
433 |
last_weight = weight
|
434 |
current_word = ['', '']
|
435 |
+
# space in ref but aligned to nothing in pred (under-translation)
|
436 |
+
elif i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' and current_word[1].strip() == '':
|
437 |
+
alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
|
438 |
+
last_weight = weight
|
439 |
+
current_word = ['', '']
|
440 |
+
seen1.append(i_ref)
|
441 |
+
seen2.append(i_pred)
|
442 |
# final word
|
443 |
alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
|
444 |
# check that both strings are entirely covered
|
|
|
451 |
'\n2: ' + re.sub(' +', ' ', recovered2) + "\n2: " + re.sub(' +', ' ', sent_pred_tok)
|
452 |
return alignment, sent_pred_tok
|
453 |
|
454 |
+
def get_pred_from_alignment(self, alignment):
|
455 |
+
return re.sub(' +', ' ', ''.join([x[1] if x[1] != "" else '\n' for x in alignment]).replace('\n', ' '))
|
456 |
|
457 |
def get_char_idx_align(self, sent_ref, sent_pred, alignment):
|
|
|
|
|
|
|
458 |
covered_ref, covered_pred = 0, 0
|
459 |
+
ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']] + [len(sent_ref)]
|
460 |
+
pred_chars = [i for i, character in enumerate(sent_pred)] + [len(sent_pred)]# if character not in [' ']]
|
461 |
align_idx = []
|
|
|
462 |
for a_ref, a_pred, _ in alignment:
|
463 |
if a_ref == '' and a_pred == '':
|
464 |
+
covered_pred += 1
|
465 |
continue
|
466 |
+
a_pred = re.sub(' +', ' ', a_pred).strip()
|
467 |
+
span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref)]]
|
468 |
covered_ref += len(a_ref)
|
469 |
+
span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + len(a_pred)]]
|
470 |
+
covered_pred += len(a_pred)
|
471 |
align_idx.append((span_ref, span_pred))
|
472 |
|
473 |
return align_idx
|
474 |
|
475 |
+
def normalise_text(list_sents, batch_size=32, beam_size=5, cache_file=None):
|
476 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
477 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
478 |
normalisation_pipeline = NormalisationPipeline(model=model,
|
479 |
tokenizer=tokeniser,
|
480 |
batch_size=batch_size,
|
481 |
beam_size=beam_size,
|
482 |
+
cache_file=cache_file,
|
483 |
+
no_postproc=no_postproc)
|
484 |
normalised_outputs = normalisation_pipeline(list_sents)
|
485 |
return normalised_outputs
|
486 |
|
487 |
+
def normalise_from_stdin(batch_size=32, beam_size=5, cache_file=None, no_postproc=False):
|
488 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
489 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
490 |
normalisation_pipeline = NormalisationPipeline(model=model,
|
491 |
+
tokenizer=tokeniser,
|
492 |
batch_size=batch_size,
|
493 |
beam_size=beam_size,
|
494 |
+
cache_file=cache_file,
|
495 |
+
no_postproc=no_postproc
|
496 |
+
)
|
497 |
list_sents = []
|
498 |
+
#ex = ["7. Qu'vne force plus grande de ſi peu que l'on voudra, que celle auec laquelle l'eau de la hauteur de trente & vn pieds, tend à couler en bas, ſuffit pour faire admettre ce vuide apparent, & meſme ſi grãd que l'on voudra, c'eſt à dire, pour faire des-vnir les corps d'vn ſi grand interualle que l'on voudra, pourueu qu'il n'y ait point d'autre obſtacle à leur ſeparation ny à leur eſloignement, que l'horreur que la Nature a pour ce vuide apparent."]
|
499 |
for sent in sys.stdin:
|
500 |
list_sents.append(sent.strip())
|
501 |
normalised_outputs = normalisation_pipeline(list_sents)
|
|
|
503 |
alignment=sent['alignment']
|
504 |
|
505 |
# printing in order to debug
|
506 |
+
#print('src = ', list_sents[s])
|
507 |
+
print(sent['text'])
|
508 |
# checking that the alignment makes sense
|
509 |
+
#for b, a in alignment:
|
510 |
+
# print('input: ' + ''.join([list_sents[s][x] for x in range(b[0], max(len(b), b[1]))]) + '')
|
511 |
+
# print('pred: ' + ''.join([sent['text'][x] for x in range(a[0], max(len(a), a[1]))]) + '')
|
512 |
|
513 |
return normalised_outputs
|
514 |
|
515 |
|
516 |
if __name__ == '__main__':
|
|
|
517 |
import argparse
|
518 |
parser = argparse.ArgumentParser()
|
519 |
parser.add_argument('-k', '--batch_size', type=int, default=32, help='Set the batch size for decoding')
|
520 |
parser.add_argument('-b', '--beam_size', type=int, default=5, help='Set the beam size for decoding')
|
521 |
parser.add_argument('-i', '--input_file', type=str, default=None, help='Input file. If None, read from STDIN')
|
522 |
+
parser.add_argument('-c', '--cache_lexicon', type=str, default=None, help='Path to cache the lexicon file to speed up loading')
|
523 |
+
parser.add_argument('-n', '--no_postproc', default=False, action='store_true', help='Deactivate postprocessing to speed up normalisation, but this may degrade the output')
|
524 |
+
|
525 |
args = parser.parse_args()
|
526 |
|
527 |
if args.input_file is None:
|
528 |
+
normalise_from_stdin(batch_size=args.batch_size,
|
529 |
+
beam_size=args.beam_size,
|
530 |
+
cache_file=args.cache_lexicon,
|
531 |
+
no_postproc=args.no_postproc)
|
532 |
else:
|
533 |
list_sents = []
|
534 |
with open(args.input_file) as fp:
|
535 |
for line in fp:
|
536 |
list_sents.append(line.strip())
|
537 |
+
output_sents = normalise_text(list_sents,
|
538 |
+
batch_size=args.batch_size,
|
539 |
+
beam_size=args.beam_size,
|
540 |
+
cache_file=args.cache_lexicon,
|
541 |
+
no_postproc=args.no_postproc)
|
542 |
for output_sent in output_sents:
|
543 |
+
print(output_sent['text'])
|