rbawden commited on
Commit
713121a
1 Parent(s): 49e0b5e

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +140 -142
pipeline.py CHANGED
@@ -20,9 +20,20 @@ def basic_tokenise(string):
20
  string = re.sub(char + '(?! )' , char + ' ', string)
21
  return string.strip()
22
 
23
- def homogenise(sent):
 
 
 
 
 
 
24
  sent = sent.lower()
25
- # sent = sent.replace("oe", "œ").replace("OE", "Œ")
 
 
 
 
 
26
  replace_from = "ǽǣáàâäąãăåćčçďéèêëęěğìíîĩĭıïĺľłńñňòóôõöøŕřśšşťţùúûũüǔỳýŷÿźẑżžÁÀÂÄĄÃĂÅĆČÇĎÉÈÊËĘĚĞÌÍÎĨĬİÏĹĽŁŃÑŇÒÓÔÕÖØŔŘŚŠŞŤŢÙÚÛŨÜǓỲÝŶŸŹẐŻŽſ"
27
  replace_into = "ææaaaaaaaacccdeeeeeegiiiiiiilllnnnoooooorrsssttuuuuuuyyyyzzzzAAAAAAAACCCDEEEEEEGIIIIIIILLLNNNOOOOOORRSSSTTUUUUUUYYYYZZZZs"
28
  table = sent.maketrans(replace_from, replace_into)
@@ -161,7 +172,7 @@ def space_before(idx, sent):
161
  ######## Normaliation pipeline #########
162
  class NormalisationPipeline(Pipeline):
163
 
164
- def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, cache_file=None, **kwargs):
165
  self.beam_size = beam_size
166
  # classic tokeniser function (used for alignments)
167
  if tokenise_func is not None:
@@ -170,7 +181,10 @@ class NormalisationPipeline(Pipeline):
170
  self.classic_tokenise = basic_tokenise
171
 
172
  # load lexicon
173
- self.lexicon_orig, self.lexicon_homog = self.load_lexicon(cache_file=cache_file)
 
 
 
174
  super().__init__(**kwargs)
175
 
176
 
@@ -187,11 +201,11 @@ class NormalisationPipeline(Pipeline):
187
  for entry_dict in dataset['test']:
188
  entry = entry_dict['form']
189
  orig_words.append(entry.lower())
190
- if homogenise(entry) not in homog_words:
191
- homog_words[homogenise(entry)] = entry
192
- else:
193
  remove.add(homogenise(entry))
194
-
 
 
195
  for entry in remove:
196
  del homog_words[entry]
197
 
@@ -203,11 +217,8 @@ class NormalisationPipeline(Pipeline):
203
  preprocess_params = {}
204
  if truncation is not None:
205
  preprocess_params["truncation"] = truncation
206
-
207
  forward_params = generate_kwargs
208
-
209
  postprocess_params = {}
210
-
211
  if clean_up_tokenisation_spaces is not None:
212
  postprocess_params["clean_up_tokenisation_spaces"] = clean_up_tokenisation_spaces
213
 
@@ -226,15 +237,8 @@ class NormalisationPipeline(Pipeline):
226
 
227
 
228
  def normalise(self, line):
229
- #line = unicodedata.normalize('NFKC', line)
230
- #line = self.make_printable(line)
231
- for before, after in [('[«»\“\”]', '"'),
232
- ('[‘’]', "'"),
233
- (' +', ' '),
234
- ('\"+', '"'),
235
- ("'+", "'"),
236
- ('^ *', ''),
237
- (' *$', '')]:
238
  line = re.sub(before, after, line)
239
  return line.strip() + ' </s>'
240
 
@@ -269,7 +273,6 @@ class NormalisationPipeline(Pipeline):
269
 
270
  def _forward(self, model_inputs, **generate_kwargs):
271
  in_b, input_length = model_inputs["input_ids"].shape
272
-
273
  generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
274
  generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
275
  generate_kwargs['num_beams'] = self.beam_size
@@ -279,68 +282,66 @@ class NormalisationPipeline(Pipeline):
279
  output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
280
  return {"output_ids": output_ids}
281
 
282
- def postprocess(self, model_outputs, clean_up_tokenisation_spaces=False):
283
  records = []
284
  for output_ids in model_outputs["output_ids"][0]:
