Spaces:
Runtime error
Runtime error
File size: 5,604 Bytes
8044721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
def normalize_abbreviations(text):
text = text.replace(" n't ", "n't ")
text = text.replace(" N'T ", "N'T ")
text = text.replace(" 'll ", "'ll ")
text = text.replace(" 'LL ", "'LL ")
text = text.replace(" 're ", "'re ")
text = text.replace(" 'RE ", "'RE ")
text = text.replace(" 've ", "'ve ")
text = text.replace(" 'VE ", "'VE ")
text = text.replace(" 'm ", "'m ")
text = text.replace(" 'M ", "'M ")
text = text.replace(" 's ", "'s ")
text = text.replace(" 'S ", "'S ")
text = text.replace(" 'd ", "'d ")
text = text.replace(" 'D ", "'D ")
return text
def fix_quotes(text, quote_symbol='"'):
n_quotes = text.count(f" {quote_symbol}") + text.count(f"{quote_symbol} ") - text.count(f" {quote_symbol} ")
if (
n_quotes == 0
or (n_quotes % 2) == 1
or f"{quote_symbol}{quote_symbol}" in text
or f"{quote_symbol} {quote_symbol}" in text
):
return text
i, i_quote, n_changes = 0, 0, 0
while i < len(text):
if text[i] != quote_symbol or (i - 1 >= 0 and text[i - 1] != ' ' and i + 1 < len(text) and text[i + 1] != ' '):
i += 1
continue
if (i_quote % 2) == 0:
if i > 0 and text[i - 1] != ' ':
text = text[:i] + ' ' + text[i:]
i += 1
n_changes += 1
if i + 1 < len(text) and text[i + 1] == ' ':
text = text[:i + 1] + text[i + 2:]
n_changes += 1
else:
if i > 0 and text[i - 1] == ' ':
text = text[:i - 1] + text[i:]
i -= 1
n_changes += 1
if i + 1 < len(text) and text[i + 1].isalnum():
text = text[:i + 1] + ' ' + text[i + 1:]
n_changes += 1
i_quote += 1
i += 1
return text
def detokenize(tokens, compact_dashes=False):
text = ' '.join(tokens)
text = normalize_abbreviations(text)
if compact_dashes:
text = text.replace(' - ', '-')
for i in range(len(text) - 2, -1, -1):
if text[i] == '.' and (text[i + 1].isupper() or text[i + 1] in ['β', '(', '[', '{']):
text = text[:i+1] + ' ' + text[i+1:]
elif text[i] in ['?', '!', 'β¦', 'β'] and (text[i + 1].isalnum() or text[i + 1] in ['β', '(', '[', '{']):
text = text[:i+1] + ' ' + text[i+1:]
elif i > 2 and text[i] == '.' and text[i - 1] == '.' and text[i - 2] == '.' and text[i + 1] != ' ':
text = text[:i+1] + ' ' + text[i+1:]
elif i > 2 and text[i] == '.' and text[i - 1] == '.' and text[i - 2] == '.' and text[i + 1] != ' ':
text = text[:i+1] + ' ' + text[i+1:]
elif text[i] == ',' and (text[i + 1].isalpha() or text[i + 1] in ['β', '(', '[', '{']):
text = text[:i+1] + ' ' + text[i+1:]
elif text[i] in [';', ')', ']', '}', '%'] and (text[i + 1].isalnum() or text[i + 1] in ['β', '(', '[', '{']):
text = text[:i+1] + ' ' + text[i+1:]
elif text[i] == ':' and (text[i + 1] in ['β', '(', '[', '{'] or (text[i + 1].isalnum() and (not text[i + 1].isnumeric() or i - 1 < 0 or not text[i - 1].isnumeric()))):
text = text[:i+1] + ' ' + text[i+1:]
elif text[i] in ['(', '[', '{'] and text[i + 1] == ' ':
text = text[:i+1] + text[i+2:]
elif text[i] == ' ' and text[i+1] in ['.', ';', ':', '?', '!', 'β¦', ',', 'β', ')', ']']:
text = text[:i] + text[i+1:]
elif i > 0 and text[i] == ' ' and text[i - 1] in ['$', 'Β£', 'β¬'] and text[i + 1].isnumeric():
text = text[:i] + text[i+1:]
elif i > 0 and text[i] == ' ' and text[i - 1].isnumeric() and text[i + 1] == '%':
text = text[:i] + text[i+1:]
text = fix_quotes(text, '"')
text = fix_quotes(text, "'")
spans = []
word_offset, char_offset = 0, 0
for i, ch in enumerate(text):
if ch == ' ':
if tokens[word_offset][char_offset] == ' ':
char_offset += 1
continue
assert ch == tokens[word_offset][char_offset], f"{text}\n{' '.join(tokens)}\n{tokens[word_offset]}\n{char_offset} {ch}"
if char_offset == 0:
start = i
if char_offset == len(tokens[word_offset]) - 1:
end = i + 1
spans.append((start, end))
word_offset += 1
char_offset = 0
else:
char_offset += 1
return text, spans
def calculate_spans(original_spans, encoding_offsets):
span_id = 0
subword_spans = [[] for _ in original_spans]
for i, (_, end) in enumerate(encoding_offsets):
subword_spans[span_id].append(i + 1)
while original_spans[span_id][1] <= end:
span_id += 1
if span_id < len(original_spans) and end > original_spans[span_id][0]:
subword_spans[span_id].append(i + 1)
if span_id == len(original_spans):
return subword_spans
return subword_spans
def subtokenize(tokens, tokenizer, compact_dashes=False):
text, spans = detokenize(tokens, compact_dashes=compact_dashes)
encoding = tokenizer(text, return_offsets_mapping=True)
spans = calculate_spans(spans, encoding["offset_mapping"][1:-1])
subwords = encoding["input_ids"]
subword_mask = torch.zeros(len(subwords), len(spans), dtype=torch.bool)
for word_id, subword_ids in enumerate(spans):
for subword_id in subword_ids:
subword_mask[subword_id + 1, word_id] = True
return subwords, subword_mask
|