rbawden commited on
Commit
94cfe5a
1 Parent(s): ad5b64b

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +163 -31
pipeline.py CHANGED
@@ -8,6 +8,7 @@ import sys, os
8
  import re
9
  from tqdm.auto import tqdm
10
  import operator
 
11
 
12
 
13
  def basic_tokenise(string):
@@ -166,8 +167,30 @@ class NormalisationPipeline(Pipeline):
166
  self.classic_tokenise = tokenise_func
167
  else:
168
  self.classic_tokenise = basic_tokenise
 
 
 
169
  super().__init__(**kwargs)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
173
  preprocess_params = {}
@@ -262,13 +285,92 @@ class NormalisationPipeline(Pipeline):
262
  records.append(record)
263
  return records
264
 
265
- def correct_hallunications(self, orig, output):
266
- # align the original and output tokens
267
-
268
- # check that the correspondences are legitimate and correct if not
269
-
270
- # replace <EMOJI> symbols by the original ones
271
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  def __call__(self, *args, **kwargs):
274
  r"""
@@ -303,8 +405,20 @@ class NormalisationPipeline(Pipeline):
303
  output = []
304
  for i in range(len(result)):
305
  input_sent, pred_sent = args[0][i].strip(), result[i][0]['text'].strip()
306
- alignment = self.align(input_sent, pred_sent)
 
 
 
 
 
 
 
 
 
 
 
307
  char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
 
308
  output.append({'text': result[i][0]['text'], 'alignment': char_spans})
309
  return output
310
 
@@ -312,34 +426,36 @@ class NormalisationPipeline(Pipeline):
312
  return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
313
 
314
  def align(self, sent_ref, sent_pred):
315
- backpointers = wedit_distance_align(homogenise(self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))),
316
- homogenise(self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))))
 
 
317
  alignment, current_word, seen1, seen2, last_weight = [], ['', ''], [], [], 0
318
 
319
- print(homogenise(sent_ref), homogenise(sent_pred))
320
  for i_ref, i_pred, weight in backpointers:
321
  if i_ref == 0 and i_pred == 0:
322
  continue
323
  # spaces in both, add straight away
324
- if i_ref <= len(sent_ref) and sent_ref[i_ref-1] == ' ' and \
325
- i_pred <= len(sent_pred) and sent_pred[i_pred-1] == ' ':
326
  alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
327
  last_weight = weight
328
  current_word = ['', '']
329
  seen1.append(i_ref)
330
  seen2.append(i_pred)
331
  else:
332
- end_space = '░'
333
- if i_ref <= len(sent_ref) and i_ref not in seen1:
334
  if i_ref > 0:
335
- current_word[0] += sent_ref[i_ref-1]
336
  seen1.append(i_ref)
337
- if i_pred <= len(sent_pred) and i_pred not in seen2:
338
  if i_pred > 0:
339
- current_word[1] += sent_pred[i_pred-1] if sent_pred[i_pred-1] != ' ' else '▁'
340
- end_space = '' if space_after(i_pred, sent_pred) else '░'
341
  seen2.append(i_pred)
342
- if i_ref <= len(sent_ref) and sent_ref[i_ref-1] == ' ' and current_word[0].strip() != '':
343
  alignment.append((current_word[0].strip(), current_word[1].strip() + end_space, weight-last_weight))
344
  last_weight = weight
345
  current_word = ['', '']
@@ -349,27 +465,37 @@ class NormalisationPipeline(Pipeline):
349
  recovered1 = re.sub(' +', ' ', ' '.join([x[0] for x in alignment]))
350
  recovered2 = re.sub(' +', ' ', ' '.join([x[1] for x in alignment]))
351
 
352
- assert recovered1 == re.sub(' +', ' ', sent_ref), \
353
- '\n1: ' + re.sub(' +', ' ', recovered1) + "\n1: " + re.sub(' +', ' ', sent_ref)
354
- assert re.sub('[░▁ ]+', '', recovered2) == re.sub('[▁ ]+', '', sent_pred), \
355
- '\n2: ' + re.sub(' +', ' ', recovered2) + "\n2: " + re.sub(' +', ' ', sent_pred)
356
- return alignment
357
 
358
 
359
  def get_char_idx_align(self, sent_ref, sent_pred, alignment):
 
 
 
360
  covered_ref, covered_pred = 0, 0
361
  ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
362
  pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
363
  align_idx = []
 
 
364
  for a_ref, a_pred, _ in alignment:
365
  if a_ref == '' and a_pred == '':
366
  continue
367
- a_pred = re.sub('[░▁ ]+', '', a_pred).strip()
 
 
 
368
  span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref) - 1]]
369
  covered_ref += len(a_ref)
370
  span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + max(0, len(a_pred) - 1)]]
371
  covered_pred += max(0, len(a_pred))
372
  align_idx.append((span_ref, span_pred))
 
 
373
  return align_idx
374
 
375
  def normalise_text(list_sents, batch_size=32, beam_size=5):
@@ -393,14 +519,20 @@ def normalise_from_stdin(batch_size=32, beam_size=5):
393
  for sent in sys.stdin:
394
  list_sents.append(sent.strip())
395
  normalised_outputs = normalisation_pipeline(list_sents)
 
396
  for s, sent in enumerate(normalised_outputs):
397
  alignment=sent['alignment']
398
- print(list_sents[s], len(list_sents[s]))
399
- print(sent['text'], len(sent['text']))
400
  print(sent['alignment'])
401
- #for b, a in alignment:
402
- # print('input: [' + ''.join([list_sents[s][x] for x in range(b[0], b[1]+1)]) + ']')
403
- # print('pred: [' + ''.join([sent['text'][x] for x in range(a[0], a[1]+1)]) + ']')
 
 
 
 
 
404
 
405
  return normalised_outputs
406
 
 
8
  import re
9
  from tqdm.auto import tqdm
10
  import operator
11
+ from datasets import load_dataset
12
 
13
 
14
  def basic_tokenise(string):
 
167
  self.classic_tokenise = tokenise_func
168
  else:
169
  self.classic_tokenise = basic_tokenise
170
+
171
+ # load lexicon
172
+ self.lexicon_orig, self.lexicon_homog = self.load_lexicon()
173
  super().__init__(**kwargs)
174
 
175
+
176
+ def load_lexicon(self):
177
+ #local_file = '../data/lexicons/lefff-3.4.mlex'
178
+ orig_words = []
179
+ homog_words = {}
180
+ remove = set([])
181
+ dataset = load_dataset("sagot/lefff_morpho")
182
+
183
+ for entry_dict in dataset['test']:
184
+ entry = entry_dict['form']
185
+ orig_words.append(entry.lower())
186
+ if homogenise(entry) not in homog_words:
187
+ homog_words[homogenise(entry)] = entry
188
+ else:
189
+ remove.add(homogenise(entry))
190
+
191
+ for entry in remove:
192
+ del homog_words[entry]
193
+ return orig_words, homog_words
194
 
195
  def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
196
  preprocess_params = {}
 
285
  records.append(record)
286
  return records
287
 
288
+ def postprocess_correct_sents(self, alignment, pred_sent_tok):
289
+ #return [pred_sent]
290
+ print(alignment)
291
+ output = []
292
+ # align the two
293
+ #alignments = self.align(orig_sent, pred_sent)
294
+ # correct word by word
295
+ len_diff_orig, len_diff_pred = 0, 0
296
+ pred_idxs = []
297
+ start = 0
298
+ for i, char in enumerate(re.sub(' +', ' ', pred_sent_tok) + " "):
299
+ if char == " ":
300
+ pred_idxs.append((start, i-1))
301
+ start = i+1
302
+ print(pred_idxs)
303
+ print('°°°°°°°°°°°°°°')
304
+ suffix_pred_sent = pred_sent
305
+ for i, (orig_word, pred_word, _) in enumerate(alignment):
306
+ #print(orig_word, pred_word)
307
+ start_idx, end_idx = 1, 1
308
+ postproc_word, alignment = self.postprocess_correct_word(orig_word, pred_word, alignment)
309
+ #print(postproc_word)
310
+ # replace word in tokenised sentence
311
+
312
+
313
+ output.append(postproc_word)
314
+ return re.sub(' +', ' ', ' '.join(output)), alignment
315
+
316
+ def postprocess_correct_word(self, orig_word, pred_word, alignment):
317
+ # pred_word exists in lexicon, take it
318
+ if pred_word.lower() in self.lexicon_orig:
319
+ return pred_word, alignment
320
+ # otherwise, if original word exists, take that
321
+ if orig_word.lower() in self.lexicon_orig:
322
+ return orig_word, alignment
323
+ pred_replacement = self.lexicon_homog.get(homogenise(pred_word), None)
324
+ # otherwise if pred word is in the lexicon with some changes, take that
325
+ if pred_replacement is not None:
326
+ alignment = (alignment[0], pred_replacement, alignment[2])
327
+ return pred_replacement, alignment
328
+ orig_replacement = self.lexicon_homog.get(homogenise(orig_word), None)
329
+ # otherwise if orig word is in the lexicon with some changes, take that
330
+ if orig_replacement is not None:
331
+ alignment = (orig_replacement, alignment[1], alignment[2])
332
+ return orig_replacement, alignment
333
+ # otherwise return original word (or pred?) + postprocessing?
334
+ return orig_word, alignment
335
+
336
+ def get_caps(self, word):
337
+ first, second, allcaps = False, False, False
338
+ if len(word) > 0 and word[0].upper() == word[0]:
339
+ first = True
340
+ if len(word) > 1 and word[1].upper() == word[1]:
341
+ second = True
342
+ if word.upper() == word:
343
+ allcaps = True
344
+ return first, second, allcaps
345
+
346
+ def set_caps(self, word, first, second, allcaps):
347
+ if allcaps:
348
+ return word.upper()
349
+ elif first and second:
350
+ return word[0].upper() + word[1].upper() + word[2:]
351
+ elif first:
352
+ return word[0].upper()
353
+ elif second:
354
+ return word[1].upper()
355
+ else:
356
+ return word
357
+
358
+ def lexicon_lookup(self, candidate):
359
+ norm_candidate = homogenise(candidate.lower())
360
+ replacements = []
361
+ for candidate_word in candidate.split('▁'):
362
+ capitals = self.get_caps(candidate_word)
363
+ replacements.append([])
364
+ for word in self.lexicon:
365
+ if homogenise(word.lower()) == candidate_word:
366
+ if len(replacements[-1]) > 0:
367
+ return None # if ambiguity skip
368
+ replacements[-1].append(self.set_caps(candidate, *capitals))
369
+
370
+ if [] not in replacements:
371
+ return ' '.join([x[0] for x in replacements]) # or some better strategy
372
+ else:
373
+ return None
374
 
375
  def __call__(self, *args, **kwargs):
376
  r"""
 