285
- record = {
286
- "text": self.tokenizer.decode(
287
- output_ids,
288
- skip_special_tokens=True,
289
- clean_up_tokenisation_spaces=clean_up_tokenisation_spaces,
290
- )
291
- }
292
  records.append(record)
293
  return records
294
 
295
- def postprocess_correct_sents(self, alignment, pred_sent_tok):
296
- #return [pred_sent]
297
- print(alignment)
298
  output = []
299
- # align the two
300
- #alignments = self.align(orig_sent, pred_sent)
301
- # correct word by word
302
- len_diff_orig, len_diff_pred = 0, 0
303
- pred_idxs = []
304
- start = 0
305
- for i, char in enumerate(re.sub(' +', ' ', pred_sent_tok) + " "):
306
- if char == " ":
307
- pred_idxs.append((start, i-1))
308
- start = i+1
309
- print(pred_idxs)
310
- print('°°°°°°°°°°°°°°')
311
- suffix_pred_sent = pred_sent
312
  for i, (orig_word, pred_word, _) in enumerate(alignment):
313
- #print(orig_word, pred_word)
314
- start_idx, end_idx = 1, 1
315
- postproc_word, alignment = self.postprocess_correct_word(orig_word, pred_word, alignment)
316
- #print(postproc_word)
317
- # replace word in tokenised sentence
318
-
319
-
320
- output.append(postproc_word)
321
- return re.sub(' +', ' ', ' '.join(output)), alignment
322
 
323
  def postprocess_correct_word(self, orig_word, pred_word, alignment):
 
324
  # pred_word exists in lexicon, take it
 
325
  if pred_word.lower() in self.lexicon_orig:
326
- return pred_word, alignment
 
327
  # otherwise, if original word exists, take that
328
  if orig_word.lower() in self.lexicon_orig:
329
- return orig_word, alignment
330
- pred_replacement = self.lexicon_homog.get(homogenise(pred_word), None)
 
331
  # otherwise if pred word is in the lexicon with some changes, take that
332
  if pred_replacement is not None:
333
- alignment = (alignment[0], pred_replacement, alignment[2])
334
- return pred_replacement, alignment
335
- orig_replacement = self.lexicon_homog.get(homogenise(orig_word), None)
336
  # otherwise if orig word is in the lexicon with some changes, take that
337
  if orig_replacement is not None:
338
- alignment = (orig_replacement, alignment[1], alignment[2])
339
- return orig_replacement, alignment
340
  # otherwise return original word (or pred?) + postprocessing?
341
- return orig_word, alignment
342
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  def get_caps(self, word):
 
 
344
  first, second, allcaps = False, False, False
345
  if len(word) > 0 and word[0].upper() == word[0]:
346
  first = True
@@ -356,45 +357,28 @@ class NormalisationPipeline(Pipeline):
356
  elif first and second:
357
  return word[0].upper() + word[1].upper() + word[2:]
358
  elif first:
359
- return word[0].upper()
 
 
 
360
  elif second:
361
- return word[1].upper()
 
 
 
 
 
362
  else:
363
  return word
364
-
365
- def lexicon_lookup(self, candidate):
366
- norm_candidate = homogenise(candidate.lower())
367
- replacements = []
368
- for candidate_word in candidate.split('▁'):
369
- capitals = self.get_caps(candidate_word)
370
- replacements.append([])
371
- for word in self.lexicon:
372
- if homogenise(word.lower()) == candidate_word:
373
- if len(replacements[-1]) > 0:
374
- return None # if ambiguity skip
375
- replacements[-1].append(self.set_caps(candidate, *capitals))
376
-
377
- if [] not in replacements:
378
- return ' '.join([x[0] for x in replacements]) # or some better strategy
379
- else:
380
- return None
381
 
382
- def __call__(self, *args, **kwargs):
383
  r"""
384
- Generate the output text(s) using text(s) given as inputs.
385
  Args:
386
- args (`str` or `List[str]`):
387
  Input text for the encoder.
388
- return_tensors (`bool`, *optional*, defaults to `False`):
389
- Whether or not to include the tensors of predictions (as token indices) in the outputs.
390
- return_text (`bool`, *optional*, defaults to `True`):
391
- Whether or not to include the decoded texts in the outputs.
392
- clean_up_tokenisation_spaces (`bool`, *optional*, defaults to `False`):
393
- Whether or not to clean up the potential extra spaces in the text output.
394
- truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
395
- The truncation strategy for the tokenisation within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
396
- (default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
397
- max_length instead of throwing an error down the line.
398
  generate_kwargs:
399
  Additional keyword arguments to pass along to the generate method of the model (see the generate method
400
  corresponding to your framework [here](./model#generative-models)).
@@ -404,23 +388,18 @@ class NormalisationPipeline(Pipeline):
404
  - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
405
  ids of the generated text.
406
  """
