rbawden commited on
Commit
ad5b64b
1 Parent(s): ae761aa

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +241 -12
pipeline.py CHANGED
@@ -4,12 +4,168 @@ from transformers.tokenization_utils_base import TruncationStrategy
4
  from torch import Tensor
5
  import html.parser
6
  import unicodedata
7
- import sys, os, re
8
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class NormalisationPipeline(Pipeline):
10
 
11
- def __init__(self, beam_size=5, batch_size=32, **kwargs):
12
  self.beam_size = beam_size
 
 
 
 
 
13
  super().__init__(**kwargs)
14
 
15
 
@@ -141,15 +297,81 @@ class NormalisationPipeline(Pipeline):
141
  """
142
 
143
  result = super().__call__(*args, **kwargs)
144
- if (
145
- isinstance(args[0], list)
146
  and all(isinstance(el, str) for el in args[0])
147
- and all(len(res) == 1 for res in result)
148
- ):
149
- return [res[0] for res in result]
150
- return result
 
 
 
 
 
 
 
 
 
 
 
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def normalise_text(list_sents, batch_size=32, beam_size=5):
154
  tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
155
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
@@ -169,10 +391,17 @@ def normalise_from_stdin(batch_size=32, beam_size=5):
169
  beam_size=beam_size)
170
  list_sents = []
171
  for sent in sys.stdin:
172
- list_sents.append(sent)
173
  normalised_outputs = normalisation_pipeline(list_sents)
174
- for sent in normalised_outputs:
175
- print(sent['text'].strip())
 
 
 
 
 
 
 
176
  return normalised_outputs
177
 
178
 
 
4
  from torch import Tensor
5
  import html.parser
6
  import unicodedata
7
+ import sys, os
8
+ import re
9
+ from tqdm.auto import tqdm
10
+ import operator
11
+
12
+
13
+ def basic_tokenise(string):
14
+ # separate punctuation
15
+ for char in r',.;?!:)("…-':
16
+ string = re.sub('(?<! )' + re.escape(char) + '+', ' ' + char, string)
17
+ for char in '\'"’':
18
+ string = re.sub(char + '(?! )' , char + ' ', string)
19
+ return string.strip()
20
+
21
+ def homogenise(sent):
22
+ sent = sent.lower()
23
+ # sent = sent.replace("oe", "œ").replace("OE", "Œ")
24
+ replace_from = "ǽǣáàâäąãăåćčçďéèêëęěğìíîĩĭıïĺľłńñňòóôõöøŕřśšşťţùúûũüǔỳýŷÿźẑżžÁÀÂÄĄÃĂÅĆČÇĎÉÈÊËĘĚĞÌÍÎĨĬİÏĹĽŁŃÑŇÒÓÔÕÖØŔŘŚŠŞŤŢÙÚÛŨÜǓỲÝŶŸŹẐŻŽſ"
25
+ replace_into = "ææaaaaaaaacccdeeeeeegiiiiiiilllnnnoooooorrsssttuuuuuuyyyyzzzzAAAAAAAACCCDEEEEEEGIIIIIIILLLNNNOOOOOORRSSSTTUUUUUUYYYYZZZZs"
26
+ table = sent.maketrans(replace_from, replace_into)
27
+ return sent.translate(table)
28
+
29
+ ######## Edit distance functions #######
30
+ def _wedit_dist_init(len1, len2):
31
+ lev = []
32
+ for i in range(len1):
33
+ lev.append([0] * len2) # initialize 2D array to zero
34
+ for i in range(len1):
35
+ lev[i][0] = i # column 0: 0,1,2,3,4,...
36
+ for j in range(len2):
37
+ lev[0][j] = j # row 0: 0,1,2,3,4,...
38
+ return lev
39
+
40
+
41
+ def _wedit_dist_step(
42
+ lev, i, j, s1, s2, last_left, last_right, transpositions=False
43
+ ):
44
+ c1 = s1[i - 1]
45
+ c2 = s2[j - 1]
46
+
47
+ # skipping a character in s1
48
+ a = lev[i - 1][j] + _wedit_dist_deletion_cost(c1,c2)
49
+ # skipping a character in s2
50
+ b = lev[i][j - 1] + _wedit_dist_insertion_cost(c1,c2)
51
+ # substitution
52
+ c = lev[i - 1][j - 1] + (_wedit_dist_substitution_cost(c1, c2) if c1 != c2 else 0)
53
+
54
+ # pick the cheapest
55
+ lev[i][j] = min(a, b, c)#, d)
56
+
57
+ def _wedit_dist_backtrace(lev):
58
+ i, j = len(lev) - 1, len(lev[0]) - 1
59
+ alignment = [(i, j, lev[i][j])]
60
+
61
+ while (i, j) != (0, 0):
62
+ directions = [
63
+ (i - 1, j), # skip s1
64
+ (i, j - 1), # skip s2
65
+ (i - 1, j - 1), # substitution
66
+ ]
67
+
68
+ direction_costs = (
69
+ (lev[i][j] if (i >= 0 and j >= 0) else float("inf"), (i, j))
70
+ for i, j in directions
71
+ )
72
+ _, (i, j) = min(direction_costs, key=operator.itemgetter(0))
73
+
74
+ alignment.append((i, j, lev[i][j]))
75
+ return list(reversed(alignment))
76
+
77
+ def _wedit_dist_substitution_cost(c1, c2):
78
+ if c1 == ' ' and c2 != ' ':
79
+ return 1000000
80
+ if c2 == ' ' and c1 != ' ':
81
+ return 30
82
+ for c in ",.;-!?'":
83
+ if c1 == c and c2 != c:
84
+ return 20
85
+ if c2 == c and c1 != c:
86
+ return 20
87
+ return 1
88
+
89
+ def _wedit_dist_deletion_cost(c1, c2):
90
+ if c1 == ' ':
91
+ return 2
92
+ if c2 == ' ':
93
+ return 1000000
94
+ return 0.8
95
+
96
+ def _wedit_dist_insertion_cost(c1, c2):
97
+ if c1 == ' ':
98
+ return 1000000
99
+ if c2 == ' ':
100
+ return 2
101
+ return 0.8
102
+
103
+ def wedit_distance_align(s1, s2):
104
+ """
105
+ Calculate the minimum Levenshtein edit-distance based alignment
106
+ mapping between two strings. The alignment finds the mapping
107
+ from string s1 to s2 that minimizes the edit distance cost.
108
+ For example, mapping "rain" to "shine" would involve 2
109
+ substitutions, 2 matches and an insertion resulting in
110
+ the following mapping:
111
+ [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (4, 5)]
112
+ NB: (0, 0) is the start state without any letters associated
113
+ See more: https://web.stanford.edu/class/cs124/lec/med.pdf
114
+ In case of multiple valid minimum-distance alignments, the
115
+ backtrace has the following operation precedence:
116
+ 1. Skip s1 character
117
+ 2. Skip s2 character
118
+ 3. Substitute s1 and s2 characters
119
+ The backtrace is carried out in reverse string order.
120
+ This function does not support transposition.
121
+ :param s1, s2: The strings to be aligned
122
+ :type s1: str
123
+ :type s2: str
124
+ :rtype: List[Tuple(int, int)]
125
+ """
126
+ # set up a 2-D array
127
+ len1 = len(s1)
128
+ len2 = len(s2)
129
+ lev = _wedit_dist_init(len1 + 1, len2 + 1)
130
+
131
+ # iterate over the array
132
+ for i in range(len1):
133
+ for j in range(len2):
134
+ _wedit_dist_step(
135
+ lev,
136
+ i + 1,
137
+ j + 1,
138
+ s1,
139
+ s2,
140
+ 0,
141
+ 0,
142
+ transpositions=False,
143
+ )
144
+
145
+ # backtrace to find alignment
146
+ alignment = _wedit_dist_backtrace(lev)
147
+ return alignment
148
+
149
+ def space_after(idx, sent):
150
+ if idx < len(sent) -1 and sent[idx + 1] == ' ':
151
+ return True
152
+ return False
153
+
154
+ def space_before(idx, sent):
155
+ if idx > 0 and sent[idx - 1] == ' ':
156
+ return True
157
+ return False
158
+
159
+ ######## Normaliation pipeline #########
160
  class NormalisationPipeline(Pipeline):
161
 
162
+ def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, **kwargs):
163
  self.beam_size = beam_size
164
+ # classic tokeniser function (used for alignments)
165
+ if tokenise_func is not None:
166
+ self.classic_tokenise = tokenise_func
167
+ else:
168
+ self.classic_tokenise = basic_tokenise
169
  super().__init__(**kwargs)
170
 
171
 
 
297
  """