405
  output = []
406
  for i in range(len(result)):
407
  input_sent, pred_sent = args[0][i].strip(), result[i][0]['text'].strip()
408
+ # correct pred sent
409
+ print('prediction = ', pred_sent)
410
+ print('source = ', input_sent)
411
+
412
+
413
+
414
+ alignment, pred_sent_tok = self.align(input_sent, pred_sent)
415
+ pred_sent, alignment = self.postprocess_correct_sents(alignment, pred_sent_tok)
416
+ print('alignment = ', alignment)
417
+ print('corrected pred = ', pred_sent)
418
+ print([x[1] for x in alignment])
419
+ print('******')
420
  char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
421
+
422
  output.append({'text': result[i][0]['text'], 'alignment': char_spans})
423
  return output
424
 
 
426
  return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
427
 
428
  def align(self, sent_ref, sent_pred):
429
+ print("*", sent_pred)
430
+ sent_ref_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
431
+ sent_pred_tok = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
432
+ backpointers = wedit_distance_align(homogenise(sent_ref_tok), homogenise(sent_pred_tok))
433
  alignment, current_word, seen1, seen2, last_weight = [], ['', ''], [], [], 0
434
 
435
+ print('before align = ', homogenise(sent_ref_tok), homogenise(sent_pred_tok))
436
  for i_ref, i_pred, weight in backpointers:
437
  if i_ref == 0 and i_pred == 0:
438
  continue
439
  # spaces in both, add straight away
440
+ if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' and \
441
+ i_pred <= len(sent_pred_tok) and sent_pred_tok[i_pred-1] == ' ':
442
  alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
443
  last_weight = weight
444
  current_word = ['', '']
445
  seen1.append(i_ref)
446
  seen2.append(i_pred)
447
  else:
448
+ end_space = '' #'░'
449
+ if i_ref <= len(sent_ref_tok) and i_ref not in seen1:
450
  if i_ref > 0:
451
+ current_word[0] += sent_ref_tok[i_ref-1]
452
  seen1.append(i_ref)
453
+ if i_pred <= len(sent_pred_tok) and i_pred not in seen2:
454
  if i_pred > 0:
455
+ current_word[1] += sent_pred_tok[i_pred-1] if sent_pred_tok[i_pred-1] != ' ' else '▁'
456
+ end_space = '' if space_after(i_pred, sent_pred_tok) else '░'
457
  seen2.append(i_pred)
458
+ if i_ref <= len(sent_ref_tok) and sent_ref_tok[i_ref-1] == ' ' and current_word[0].strip() != '':
459
  alignment.append((current_word[0].strip(), current_word[1].strip() + end_space, weight-last_weight))