407
-
408
- result = super().__call__(*args, **kwargs)
409
- if (isinstance(args[0], list)
410
- and all(isinstance(el, str) for el in args[0])
411
- and all(len(res) == 1 for res in result)):
412
- output = []
413
- for i in range(len(result)):
414
- input_sent, pred_sent = args[0][i].strip(), result[i][0]['text'].strip()
415
- alignment, pred_sent_tok = self.align(input_sent, pred_sent)
416
- #pred_sent, alignment = self.postprocess_correct_sents(alignment, pred_sent_tok)
417
- char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
418
-
419
- output.append({'text': result[i][0]['text'], 'alignment': char_spans})
420
- return output
421
-
422
- else:
423
- return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
424
 
425
  def align(self, sent_ref, sent_pred):
426
  sent_ref_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
@@ -431,28 +410,35 @@ class NormalisationPipeline(Pipeline):
431
  if i_ref == 0 and i_pred == 0:
432
  continue
433
  # spaces in both, add straight away
434
- if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' and \
435
- i_pred <= len(sent_pred_tok) and sent_pred_tok[i_pred-1] == ' ':
436
  alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
437
  last_weight = weight
438
  current_word = ['', '']
439
  seen1.append(i_ref)
440
  seen2.append(i_pred)
441
  else:
442
- end_space = '░'
443
  if i_ref <= len(sent_ref_tok) and i_ref not in seen1:
444
  if i_ref > 0:
445
  current_word[0] += sent_ref_tok[i_ref-1]
446
  seen1.append(i_ref)
447
  if i_pred <= len(sent_pred_tok) and i_pred not in seen2:
448
  if i_pred > 0:
449
- current_word[1] += sent_pred_tok[i_pred-1] if sent_pred_tok[i_pred-1] != ' ' else '▁'
450
- end_space = '' if space_after(i_pred, sent_pred_tok) else '░'
451
  seen2.append(i_pred)
452
  if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' and current_word[0].strip() != '':
453
  alignment.append((current_word[0].strip(), current_word[1].strip() + end_space, weight-last_weight))
454
  last_weight = weight
455
  current_word = ['', '']
 
 
 
 
 
 
 
456
  # final word
457
  alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
458
  # check that both strings are entirely covered
@@ -465,48 +451,51 @@ class NormalisationPipeline(Pipeline):
465
  '\n2: ' + re.sub(' +', ' ', recovered2) + "\n2: " + re.sub(' +', ' ', sent_pred_tok)
466
  return alignment, sent_pred_tok
467
 
 
 
468
 
469
  def get_char_idx_align(self, sent_ref, sent_pred, alignment):
470
- #sent_ref = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
471
- #sent_pred = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
472
-
473
  covered_ref, covered_pred = 0, 0
474
- ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
475
- pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
476
  align_idx = []
477
-
478
  for a_ref, a_pred, _ in alignment:
479
  if a_ref == '' and a_pred == '':
 
480
  continue
481
- a_pred = re.sub(' +', '', a_pred).strip()
482
- span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref) - 1]]
483
  covered_ref += len(a_ref)
484
- span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + max(0, len(a_pred) - 1)]]
485
- covered_pred += max(0, len(a_pred))
486
  align_idx.append((span_ref, span_pred))
487
 
488
  return align_idx
489
 
490
- def normalise_text(list_sents, batch_size=32, beam_size=5):
491
  tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
492
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
493
  normalisation_pipeline = NormalisationPipeline(model=model,
494
  tokenizer=tokeniser,
495
  batch_size=batch_size,
496
  beam_size=beam_size,
497
- cache_file="/home/rbawden/scratch/.normalisation_lefff.pickle")
 
498
  normalised_outputs = normalisation_pipeline(list_sents)
499
  return normalised_outputs
500
 
501
- def normalise_from_stdin(batch_size=32, beam_size=5):
502
  tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
503
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
504
  normalisation_pipeline = NormalisationPipeline(model=model,
505
- tokenizer=tokeniser,
506
  batch_size=batch_size,
507
  beam_size=beam_size,
508
- cache_file="/home/rbawden/scratch/.normalisation_lefff.pickle")
 
 
509
  list_sents = []
 
