Upload pipeline.py
Browse files- pipeline.py +241 -12
pipeline.py
CHANGED
@@ -4,12 +4,168 @@ from transformers.tokenization_utils_base import TruncationStrategy
|
|
4 |
from torch import Tensor
|
5 |
import html.parser
|
6 |
import unicodedata
|
7 |
-
import sys, os
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class NormalisationPipeline(Pipeline):
|
10 |
|
11 |
-
def __init__(self, beam_size=5, batch_size=32, **kwargs):
|
12 |
self.beam_size = beam_size
|
|
|
|
|
|
|
|
|
|
|
13 |
super().__init__(**kwargs)
|
14 |
|
15 |
|
@@ -141,15 +297,81 @@ class NormalisationPipeline(Pipeline):
|
|
141 |
"""
|
142 |
|
143 |
result = super().__call__(*args, **kwargs)
|
144 |
-
if (
|
145 |
-
isinstance(args[0], list)
|
146 |
and all(isinstance(el, str) for el in args[0])
|
147 |
-
and all(len(res) == 1 for res in result)
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
154 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
155 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
@@ -169,10 +391,17 @@ def normalise_from_stdin(batch_size=32, beam_size=5):
|
|
169 |
beam_size=beam_size)
|
170 |
list_sents = []
|
171 |
for sent in sys.stdin:
|
172 |
-
list_sents.append(sent)
|
173 |
normalised_outputs = normalisation_pipeline(list_sents)
|
174 |
-
for sent in normalised_outputs:
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
return normalised_outputs
|
177 |
|
178 |
|
|
|
4 |
from torch import Tensor
|
5 |
import html.parser
|
6 |
import unicodedata
|
7 |
+
import sys, os
|
8 |
+
import re
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
import operator
|
11 |
+
|
12 |
+
|
13 |
+
def basic_tokenise(string):
|
14 |
+
# separate punctuation
|
15 |
+
for char in r',.;?!:)("…-':
|
16 |
+
string = re.sub('(?<! )' + re.escape(char) + '+', ' ' + char, string)
|
17 |
+
for char in '\'"’':
|
18 |
+
string = re.sub(char + '(?! )' , char + ' ', string)
|
19 |
+
return string.strip()
|
20 |
+
|
21 |
+
def homogenise(sent):
|
22 |
+
sent = sent.lower()
|
23 |
+
# sent = sent.replace("oe", "œ").replace("OE", "Œ")
|
24 |
+
replace_from = "ǽǣáàâäąãăåćčçďéèêëęěğìíîĩĭıïĺľłńñňòóôõöøŕřśšşťţùúûũüǔỳýŷÿźẑżžÁÀÂÄĄÃĂÅĆČÇĎÉÈÊËĘĚĞÌÍÎĨĬİÏĹĽŁŃÑŇÒÓÔÕÖØŔŘŚŠŞŤŢÙÚÛŨÜǓỲÝŶŸŹẐŻŽſ"
|
25 |
+
replace_into = "ææaaaaaaaacccdeeeeeegiiiiiiilllnnnoooooorrsssttuuuuuuyyyyzzzzAAAAAAAACCCDEEEEEEGIIIIIIILLLNNNOOOOOORRSSSTTUUUUUUYYYYZZZZs"
|
26 |
+
table = sent.maketrans(replace_from, replace_into)
|
27 |
+
return sent.translate(table)
|
28 |
+
|
29 |
+
######## Edit distance functions #######
|
30 |
+
def _wedit_dist_init(len1, len2):
|
31 |
+
lev = []
|
32 |
+
for i in range(len1):
|
33 |
+
lev.append([0] * len2) # initialize 2D array to zero
|
34 |
+
for i in range(len1):
|
35 |
+
lev[i][0] = i # column 0: 0,1,2,3,4,...
|
36 |
+
for j in range(len2):
|
37 |
+
lev[0][j] = j # row 0: 0,1,2,3,4,...
|
38 |
+
return lev
|
39 |
+
|
40 |
+
|
41 |
+
def _wedit_dist_step(
|
42 |
+
lev, i, j, s1, s2, last_left, last_right, transpositions=False
|
43 |
+
):
|
44 |
+
c1 = s1[i - 1]
|
45 |
+
c2 = s2[j - 1]
|
46 |
+
|
47 |
+
# skipping a character in s1
|
48 |
+
a = lev[i - 1][j] + _wedit_dist_deletion_cost(c1,c2)
|
49 |
+
# skipping a character in s2
|
50 |
+
b = lev[i][j - 1] + _wedit_dist_insertion_cost(c1,c2)
|
51 |
+
# substitution
|
52 |
+
c = lev[i - 1][j - 1] + (_wedit_dist_substitution_cost(c1, c2) if c1 != c2 else 0)
|
53 |
+
|
54 |
+
# pick the cheapest
|
55 |
+
lev[i][j] = min(a, b, c)#, d)
|
56 |
+
|
57 |
+
def _wedit_dist_backtrace(lev):
|
58 |
+
i, j = len(lev) - 1, len(lev[0]) - 1
|
59 |
+
alignment = [(i, j, lev[i][j])]
|
60 |
+
|
61 |
+
while (i, j) != (0, 0):
|
62 |
+
directions = [
|
63 |
+
(i - 1, j), # skip s1
|
64 |
+
(i, j - 1), # skip s2
|
65 |
+
(i - 1, j - 1), # substitution
|
66 |
+
]
|
67 |
+
|
68 |
+
direction_costs = (
|
69 |
+
(lev[i][j] if (i >= 0 and j >= 0) else float("inf"), (i, j))
|
70 |
+
for i, j in directions
|
71 |
+
)
|
72 |
+
_, (i, j) = min(direction_costs, key=operator.itemgetter(0))
|
73 |
+
|
74 |
+
alignment.append((i, j, lev[i][j]))
|
75 |
+
return list(reversed(alignment))
|
76 |
+
|
77 |
+
def _wedit_dist_substitution_cost(c1, c2):
|
78 |
+
if c1 == ' ' and c2 != ' ':
|
79 |
+
return 1000000
|
80 |
+
if c2 == ' ' and c1 != ' ':
|
81 |
+
return 30
|
82 |
+
for c in ",.;-!?'":
|
83 |
+
if c1 == c and c2 != c:
|
84 |
+
return 20
|
85 |
+
if c2 == c and c1 != c:
|
86 |
+
return 20
|
87 |
+
return 1
|
88 |
+
|
89 |
+
def _wedit_dist_deletion_cost(c1, c2):
|
90 |
+
if c1 == ' ':
|
91 |
+
return 2
|
92 |
+
if c2 == ' ':
|
93 |
+
return 1000000
|
94 |
+
return 0.8
|
95 |
+
|
96 |
+
def _wedit_dist_insertion_cost(c1, c2):
|
97 |
+
if c1 == ' ':
|
98 |
+
return 1000000
|
99 |
+
if c2 == ' ':
|
100 |
+
return 2
|
101 |
+
return 0.8
|
102 |
+
|
103 |
+
def wedit_distance_align(s1, s2):
|
104 |
+
"""
|
105 |
+
Calculate the minimum Levenshtein edit-distance based alignment
|
106 |
+
mapping between two strings. The alignment finds the mapping
|
107 |
+
from string s1 to s2 that minimizes the edit distance cost.
|
108 |
+
For example, mapping "rain" to "shine" would involve 2
|
109 |
+
substitutions, 2 matches and an insertion resulting in
|
110 |
+
the following mapping:
|
111 |
+
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (4, 5)]
|
112 |
+
NB: (0, 0) is the start state without any letters associated
|
113 |
+
See more: https://web.stanford.edu/class/cs124/lec/med.pdf
|
114 |
+
In case of multiple valid minimum-distance alignments, the
|
115 |
+
backtrace has the following operation precedence:
|
116 |
+
1. Skip s1 character
|
117 |
+
2. Skip s2 character
|
118 |
+
3. Substitute s1 and s2 characters
|
119 |
+
The backtrace is carried out in reverse string order.
|
120 |
+
This function does not support transposition.
|
121 |
+
:param s1, s2: The strings to be aligned
|
122 |
+
:type s1: str
|
123 |
+
:type s2: str
|
124 |
+
:rtype: List[Tuple(int, int)]
|
125 |
+
"""
|
126 |
+
# set up a 2-D array
|
127 |
+
len1 = len(s1)
|
128 |
+
len2 = len(s2)
|
129 |
+
lev = _wedit_dist_init(len1 + 1, len2 + 1)
|
130 |
+
|
131 |
+
# iterate over the array
|
132 |
+
for i in range(len1):
|
133 |
+
for j in range(len2):
|
134 |
+
_wedit_dist_step(
|
135 |
+
lev,
|
136 |
+
i + 1,
|
137 |
+
j + 1,
|
138 |
+
s1,
|
139 |
+
s2,
|
140 |
+
0,
|
141 |
+
0,
|
142 |
+
transpositions=False,
|
143 |
+
)
|
144 |
+
|
145 |
+
# backtrace to find alignment
|
146 |
+
alignment = _wedit_dist_backtrace(lev)
|
147 |
+
return alignment
|
148 |
+
|
149 |
+
def space_after(idx, sent):
|
150 |
+
if idx < len(sent) -1 and sent[idx + 1] == ' ':
|
151 |
+
return True
|
152 |
+
return False
|
153 |
+
|
154 |
+
def space_before(idx, sent):
|
155 |
+
if idx > 0 and sent[idx - 1] == ' ':
|
156 |
+
return True
|
157 |
+
return False
|
158 |
+
|
159 |
+
######## Normaliation pipeline #########
|
160 |
class NormalisationPipeline(Pipeline):
|
161 |
|
162 |
+
def __init__(self, beam_size=5, batch_size=32, tokenise_func=None, **kwargs):
|
163 |
self.beam_size = beam_size
|
164 |
+
# classic tokeniser function (used for alignments)
|
165 |
+
if tokenise_func is not None:
|
166 |
+
self.classic_tokenise = tokenise_func
|
167 |
+
else:
|
168 |
+
self.classic_tokenise = basic_tokenise
|
169 |
super().__init__(**kwargs)
|
170 |
|
171 |
|
|
|
297 |
"""
|
298 |
|
299 |
result = super().__call__(*args, **kwargs)
|
300 |
+
if (isinstance(args[0], list)
|
|
|
301 |
and all(isinstance(el, str) for el in args[0])
|
302 |
+
and all(len(res) == 1 for res in result)):
|
303 |
+
output = []
|
304 |
+
for i in range(len(result)):
|
305 |
+
input_sent, pred_sent = args[0][i].strip(), result[i][0]['text'].strip()
|
306 |
+
alignment = self.align(input_sent, pred_sent)
|
307 |
+
char_spans = self.get_char_idx_align(input_sent, pred_sent, alignment)
|
308 |
+
output.append({'text': result[i][0]['text'], 'alignment': char_spans})
|
309 |
+
return output
|
310 |
+
|
311 |
+
else:
|
312 |
+
return [{'text': result, 'alignment': self.align(args, result[0]['text'].strip())}]
|
313 |
+
|
314 |
+
def align(self, sent_ref, sent_pred):
|
315 |
+
backpointers = wedit_distance_align(homogenise(self.classic_tokenise(re.sub('[ ]', ' ', sent_ref))),
|
316 |
+
homogenise(self.classic_tokenise(re.sub('[ ]', ' ', sent_pred))))
|
317 |
+
alignment, current_word, seen1, seen2, last_weight = [], ['', ''], [], [], 0
|
318 |
|
319 |
+
print(homogenise(sent_ref), homogenise(sent_pred))
|
320 |
+
for i_ref, i_pred, weight in backpointers:
|
321 |
+
if i_ref == 0 and i_pred == 0:
|
322 |
+
continue
|
323 |
+
# spaces in both, add straight away
|
324 |
+
if i_ref <= len(sent_ref) and sent_ref[i_ref-1] == ' ' and \
|
325 |
+
i_pred <= len(sent_pred) and sent_pred[i_pred-1] == ' ':
|
326 |
+
alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
|
327 |
+
last_weight = weight
|
328 |
+
current_word = ['', '']
|
329 |
+
seen1.append(i_ref)
|
330 |
+
seen2.append(i_pred)
|
331 |
+
else:
|
332 |
+
end_space = '░'
|
333 |
+
if i_ref <= len(sent_ref) and i_ref not in seen1:
|
334 |
+
if i_ref > 0:
|
335 |
+
current_word[0] += sent_ref[i_ref-1]
|
336 |
+
seen1.append(i_ref)
|
337 |
+
if i_pred <= len(sent_pred) and i_pred not in seen2:
|
338 |
+
if i_pred > 0:
|
339 |
+
current_word[1] += sent_pred[i_pred-1] if sent_pred[i_pred-1] != ' ' else '▁'
|
340 |
+
end_space = '' if space_after(i_pred, sent_pred) else '░'
|
341 |
+
seen2.append(i_pred)
|
342 |
+
if i_ref <= len(sent_ref) and sent_ref[i_ref-1] == ' ' and current_word[0].strip() != '':
|
343 |
+
alignment.append((current_word[0].strip(), current_word[1].strip() + end_space, weight-last_weight))
|
344 |
+
last_weight = weight
|
345 |
+
current_word = ['', '']
|
346 |
+
# final word
|
347 |
+
alignment.append((current_word[0].strip(), current_word[1].strip(), weight-last_weight))
|
348 |
+
# check that both strings are entirely covered
|
349 |
+
recovered1 = re.sub(' +', ' ', ' '.join([x[0] for x in alignment]))
|
350 |
+
recovered2 = re.sub(' +', ' ', ' '.join([x[1] for x in alignment]))
|
351 |
|
352 |
+
assert recovered1 == re.sub(' +', ' ', sent_ref), \
|
353 |
+
'\n1: ' + re.sub(' +', ' ', recovered1) + "\n1: " + re.sub(' +', ' ', sent_ref)
|
354 |
+
assert re.sub('[░▁ ]+', '', recovered2) == re.sub('[▁ ]+', '', sent_pred), \
|
355 |
+
'\n2: ' + re.sub(' +', ' ', recovered2) + "\n2: " + re.sub(' +', ' ', sent_pred)
|
356 |
+
return alignment
|
357 |
+
|
358 |
+
|
359 |
+
def get_char_idx_align(self, sent_ref, sent_pred, alignment):
|
360 |
+
covered_ref, covered_pred = 0, 0
|
361 |
+
ref_chars = [i for i, character in enumerate(sent_ref) if character not in [' ']]
|
362 |
+
pred_chars = [i for i, character in enumerate(sent_pred) if character not in [' ']]
|
363 |
+
align_idx = []
|
364 |
+
for a_ref, a_pred, _ in alignment:
|
365 |
+
if a_ref == '' and a_pred == '':
|
366 |
+
continue
|
367 |
+
a_pred = re.sub('[░▁ ]+', '', a_pred).strip()
|
368 |
+
span_ref = [ref_chars[covered_ref], ref_chars[covered_ref + len(a_ref) - 1]]
|
369 |
+
covered_ref += len(a_ref)
|
370 |
+
span_pred = [pred_chars[covered_pred], pred_chars[covered_pred + max(0, len(a_pred) - 1)]]
|
371 |
+
covered_pred += max(0, len(a_pred))
|
372 |
+
align_idx.append((span_ref, span_pred))
|
373 |
+
return align_idx
|
374 |
+
|
375 |
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
376 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
377 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
|
|
391 |
beam_size=beam_size)
|
392 |
list_sents = []
|
393 |
for sent in sys.stdin:
|
394 |
+
list_sents.append(sent.strip())
|
395 |
normalised_outputs = normalisation_pipeline(list_sents)
|
396 |
+
for s, sent in enumerate(normalised_outputs):
|
397 |
+
alignment=sent['alignment']
|
398 |
+
print(list_sents[s], len(list_sents[s]))
|
399 |
+
print(sent['text'], len(sent['text']))
|
400 |
+
print(sent['alignment'])
|
401 |
+
#for b, a in alignment:
|
402 |
+
# print('input: [' + ''.join([list_sents[s][x] for x in range(b[0], b[1]+1)]) + ']')
|
403 |
+
# print('pred: [' + ''.join([sent['text'][x] for x in range(a[0], a[1]+1)]) + ']')
|
404 |
+
|
405 |
return normalised_outputs
|
406 |
|
407 |
|