460
  last_weight = weight
461
  current_word = ['', '']
 
465
  recovered1 = re.sub(' +', ' ', ' '.join([x[0] for x in alignment]))
466
  recovered2 = re.sub(' +', ' ', ' '.join([x[1] for x in alignment]))
467
 
468
+ assert recovered1 == re.sub(' +', ' ', sent_ref_tok), \
469
+ '\n1: ' + re.sub(' +', ' ', recovered1) + "\n1: " + re.sub(' +', ' ', sent_ref_tok)
470
+ assert re.sub('[░▁ ]+', '', recovered2) == re.sub('[▁ ]+', '', sent_pred_tok), \
471
+ '\n2: ' + re.sub(' +', ' ', recovered2) + "\n2: " + re.sub(' +', ' ', sent_pred_tok)
472
+ return alignment, sent_pred_tok
473
 
474
 
475
  def get_char_idx_align(self, sent_ref, sent_pred, alignment):
476
+ sent_ref = self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))
477
+ sent_pred = self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))
478
+
479
  covered_ref, covered_pred = 0, 0
480
  ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
481
  pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
482
  align_idx = []
483
+ #print(ref_chars)
484
+ #print(pred_chars)
485
  for a_ref, a_pred, _ in alignment:
486
  if a_ref == '' and a_pred == '':
487
  continue
488
+ #print('ref: ', sent_ref)
489
+ #print('pred: ', sent_pred)
490
+ #print('align: ', a_ref, a_pred)
491
+ a_pred = re.sub(' +', '', a_pred).strip()
492
  span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref) - 1]]
493
  covered_ref += len(a_ref)
494
  span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + max(0, len(a_pred) - 1)]]
495
  covered_pred += max(0, len(a_pred))
496
  align_idx.append((span_ref, span_pred))
497
+ #print(span_ref, span_pred)
498
+ #print('---')
499
  return align_idx
500
 
501
  def normalise_text(list_sents, batch_size=32, beam_size=5):
 
519
  for sent in sys.stdin:
520
  list_sents.append(sent.strip())
521
  normalised_outputs = normalisation_pipeline(list_sents)
522
+ print('norm outputs = ', normalised_outputs)
523
  for s, sent in enumerate(normalised_outputs):
524
  alignment=sent['alignment']
525
+ #print(list_sents[s], len(list_sents[s]))
526
+ #print(sent['text'], len(sent['text']))
527
  print(sent['alignment'])
528
+ print('src = ', list_sents[s])
529
+ print('norm = ', sent)
530
+ for b, a in alignment:
531
+ #print(b, a)
532
+ #print(list_sents[s])
533
+ #print(sent['text'])
534
+ print('input: ' + ''.join([list_sents[s][x] for x in range(b[0], max(len(b), b[1]+1))]) + '')
535
+ print('pred: ' + ''.join([sent['text'][x] for x in range(a[0], max(len(a), a[1]+1))]) + '')
536
 
537
  return normalised_outputs
538