510
  for sent in sys.stdin:
511
  list_sents.append(sent.strip())
512
  normalised_outputs = normalisation_pipeline(list_sents)
@@ -514,32 +503,41 @@ def normalise_from_stdin(batch_size=32, beam_size=5):
514
  alignment=sent['alignment']
515
 
516
  # printing in order to debug
517
- print('src = ', list_sents[s])
518
- print('norm = ', sent)
519
  # checking that the alignment makes sense
520
- for b, a in alignment:
521
- print('input: ' + ''.join([list_sents[s][x] for x in range(b[0], max(len(b), b[1]+1))]) + '')
522
- print('pred: ' + ''.join([sent['text'][x] for x in range(a[0], max(len(a), a[1]+1))]) + '')
523
 
524
  return normalised_outputs
525
 
526
 
527
  if __name__ == '__main__':
528
-
529
  import argparse
530
  parser = argparse.ArgumentParser()
531
  parser.add_argument('-k', '--batch_size', type=int, default=32, help='Set the batch size for decoding')
532
  parser.add_argument('-b', '--beam_size', type=int, default=5, help='Set the beam size for decoding')
533
  parser.add_argument('-i', '--input_file', type=str, default=None, help='Input file. If None, read from STDIN')
 
 
 
534
  args = parser.parse_args()
535
 
536
  if args.input_file is None:
537
- normalise_from_stdin(batch_size=args.batch_size, beam_size=args.beam_size)
 
 
 
538
  else:
539
  list_sents = []
540
  with open(args.input_file) as fp:
541
  for line in fp:
542
  list_sents.append(line.strip())
543
- output_sents = normalise_text(list_sents, batch_size=args.batch_size, beam_size=args.beam_size)
 
 
 
 
544
  for output_sent in output_sents:
545
- print(output_sent)
 
20
  string = re.sub(char + '(?! )' , char + ' ', string)
21
  return string.strip()
22
 
23
+ def homogenise(sent, allow_alter_length=False):
24
+ '''
25
+ Homogenise an input sentence by lowercasing, removing diacritics, etc.
26
+ If allow_alter_length is False, then only applies changes that do not alter
27
+ the length of the original sentence (i.e. one-to-one modifications). If True,
28
+ then also apply n-m replacements.
29
+ '''
30
  sent = sent.lower()
31
+ # n-m replacemenets
32
+ if allow_alter_length:
33
+ for before, after in [('ã', 'an'), ('xoe', 'œ')]:
34
+ sent = sent.replace(before, after)
35
+ sent = sent.strip('-')
36
+ # 1-1 replacements only (must not change the number of characters
37
  replace_from = "ǽǣáàâäąãăåćčçďéèêëęěğìíîĩĭıïĺľłńñňòóôõöøŕřśšşťţùúûũüǔỳýŷÿźẑżžÁÀÂÄĄÃĂÅĆČÇĎÉÈÊËĘĚĞÌÍÎĨĬİÏĹĽŁŃÑŇÒÓÔÕÖØŔŘŚŠŞŤŢÙÚÛŨÜǓỲÝŶŸŹẐŻŽſ"
38
  replace_into = "ææaaaaaaaacccdeeeeeegiiiiiiilllnnnoooooorrsssttuuuuuuyyyyzzzzAAAAAAAACCCDEEEEEEGIIIIIIILLLNNNOOOOOORRSSSTTUUUUUUYYYYZZZZs"
39
  table = sent.maketrans(replace_from, replace_into)
 
172
  ######## Normaliation pipeline #########
173
  class NormalisationPipeline(Pipeline):
174
 
175
+ def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, cache_file=None, no_postproc=False, **kwargs):
176
  self.beam_size = beam_size
177
  # classic tokeniser function (used for alignments)
178
  if tokenise_func is not None:
 
181
  self.classic_tokenise = basic_tokenise
182
 
183
  # load lexicon
184
+ if no_postproc:
185
+ self.lexicon_orig, self.lexicon_homog = None, None
186
+ else:
187
+ self.lexicon_orig, self.lexicon_homog = self.load_lexicon(cache_file=cache_file)
188
  super().__init__(**kwargs)
189
 
190
 
 
201
  for entry_dict in dataset['test']:
202
  entry = entry_dict['form']
203
  orig_words.append(entry.lower())
