rbawden commited on
Commit
eac4b97
1 Parent(s): 659adff

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +27 -40
pipeline.py CHANGED
@@ -6,6 +6,7 @@ import html.parser
6
  import unicodedata
7
  import sys, os
8
  import re
 
9
  from tqdm.auto import tqdm
10
  import operator
11
  from datasets import load_dataset
@@ -160,7 +161,7 @@ def space_before(idx, sent):
160
  ######## Normaliation pipeline #########
161
  class NormalisationPipeline(Pipeline):
162
 
163
- def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, **kwargs):
164
  self.beam_size = beam_size
165
  # classic tokeniser function (used for alignments)
166
  if tokenise_func is not None:
@@ -169,15 +170,18 @@ class NormalisationPipeline(Pipeline):
169
  self.classic_tokenise = basic_tokenise
170
 
171
  # load lexicon
172
- self.lexicon_orig, self.lexicon_homog = self.load_lexicon()
173
  super().__init__(**kwargs)
174
 
175
 
176
- def load_lexicon(self):
177
- #local_file = '../data/lexicons/lefff-3.4.mlex'
178
  orig_words = []
179
  homog_words = {}
180
  remove = set([])
 
 
 
 
181
  dataset = load_dataset("sagot/lefff_morpho")
182
 
183
  for entry_dict in dataset['test']:
@@ -190,6 +194,9 @@ class NormalisationPipeline(Pipeline):
190
 
191
  for entry in remove:
192
  del homog_words[entry]
 
 
 
193
  return orig_words, homog_words
194
 
195
  def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
@@ -405,18 +412,8 @@ class NormalisationPipeline(Pipeline):
405
  output = []
406
  for i in range(len(result)):
407
  input_sent, pred_sent = args[0][i].strip(), result[i][0]['text'].strip()
408
- # correct pred sent
409
- print('prediction = ', pred_sent)
410
- print('source = ', input_sent)
411
-
412
-
413
-
414
  alignment, pred_sent_tok = self.align(input_sent, pred_sent)
415
- pred_sent, alignment = self.postprocess_correct_sents(alignment, pred_sent_tok)
416
- print('alignment = ', alignment)
417
- print('corrected pred = ', pred_sent)
418
- print([x[1] for x in alignment])
419
- print('******')
420
  char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
421
 
422
  output.append({'text': result[i][0]['text'], 'alignment': char_spans})
@@ -426,13 +423,10 @@ class NormalisationPipeline(Pipeline):
426
  return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
427
 
428
  def align(self, sent_ref, sent_pred):
429
- print("*", sent_pred)
430
  sent_ref_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
431
  sent_pred_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
432
  backpointers = wedit_distance_align(homogenise(sent_ref_tok), homogenise(sent_pred_tok))
433
  alignment, current_word, seen1, seen2, last_weight = [], ['', ''], [], [], 0
434
-
435
- print('before align = ', homogenise(sent_ref_tok), homogenise(sent_pred_tok))
436
  for i_ref, i_pred, weight in backpointers:
437
  if i_ref == 0 and i_pred == 0:
438
  continue
@@ -445,7 +439,7 @@ class NormalisationPipeline(Pipeline):
445
  seen1.append(i_ref)
446
  seen2.append(i_pred)
447
  else:
448
- end_space = '' #'░'
449
  if i_ref <= len(sent_ref_tok) and i_ref not in seen1:
450
  if i_ref > 0:
451
  current_word[0] += sent_ref_tok[i_ref-1]
@@ -473,38 +467,34 @@ class NormalisationPipeline(Pipeline):
473
 
474
 
475
  def get_char_idx_align(self, sent_ref, sent_pred, alignment):
476
- sent_ref = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
477
- sent_pred = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
478
 
479
  covered_ref, covered_pred = 0, 0
480
  ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
481
  pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
482
  align_idx = []
483
- #print(ref_chars)
484
- #print(pred_chars)
485
  for a_ref, a_pred, _ in alignment:
486
  if a_ref == '' and a_pred == '':
487
  continue
488
- #print('ref: ', sent_ref)
489
- #print('pred: ', sent_pred)
490
- #print('align: ', a_ref, a_pred)
491
  a_pred = re.sub(' +', '', a_pred).strip()
492
  span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref) - 1]]
493
  covered_ref += len(a_ref)
494
  span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + max(0, len(a_pred) - 1)]]
495
  covered_pred += max(0, len(a_pred))
496
  align_idx.append((span_ref, span_pred))
497
- #print(span_ref, span_pred)
498
- #print('---')
499
  return align_idx
500
 
501
  def normalise_text(list_sents, batch_size=32, beam_size=5):
502
  tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
503
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
504
  normalisation_pipeline = NormalisationPipeline(model=model,
505
- tokenizer=tokeniser,
506
- batch_size=batch_size,
507
- beam_size=beam_size)
 
508
  normalised_outputs = normalisation_pipeline(list_sents)
509
  return normalised_outputs
510
 
@@ -513,24 +503,21 @@ def normalise_from_stdin(batch_size=32, beam_size=5):
513
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
514
  normalisation_pipeline = NormalisationPipeline(model=model,
515
  tokenizer=tokeniser,
516
- batch_size=batch_size,
517
- beam_size=beam_size)
 
518
  list_sents = []
519
  for sent in sys.stdin:
520
  list_sents.append(sent.strip())
521
  normalised_outputs = normalisation_pipeline(list_sents)
522
- print('norm outputs = ', normalised_outputs)
523
  for s, sent in enumerate(normalised_outputs):
524
  alignment=sent['alignment']