298
 
299
  result = super().__call__(*args, **kwargs)
300
+ if (isinstance(args[0], list)
 
301
  and all(isinstance(el, str) for el in args[0])
302
+ and all(len(res) == 1 for res in result)):
303
+ output = []
304
+ for i in range(len(result)):
305
+ input_sent, pred_sent = args[0][i].strip(), result[i][0]['text'].strip()
306
+ alignment = self.align(input_sent, pred_sent)
307
+ char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
308
+ output.append({'text': result[i][0]['text'], 'alignment': char_spans})
309
+ return output
310
+
311
+ else:
312
+ return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
313
+
314
+ def align(self, sent_ref, sent_pred):
315
+ backpointers = wedit_distance_align(homogenise(self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))),
316
+ homogenise(self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))))
317
+ alignment, current_word, seen1, seen2, last_weight = [], ['', ''], [], [], 0
318
 
319
+ print(homogenise(sent_ref), homogenise(sent_pred))
320
+ for i_ref, i_pred, weight in backpointers:
321
+ if i_ref == 0 and i_pred == 0:
322
+ continue
323
+ # spaces in both, add straight away
324
+ if i_ref <= len(sent_ref) and sent_ref[i_ref-1] == ' ' and \
325
+ i_pred <= len(sent_pred) and sent_pred[i_pred-1] == ' ':
326
+ alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
327
+ last_weight = weight
328
+ current_word = ['', '']
329
+ seen1.append(i_ref)
330
+ seen2.append(i_pred)
331
+ else:
332
+ end_space = '░'
333
+ if i_ref <= len(sent_ref) and i_ref not in seen1:
334
+ if i_ref > 0:
335
+ current_word[0] += sent_ref[i_ref-1]
336
+ seen1.append(i_ref)
337
+ if i_pred <= len(sent_pred) and i_pred not in seen2:
338
+ if i_pred > 0:
339
+ current_word[1] += sent_pred[i_pred-1] if sent_pred[i_pred-1] != ' ' else '▁'
340
+ end_space = '' if space_after(i_pred, sent_pred) else '░'
341
+ seen2.append(i_pred)
342
+ if i_ref <= len(sent_ref) and sent_ref[i_ref-1] == ' ' and current_word[0].strip() != '':
343
+ alignment.append((current_word[0].strip(), current_word[1].strip() + end_space, weight-last_weight))
344
+ last_weight = weight
345
+ current_word = ['', '']
346
+ # final word
347
+ alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
348
+ # check that both strings are entirely covered
349
+ recovered1 = re.sub(' +', ' ', ' '.join([x[0] for x in alignment]))
350
+ recovered2 = re.sub(' +', ' ', ' '.join([x[1] for x in alignment]))
351
 