204
+ if homogenise(entry) in homog_words and homog_words[homogenise(entry)] != entry.lower():
 
 
205
  remove.add(homogenise(entry))
206
+ if homogenise(entry) not in homog_words:
207
+ homog_words[homogenise(entry)] = entry.lower()
208
+
209
  for entry in remove:
210
  del homog_words[entry]
211
 
 
217
  preprocess_params = {}
218
  if truncation is not None:
219
  preprocess_params["truncation"] = truncation
 
220
  forward_params = generate_kwargs
 
221
  postprocess_params = {}
 
222
  if clean_up_tokenisation_spaces is not None:
223
  postprocess_params["clean_up_tokenisation_spaces"] = clean_up_tokenisation_spaces
224
 
 
237
 
238
 
239
  def normalise(self, line):
240
+ for before, after in [('[«»\“\”]', '"'), ('[‘’]', "'"), (' +', ' '), ('\"+', '"'),
241
+ ("'+", "'"), ('^ *', ''), (' *$', '')]:
 
 
 
 
 
 
 
242
  line = re.sub(before, after, line)
243
  return line.strip() + ' </s>'
244
 
 
273
 
274
  def _forward(self, model_inputs, **generate_kwargs):
275
  in_b, input_length = model_inputs["input_ids"].shape
 
276
  generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
277
  generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
278
  generate_kwargs['num_beams'] = self.beam_size
 
282
  output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
283
  return {"output_ids": output_ids}
284
 
285
+ def postprocess(self, model_outputs, clean_up_tok_spaces=False):
286
  records = []
287
  for output_ids in model_outputs["output_ids"][0]:
288
+ record = {"text": self.tokenizer.decode(output_ids, skip_special_tokens=True,
289
+ clean_up_tokenisation_spaces=clean_up_tok_spaces).strip()}
 
 
 
 
 
290
  records.append(record)
291
  return records
292
 
293
+ def postprocess_correct_sent(self, alignment):
 
 
294
  output = []
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  for i, (orig_word, pred_word, _) in enumerate(alignment):
296
+ if orig_word != '':
297
+ postproc_word = self.postprocess_correct_word(orig_word, pred_word, alignment)
298
+ alignment[i] = (orig_word, postproc_word, -1) # replace prediction in the alignment
299
+ return alignment
 
 
 
 
 
300
 
301
  def postprocess_correct_word(self, orig_word, pred_word, alignment):
302
+ #print('*' + pred_word + '*' + orig_word + '*')
303
  # pred_word exists in lexicon, take it
304
+ orig_caps = self.get_caps(orig_word)
305
  if pred_word.lower() in self.lexicon_orig:
306
+ #print('pred exists')
307
+ return self.set_caps(pred_word, *orig_caps)
308
  # otherwise, if original word exists, take that
309
  if orig_word.lower() in self.lexicon_orig:
310
+ #print('orig word = ', pred_word, orig_word)
311
+ return orig_word
312
+ pred_replacement = self.lexicon_homog.get(homogenise(pred_word, True), None)
313
  # otherwise if pred word is in the lexicon with some changes, take that
314
  if pred_replacement is not None:
315
+ #print('pred replace = ', pred_word, pred_replacement)
316
+ return self.add_orig_punct(pred_word, self.set_caps(pred_replacement, *orig_caps))
317
+ orig_replacement = self.lexicon_homog.get(homogenise(orig_word, True), None)
318
  # otherwise if orig word is in the lexicon with some changes, take that
319
  if orig_replacement is not None:
320
+ #print('orig replace = ', pred_word, orig_replacement)
321
+ return self.add_orig_punct(orig_word, self.set_caps(orig_replacement, *orig_caps))
322
  # otherwise return original word (or pred?) + postprocessing?
323
+ #print('last orig replace = ', pred_word, orig_word)
324
+
325
+ # TODO: how about, if close enough between src and pred, return pred?
326
+ return orig_word
327
+
328
+ def get_surrounding_punct(self, word):
329
+ beginning_match = re.match("^(['\-]*)", word)
330
+ beginning, end = '', ''
331
+ if beginning_match:
332
+ beginning = beginning_match.group(1)
333
+ end_match = re.match("(['\-]*)$", word)
334
+ if end_match:
335
+ end = end.group(1)
336
+ return beginning, end
337
+
338
+ def add_orig_punct(self, old_word, new_word):
339
+ beginning, end = self.get_surrounding_punct(old_word)
340
+ return beginning + new_word + end
341
+
342
  def get_caps(self, word):