525
- #print(list_sents[s], len(list_sents[s]))
526
- #print(sent['text'], len(sent['text']))
527
- print(sent['alignment'])
528
  print('src = ', list_sents[s])
529
  print('norm = ', sent)
 
530
  for b, a in alignment:
531
- #print(b, a)
532
- #print(list_sents[s])
533
- #print(sent['text'])
534
  print('input: ' + ''.join([list_sents[s][x] for x in range(b[0], max(len(b), b[1]+1))]) + '')
535
  print('pred: ' + ''.join([sent['text'][x] for x in range(a[0], max(len(a), a[1]+1))]) + '')
536
 
 
6
  import unicodedata
7
  import sys, os
8
  import re
9
+ import pickle
10
  from tqdm.auto import tqdm
11
  import operator
12
  from datasets import load_dataset
 
161
  ######## Normaliation pipeline #########
162
  class NormalisationPipeline(Pipeline):
163
 
164
+ def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, cache_file=None, **kwargs):
165
  self.beam_size = beam_size
166
  # classic tokeniser function (used for alignments)
167
  if tokenise_func is not None:
 
170
  self.classic_tokenise = basic_tokenise
171
 
172
  # load lexicon
173
+ self.lexicon_orig, self.lexicon_homog = self.load_lexicon(cache_file=cache_file)
174
  super().__init__(**kwargs)
175
 
176
 
177
+ def load_lexicon(self, cache_file=None):
 
178
  orig_words = []
179
  homog_words = {}
180
  remove = set([])
181
+
182
+ # load pickled version if there
183
+ if cache_file is not None and os.path.exists(cache_file):
184
+ return pickle.load(open(cache_file, 'rb'))
185
  dataset = load_dataset("sagot/lefff_morpho")
186
 
187
  for entry_dict in dataset['test']:
 
194
 
195
  for entry in remove:
196
  del homog_words[entry]
197
+
198
+ if cache_file is not None:
199
+ pickle.dump((orig_words, homog_words), open(cache_file, 'wb'))
200
  return orig_words, homog_words
201
 
202
  def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
 
412
  output = []
413
  for i in range(len(result)):
414
  input_sent, pred_sent = args[0][i].strip(), result[i][0]['text'].strip()
 
 
 
 
 
 
415
  alignment, pred_sent_tok = self.align(input_sent, pred_sent)
416
+ #pred_sent, alignment = self.postprocess_correct_sents(alignment, pred_sent_tok)
 
 
 
 
417
  char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
418
 
419
  output.append({'text': result[i][0]['text'], 'alignment': char_spans})
 
423
  return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
424
 
425
  def align(self, sent_ref, sent_pred):
 
426
  sent_ref_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
427
  sent_pred_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
428
  backpointers = wedit_distance_align(homogenise(sent_ref_tok), homogenise(sent_pred_tok))
429
  alignment, current_word, seen1, seen2, last_weight = [], ['', ''], [], [], 0
 
 
430
  for i_ref, i_pred, weight in backpointers:
431
  if i_ref == 0 and i_pred == 0:
432
  continue
 
439
  seen1.append(i_ref)
440
  seen2.append(i_pred)
441
  else:
442
+ end_space = '░'
443
  if i_ref <= len(sent_ref_tok) and i_ref not in seen1:
444
  if i_ref > 0:
445
  current_word[0] += sent_ref_tok[i_ref-1]
 
467
 
468
 
469
  def get_char_idx_align(self, sent_ref, sent_pred, alignment):
470
+ #sent_ref = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
471
+ #sent_pred = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
472
 
473
  covered_ref, covered_pred = 0, 0
474
  ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
475
  pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
476
  align_idx = []
477
+
 
478
  for a_ref, a_pred, _ in alignment:
479
  if a_ref == '' and a_pred == '':
480
  continue
 
 
 
481
  a_pred = re.sub(' +', '', a_pred).strip()
482
  span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref) - 1]]
483
  covered_ref += len(a_ref)
484
  span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + max(0, len(a_pred) - 1)]]
485
  covered_pred += max(0, len(a_pred))
486
  align_idx.append((span_ref, span_pred))
487
+
 
488
  return align_idx
489
 
490
  def normalise_text(list_sents, batch_size=32, beam_size=5):
491
  tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
492
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
493
  normalisation_pipeline = NormalisationPipeline(model=model,
494
+ tokenizer=tokeniser,
495
+ batch_size=batch_size,
496
+ beam_size=beam_size,
497
+ cache_file="/home/rbawden/scratch/.normalisation_lefff.pickle")
498
  normalised_outputs = normalisation_pipeline(list_sents)
499
  return normalised_outputs
500
 
 
503
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
504
  normalisation_pipeline = NormalisationPipeline(model=model,
505
  tokenizer=tokeniser,
506
+ batch_size=batch_size,
507
+ beam_size=beam_size,
508
+ cache_file="/home/rbawden/scratch/.normalisation_lefff.pickle")
509
  list_sents = []
510
  for sent in sys.stdin:
511
  list_sents.append(sent.strip())
512
  normalised_outputs = normalisation_pipeline(list_sents)
 
513
  for s, sent in enumerate(normalised_outputs):
514
  alignment=sent['alignment']
515
+
516
+ # printing in order to debug
 
517
  print('src = ', list_sents[s])
518
  print('norm = ', sent)
519
+ # checking that the alignment makes sense
520
  for b, a in alignment:
 
 
 
521
  print('input: ' + ''.join([list_sents[s][x] for x in range(b[0], max(len(b), b[1]+1))]) + '')
522
  print('pred: ' + ''.join([sent['text'][x] for x in range(a[0], max(len(a), a[1]+1))]) + '')
523