352
+ assert recovered1 == re.sub(' +', ' ', sent_ref), \
353
+ '\n1: ' + re.sub(' +', ' ', recovered1) + "\n1: " + re.sub(' +', ' ', sent_ref)
354
+ assert re.sub('[░▁ ]+', '', recovered2) == re.sub('[▁ ]+', '', sent_pred), \
355
+ '\n2: ' + re.sub(' +', ' ', recovered2) + "\n2: " + re.sub(' +', ' ', sent_pred)
356
+ return alignment
357
+
358
+
359
+ def get_char_idx_align(self, sent_ref, sent_pred, alignment):
360
+ covered_ref, covered_pred = 0, 0
361
+ ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
362
+ pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
363
+ align_idx = []
364
+ for a_ref, a_pred, _ in alignment:
365
+ if a_ref == '' and a_pred == '':
366
+ continue
367
+ a_pred = re.sub('[░▁ ]+', '', a_pred).strip()
368
+ span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref) - 1]]
369
+ covered_ref += len(a_ref)
370
+ span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + max(0, len(a_pred) - 1)]]
371
+ covered_pred += max(0, len(a_pred))
372
+ align_idx.append((span_ref, span_pred))
373
+ return align_idx
374
+
375
  def normalise_text(list_sents, batch_size=32, beam_size=5):
376
  tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
377
  model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
 
391
  beam_size=beam_size)
392
  list_sents = []
393
  for sent in sys.stdin:
394
+ list_sents.append(sent.strip())
395
  normalised_outputs = normalisation_pipeline(list_sents)
396
+ for s, sent in enumerate(normalised_outputs):
397
+ alignment=sent['alignment']
398
+ print(list_sents[s], len(list_sents[s]))
399
+ print(sent['text'], len(sent['text']))
400
+ print(sent['alignment'])
401
+ #for b, a in alignment:
402
+ # print('input: [' + ''.join([list_sents[s][x] for x in range(b[0], b[1]+1)]) + ']')
403
+ # print('pred: [' + ''.join([sent['text'][x] for x in range(a[0], a[1]+1)]) + ']')
404
+
405
  return normalised_outputs
406
 
407