343
+ # remove any non-alphatic characters at begining or end
344
+ word = word.strip("-'")
345
  first, second, allcaps = False, False, False
346
  if len(word) > 0 and word[0].upper() == word[0]:
347
  first = True
 
357
  elif first and second:
358
  return word[0].upper() + word[1].upper() + word[2:]
359
  elif first:
360
+ if len(word) > 1:
361
+ return word[0].upper() + word[1:]
362
+ else:
363
+ return word[0].upper() + word[1:]
364
  elif second:
365
+ if len(word) > 2:
366
+ return word[0] + word[1].upper() + word[2:]
367
+ elif len(word) > 1:
368
+ return word[0] + word[1].upper() + word[2:]
369
+ else:
370
+ return word[0]
371
  else:
372
  return word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
+ def __call__(self, input_sents, **kwargs):
375
  r"""
376
+ Generate the output texts using texts given as inputs.
377
  Args:
378
+ args (`List[str]`):
379
  Input text for the encoder.
380
+ apply_postprocessing (`Bool`):
381
+ Apply postprocessing using the lexicon
 
 
 
 
 
 
 
 
382
  generate_kwargs:
383
  Additional keyword arguments to pass along to the generate method of the model (see the generate method
384
  corresponding to your framework [here](./model#generative-models)).
 
388
  - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
389
  ids of the generated text.
390
  """
391
+ result = super().__call__(input_sents, **kwargs)
392
+
393
+ output = []
394
+ for i in range(len(result)):
395
+ input_sent, pred_sent = input_sents[i].strip(), result[i][0]['text'].strip()
396
+ alignment, pred_sent_tok = self.align(input_sent, pred_sent)
397
+ if self.lexicon_orig is not None:
398
+ alignment = self.postprocess_correct_sent(alignment)
399
+ pred_sent = self.get_pred_from_alignment(alignment)
400
+ char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
401
+ output.append({'text': pred_sent, 'alignment': char_spans})
402
+ return output
 
 
 
 
 
403
 
404
  def align(self, sent_ref, sent_pred):
405
  sent_ref_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
 
410
  if i_ref == 0 and i_pred == 0:
411
  continue
412
  # spaces in both, add straight away
413
+ if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' \
414
+ and i_pred <= len(sent_pred_tok) and sent_pred_tok[i_pred-1] == ' ':
415
  alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
416
  last_weight = weight
417
  current_word = ['', '']
418
  seen1.append(i_ref)
419
  seen2.append(i_pred)
420
  else:
421
+ end_space = '' #'░'
422
  if i_ref <= len(sent_ref_tok) and i_ref not in seen1:
423
  if i_ref > 0:
424
  current_word[0] += sent_ref_tok[i_ref-1]
425
  seen1.append(i_ref)
426
  if i_pred <= len(sent_pred_tok) and i_pred not in seen2:
427
  if i_pred > 0:
428
+ current_word[1] += sent_pred_tok[i_pred-1] if sent_pred_tok[i_pred-1] != ' ' else ' ' #'▁'
429
+ end_space = '' if space_after(i_pred, sent_pred_tok) else ''# '░'
430
  seen2.append(i_pred)
431
  if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' and current_word[0].strip() != '':
432
  alignment.append((current_word[0].strip(), current_word[1].strip() + end_space, weight-last_weight))
433
  last_weight = weight
434
  current_word = ['', '']
435
+ # space in ref but aligned to nothing in pred (under-translation)
436
+ elif i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' and current_word[1].strip() == '':
437
+ alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
438
+ last_weight = weight
439
+ current_word = ['', '']
440
+ seen1.append(i_ref)
441
+ seen2.append(i_pred)
442
  # final word
443
  alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
444
  # check that both strings are entirely covered
 
451
  '\n2: ' + re.sub(' +', ' ', recovered2) + "\n2: " + re.sub(' +', ' ', sent_pred_tok)
452
  return alignment, sent_pred_tok
453
 
454
+ def get_pred_from_alignment(self, alignment):
455
+ return re.sub(' +', ' ', ''.join([x[1] if x[1] != "" else '\n' for x in alignment]).replace('\n', ' '))
456
 
457
  def get_char_idx_align(self, sent_ref, sent_pred, alignment):
 
 
 
458
  covered_ref, covered_pred = 0, 0
