Upload pipeline.py
Browse files- pipeline.py +27 -40
pipeline.py
CHANGED
@@ -6,6 +6,7 @@ import html.parser
|
|
6 |
import unicodedata
|
7 |
import sys, os
|
8 |
import re
|
|
|
9 |
from tqdm.auto import tqdm
|
10 |
import operator
|
11 |
from datasets import load_dataset
|
@@ -160,7 +161,7 @@ def space_before(idx, sent):
|
|
160 |
######## Normaliation pipeline #########
|
161 |
class NormalisationPipeline(Pipeline):
|
162 |
|
163 |
-
def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, **kwargs):
|
164 |
self.beam_size = beam_size
|
165 |
# classic tokeniser function (used for alignments)
|
166 |
if tokenise_func is not None:
|
@@ -169,15 +170,18 @@ class NormalisationPipeline(Pipeline):
|
|
169 |
self.classic_tokenise = basic_tokenise
|
170 |
|
171 |
# load lexicon
|
172 |
-
self.lexicon_orig, self.lexicon_homog = self.load_lexicon()
|
173 |
super().__init__(**kwargs)
|
174 |
|
175 |
|
176 |
-
def load_lexicon(self):
|
177 |
-
#local_file = '../data/lexicons/lefff-3.4.mlex'
|
178 |
orig_words = []
|
179 |
homog_words = {}
|
180 |
remove = set([])
|
|
|
|
|
|
|
|
|
181 |
dataset = load_dataset("sagot/lefff_morpho")
|
182 |
|
183 |
for entry_dict in dataset['test']:
|
@@ -190,6 +194,9 @@ class NormalisationPipeline(Pipeline):
|
|
190 |
|
191 |
for entry in remove:
|
192 |
del homog_words[entry]
|
|
|
|
|
|
|
193 |
return orig_words, homog_words
|
194 |
|
195 |
def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
|
@@ -405,18 +412,8 @@ class NormalisationPipeline(Pipeline):
|
|
405 |
output = []
|
406 |
for i in range(len(result)):
|
407 |
input_sent, pred_sent = args[0][i].strip(), result[i][0]['text'].strip()
|
408 |
-
# correct pred sent
|
409 |
-
print('prediction = ', pred_sent)
|
410 |
-
print('source = ', input_sent)
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
alignment, pred_sent_tok = self.align(input_sent, pred_sent)
|
415 |
-
pred_sent, alignment = self.postprocess_correct_sents(alignment, pred_sent_tok)
|
416 |
-
print('alignment = ', alignment)
|
417 |
-
print('corrected pred = ', pred_sent)
|
418 |
-
print([x[1] for x in alignment])
|
419 |
-
print('******')
|
420 |
char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
|
421 |
|
422 |
output.append({'text': result[i][0]['text'], 'alignment': char_spans})
|
@@ -426,13 +423,10 @@ class NormalisationPipeline(Pipeline):
|
|
426 |
return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
|
427 |
|
428 |
def align(self, sent_ref, sent_pred):
|
429 |
-
print("*", sent_pred)
|
430 |
sent_ref_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
|
431 |
sent_pred_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
|
432 |
backpointers = wedit_distance_align(homogenise(sent_ref_tok), homogenise(sent_pred_tok))
|
433 |
alignment, current_word, seen1, seen2, last_weight = [], ['', ''], [], [], 0
|
434 |
-
|
435 |
-
print('before align = ', homogenise(sent_ref_tok), homogenise(sent_pred_tok))
|
436 |
for i_ref, i_pred, weight in backpointers:
|
437 |
if i_ref == 0 and i_pred == 0:
|
438 |
continue
|
@@ -445,7 +439,7 @@ class NormalisationPipeline(Pipeline):
|
|
445 |
seen1.append(i_ref)
|
446 |
seen2.append(i_pred)
|
447 |
else:
|
448 |
-
end_space = '
|
449 |
if i_ref <= len(sent_ref_tok) and i_ref not in seen1:
|
450 |
if i_ref > 0:
|
451 |
current_word[0] += sent_ref_tok[i_ref-1]
|
@@ -473,38 +467,34 @@ class NormalisationPipeline(Pipeline):
|
|
473 |
|
474 |
|
475 |
def get_char_idx_align(self, sent_ref, sent_pred, alignment):
|
476 |
-
sent_ref = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
|
477 |
-
sent_pred = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
|
478 |
|
479 |
covered_ref, covered_pred = 0, 0
|
480 |
ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
|
481 |
pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
|
482 |
align_idx = []
|
483 |
-
|
484 |
-
#print(pred_chars)
|
485 |
for a_ref, a_pred, _ in alignment:
|
486 |
if a_ref == '' and a_pred == '':
|
487 |
continue
|
488 |
-
#print('ref: ', sent_ref)
|
489 |
-
#print('pred: ', sent_pred)
|
490 |
-
#print('align: ', a_ref, a_pred)
|
491 |
a_pred = re.sub(' +', '', a_pred).strip()
|
492 |
span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref) - 1]]
|
493 |
covered_ref += len(a_ref)
|
494 |
span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + max(0, len(a_pred) - 1)]]
|
495 |
covered_pred += max(0, len(a_pred))
|
496 |
align_idx.append((span_ref, span_pred))
|
497 |
-
|
498 |
-
#print('---')
|
499 |
return align_idx
|
500 |
|
501 |
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
502 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
503 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
504 |
normalisation_pipeline = NormalisationPipeline(model=model,
|
505 |
-
|
506 |
-
|
507 |
-
|
|
|
508 |
normalised_outputs = normalisation_pipeline(list_sents)
|
509 |
return normalised_outputs
|
510 |
|
@@ -513,24 +503,21 @@ def normalise_from_stdin(batch_size=32, beam_size=5):
|
|
513 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
514 |
normalisation_pipeline = NormalisationPipeline(model=model,
|
515 |
tokenizer=tokeniser,
|
516 |
-
|
517 |
-
|
|
|
518 |
list_sents = []
|
519 |
for sent in sys.stdin:
|
520 |
list_sents.append(sent.strip())
|
521 |
normalised_outputs = normalisation_pipeline(list_sents)
|
522 |
-
print('norm outputs = ', normalised_outputs)
|
523 |
for s, sent in enumerate(normalised_outputs):
|
524 |
alignment=sent['alignment']
|
525 |
-
|
526 |
-
#
|
527 |
-
print(sent['alignment'])
|
528 |
print('src = ', list_sents[s])
|
529 |
print('norm = ', sent)
|
|
|
530 |
for b, a in alignment:
|
531 |
-
#print(b, a)
|
532 |
-
#print(list_sents[s])
|
533 |
-
#print(sent['text'])
|
534 |
print('input: ' + ''.join([list_sents[s][x] for x in range(b[0], max(len(b), b[1]+1))]) + '')
|
535 |
print('pred: ' + ''.join([sent['text'][x] for x in range(a[0], max(len(a), a[1]+1))]) + '')
|
536 |
|
|
|
6 |
import unicodedata
|
7 |
import sys, os
|
8 |
import re
|
9 |
+
import pickle
|
10 |
from tqdm.auto import tqdm
|
11 |
import operator
|
12 |
from datasets import load_dataset
|
|
|
161 |
######## Normaliation pipeline #########
|
162 |
class NormalisationPipeline(Pipeline):
|
163 |
|
164 |
+
def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, cache_file=None, **kwargs):
|
165 |
self.beam_size = beam_size
|
166 |
# classic tokeniser function (used for alignments)
|
167 |
if tokenise_func is not None:
|
|
|
170 |
self.classic_tokenise = basic_tokenise
|
171 |
|
172 |
# load lexicon
|
173 |
+
self.lexicon_orig, self.lexicon_homog = self.load_lexicon(cache_file=cache_file)
|
174 |
super().__init__(**kwargs)
|
175 |
|
176 |
|
177 |
+
def load_lexicon(self, cache_file=None):
|
|
|
178 |
orig_words = []
|
179 |
homog_words = {}
|
180 |
remove = set([])
|
181 |
+
|
182 |
+
# load pickled version if there
|
183 |
+
if cache_file is not None and os.path.exists(cache_file):
|
184 |
+
return pickle.load(open(cache_file, 'rb'))
|
185 |
dataset = load_dataset("sagot/lefff_morpho")
|
186 |
|
187 |
for entry_dict in dataset['test']:
|
|
|
194 |
|
195 |
for entry in remove:
|
196 |
del homog_words[entry]
|
197 |
+
|
198 |
+
if cache_file is not None:
|
199 |
+
pickle.dump((orig_words, homog_words), open(cache_file, 'wb'))
|
200 |
return orig_words, homog_words
|
201 |
|
202 |
def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
|
|
|
412 |
output = []
|
413 |
for i in range(len(result)):
|
414 |
input_sent, pred_sent = args[0][i].strip(), result[i][0]['text'].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
alignment, pred_sent_tok = self.align(input_sent, pred_sent)
|
416 |
+
#pred_sent, alignment = self.postprocess_correct_sents(alignment, pred_sent_tok)
|
|
|
|
|
|
|
|
|
417 |
char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
|
418 |
|
419 |
output.append({'text': result[i][0]['text'], 'alignment': char_spans})
|
|
|
423 |
return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
|
424 |
|
425 |
def align(self, sent_ref, sent_pred):
|
|
|
426 |
sent_ref_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
|
427 |
sent_pred_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
|
428 |
backpointers = wedit_distance_align(homogenise(sent_ref_tok), homogenise(sent_pred_tok))
|
429 |
alignment, current_word, seen1, seen2, last_weight = [], ['', ''], [], [], 0
|
|
|
|
|
430 |
for i_ref, i_pred, weight in backpointers:
|
431 |
if i_ref == 0 and i_pred == 0:
|
432 |
continue
|
|
|
439 |
seen1.append(i_ref)
|
440 |
seen2.append(i_pred)
|
441 |
else:
|
442 |
+
end_space = '░'
|
443 |
if i_ref <= len(sent_ref_tok) and i_ref not in seen1:
|
444 |
if i_ref > 0:
|
445 |
current_word[0] += sent_ref_tok[i_ref-1]
|
|
|
467 |
|
468 |
|
469 |
def get_char_idx_align(self, sent_ref, sent_pred, alignment):
|
470 |
+
#sent_ref = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
|
471 |
+
#sent_pred = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
|
472 |
|
473 |
covered_ref, covered_pred = 0, 0
|
474 |
ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
|
475 |
pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
|
476 |
align_idx = []
|
477 |
+
|
|
|
478 |
for a_ref, a_pred, _ in alignment:
|
479 |
if a_ref == '' and a_pred == '':
|
480 |
continue
|
|
|
|
|
|
|
481 |
a_pred = re.sub(' +', '', a_pred).strip()
|
482 |
span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref) - 1]]
|
483 |
covered_ref += len(a_ref)
|
484 |
span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + max(0, len(a_pred) - 1)]]
|
485 |
covered_pred += max(0, len(a_pred))
|
486 |
align_idx.append((span_ref, span_pred))
|
487 |
+
|
|
|
488 |
return align_idx
|
489 |
|
490 |
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
491 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
492 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
493 |
normalisation_pipeline = NormalisationPipeline(model=model,
|
494 |
+
tokenizer=tokeniser,
|
495 |
+
batch_size=batch_size,
|
496 |
+
beam_size=beam_size,
|
497 |
+
cache_file="/home/rbawden/scratch/.normalisation_lefff.pickle")
|
498 |
normalised_outputs = normalisation_pipeline(list_sents)
|
499 |
return normalised_outputs
|
500 |
|
|
|
503 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
504 |
normalisation_pipeline = NormalisationPipeline(model=model,
|
505 |
tokenizer=tokeniser,
|
506 |
+
batch_size=batch_size,
|
507 |
+
beam_size=beam_size,
|
508 |
+
cache_file="/home/rbawden/scratch/.normalisation_lefff.pickle")
|
509 |
list_sents = []
|
510 |
for sent in sys.stdin:
|
511 |
list_sents.append(sent.strip())
|
512 |
normalised_outputs = normalisation_pipeline(list_sents)
|
|
|
513 |
for s, sent in enumerate(normalised_outputs):
|
514 |
alignment=sent['alignment']
|
515 |
+
|
516 |
+
# printing in order to debug
|
|
|
517 |
print('src = ', list_sents[s])
|
518 |
print('norm = ', sent)
|
519 |
+
# checking that the alignment makes sense
|
520 |
for b, a in alignment:
|
|
|
|
|
|
|
521 |
print('input: ' + ''.join([list_sents[s][x] for x in range(b[0], max(len(b), b[1]+1))]) + '')
|
522 |
print('pred: ' + ''.join([sent['text'][x] for x in range(a[0], max(len(a), a[1]+1))]) + '')
|
523 |
|