459
+ ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']] + [len(sent_ref)]
460
+ pred_chars = [i for i, character in enumerate(sent_pred)] + [len(sent_pred)]# if character not in [' ']]
461
  align_idx = []
 
462
  for a_ref, a_pred, _ in alignment:
463
  if a_ref == '' and a_pred == '':
464
+ covered_pred += 1
465
  continue
466
+ a_pred = re.sub(' +', ' ', a_pred).strip()
467
+ span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref)]]
468
  covered_ref += len(a_ref)
469
+ span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + len(a_pred)]]
470
+ covered_pred += len(a_pred)
471
  align_idx.append((span_ref, span_pred))
472
 
473
  return align_idx
474
 
475
+ def normalise_text(list_sents, batch_size=32, beam_size=5, cache_file=None):
476
  tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
477
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
478
  normalisation_pipeline = NormalisationPipeline(model=model,
479
  tokenizer=tokeniser,
480
  batch_size=batch_size,
481
  beam_size=beam_size,
482
+ cache_file=cache_file,
483
+ no_postproc=no_postproc)
484
  normalised_outputs = normalisation_pipeline(list_sents)
485
  return normalised_outputs
486
 
487
+ def normalise_from_stdin(batch_size=32, beam_size=5, cache_file=None, no_postproc=False):
488
  tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
489
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
490
  normalisation_pipeline = NormalisationPipeline(model=model,
491
+ tokenizer=tokeniser,
492
  batch_size=batch_size,
493
  beam_size=beam_size,
494
+ cache_file=cache_file,
495
+ no_postproc=no_postproc
496
+ )
497
  list_sents = []
498
+ #ex = ["7. Qu'vne force plus grande de ſi peu que l'on voudra, que celle auec laquelle l'eau de la hauteur de trente & vn pieds, tend à couler en bas, ſuffit pour faire admettre ce vuide apparent, & meſme ſi grãd que l'on voudra, c'eſt à dire, pour faire des-vnir les corps d'vn ſi grand interualle que l'on voudra, pourueu qu'il n'y ait point d'autre obſtacle à leur ſeparation ny à leur eſloignement, que l'horreur que la Nature a pour ce vuide apparent."]
499
  for sent in sys.stdin:
500
  list_sents.append(sent.strip())
501
  normalised_outputs = normalisation_pipeline(list_sents)
 
503
  alignment=sent['alignment']
504
 
505
  # printing in order to debug
506
+ #print('src = ', list_sents[s])
507
+ print(sent['text'])
508
  # checking that the alignment makes sense
509
+ #for b, a in alignment:
510
+ # print('input: ' + ''.join([list_sents[s][x] for x in range(b[0], max(len(b), b[1]))]) + '')
511
+ # print('pred: ' + ''.join([sent['text'][x] for x in range(a[0], max(len(a), a[1]))]) + '')
512
 
513
  return normalised_outputs
514
 
515
 
516
  if __name__ == '__main__':
 
517
  import argparse
518
  parser = argparse.ArgumentParser()
519
  parser.add_argument('-k', '--batch_size', type=int, default=32, help='Set the batch size for decoding')
520
  parser.add_argument('-b', '--beam_size', type=int, default=5, help='Set the beam size for decoding')
521
  parser.add_argument('-i', '--input_file', type=str, default=None, help='Input file. If None, read from STDIN')
522
+ parser.add_argument('-c', '--cache_lexicon', type=str, default=None, help='Path to cache the lexicon file to speed up loading')
523
+ parser.add_argument('-n', '--no_postproc', default=False, action='store_true', help='Deactivate postprocessing to speed up normalisation, but this may degrade the output')
524
+
525
  args = parser.parse_args()
526
 
527
  if args.input_file is None:
528
+ normalise_from_stdin(batch_size=args.batch_size,
529
+ beam_size=args.beam_size,
530
+ cache_file=args.cache_lexicon,
531
+ no_postproc=args.no_postproc)
532
  else:
533
  list_sents = []
534
  with open(args.input_file) as fp:
535
  for line in fp:
536
  list_sents.append(line.strip())
537
+ output_sents = normalise_text(list_sents,
538
+ batch_size=args.batch_size,
539
+ beam_size=args.beam_size,
540
+ cache_file=args.cache_lexicon,
541
+ no_postproc=args.no_postproc)
542
  for output_sent in output_sents:
543
+ print(output_sent['text'])