Remove unnecessary funcs and processing
Browse files
sbert-punc-case-ru/sbertpunccase.py
CHANGED
@@ -8,7 +8,6 @@ import numpy as np
|
|
8 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
9 |
import re
|
10 |
import string
|
11 |
-
from typing import List, Optional
|
12 |
|
13 |
|
14 |
TOKEN_RE = re.compile(r'-?\d*\.\d+|[a-zа-яё]+|-?[\d\+\(\)\-]+|\S', re.I)
|
@@ -84,49 +83,6 @@ def decode_label(label, classes='all'):
|
|
84 |
return INVERSE_LABELS[label]
|
85 |
|
86 |
|
87 |
-
def make_labeling(text: str):
|
88 |
-
# Разобъем предложение на слова и знаки препинания
|
89 |
-
tokens = TOKEN_RE.findall(text)
|
90 |
-
# Предобработаем слова, удалим знаки препинания и зададим метки
|
91 |
-
|
92 |
-
preprocessed_tokens = []
|
93 |
-
token_labels: List[List[str]] = []
|
94 |
-
|
95 |
-
# Убираем всю пунктуацию в начале предложения
|
96 |
-
while tokens[0] in string.punctuation:
|
97 |
-
tokens.pop(0)
|
98 |
-
|
99 |
-
for token in tokens:
|
100 |
-
if token in string.punctuation:
|
101 |
-
# Если встретился знак препинания который мы прогнозируем изменим метку предыдущего слова, иначе проигнорируем его
|
102 |
-
if token in PUNK_MAPPING:
|
103 |
-
token_labels[-1][1] = PUNK_MAPPING[token]
|
104 |
-
else:
|
105 |
-
# Если встретилось слово, то укажем метку регистра и добавим в список предобработанных слов в нижнем регистре
|
106 |
-
if sum(char.isupper() for char in token) > 1:
|
107 |
-
token_labels.append(['UPPER_TOTAL', 'O'])
|
108 |
-
elif token[0].isupper():
|
109 |
-
token_labels.append(['UPPER', 'O'])
|
110 |
-
else:
|
111 |
-
token_labels.append(['LOWER', 'O'])
|
112 |
-
preprocessed_tokens.append(token.lower())
|
113 |
-
token_labels_merged = ['_'.join(label) for label in token_labels]
|
114 |
-
token_labels_ids = [LABELS[label] for label in token_labels_merged]
|
115 |
-
return dict(words=preprocessed_tokens, labels=token_labels_merged, label_ids=token_labels_ids)
|
116 |
-
|
117 |
-
|
118 |
-
def align_labels(label_ids: list[int], word_ids: list[Optional[int]]):
|
119 |
-
aligned_label_ids = []
|
120 |
-
previous_id = None
|
121 |
-
for word_id in word_ids:
|
122 |
-
if word_id is None or word_id == previous_id:
|
123 |
-
aligned_label_ids.append(LABELS['O'])
|
124 |
-
else:
|
125 |
-
aligned_label_ids.append(label_ids.pop(0))
|
126 |
-
previous_id = word_id
|
127 |
-
return aligned_label_ids
|
128 |
-
|
129 |
-
|
130 |
MODEL_REPO = "kontur-ai/sbert-punc-case-ru"
|
131 |
|
132 |
|
@@ -151,22 +107,18 @@ class SbertPuncCase(nn.Module):
|
|
151 |
def punctuate(self, text):
|
152 |
text = text.strip().lower()
|
153 |
|
154 |
-
#
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
|
159 |
tokenizer_output = self.tokenizer(words, is_split_into_words=True)
|
160 |
-
aligned_label_ids = [align_labels(label_ids, tokenizer_output.word_ids())]
|
161 |
-
|
162 |
-
result = dict(tokenizer_output)
|
163 |
-
result.update({'labels': aligned_label_ids})
|
164 |
|
165 |
-
if len(
|
166 |
return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
|
167 |
|
168 |
-
predictions = self(torch.tensor([
|
169 |
-
torch.tensor([
|
170 |
predictions = np.argmax(predictions, axis=2)
|
171 |
|
172 |
# decode punctuation and casing
|
@@ -183,7 +135,7 @@ class SbertPuncCase(nn.Module):
|
|
183 |
|
184 |
if __name__ == '__main__':
|
185 |
parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
|
186 |
-
parser.add_argument("-i", "--input", type=str, help="text to restore", default='
|
187 |
parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
|
188 |
args = parser.parse_args()
|
189 |
print(f"Source text: {args.input}\n")
|
|
|
8 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
9 |
import re
|
10 |
import string
|
|
|
11 |
|
12 |
|
13 |
TOKEN_RE = re.compile(r'-?\d*\.\d+|[a-zа-яё]+|-?[\d\+\(\)\-]+|\S', re.I)
|
|
|
83 |
return INVERSE_LABELS[label]
|
84 |
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
MODEL_REPO = "kontur-ai/sbert-punc-case-ru"
|
87 |
|
88 |
|
|
|
107 |
def punctuate(self, text):
|
108 |
text = text.strip().lower()
|
109 |
|
110 |
+
# Разобъем предложение на слова и знаки препинания
|
111 |
+
tokens = TOKEN_RE.findall(text)
|
112 |
+
# Удалим знаки препинания
|
113 |
+
words = [token for token in tokens if token not in string.punctuation]
|
114 |
|
115 |
tokenizer_output = self.tokenizer(words, is_split_into_words=True)
|
|
|
|
|
|
|
|
|
116 |
|
117 |
+
if len(tokenizer_output.input_ids) > 512:
|
118 |
return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
|
119 |
|
120 |
+
predictions = self(torch.tensor([tokenizer_output.input_ids], device=self.model.device),
|
121 |
+
torch.tensor([tokenizer_output.attention_mask], device=self.model.device)).logits.cpu().data.numpy()
|
122 |
predictions = np.argmax(predictions, axis=2)
|
123 |
|
124 |
# decode punctuation and casing
|
|
|
135 |
|
136 |
if __name__ == '__main__':
|
137 |
parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
|
138 |
+
parser.add_argument("-i", "--input", type=str, help="text to restore", default='sbert punc case расставляет точки запятые и знаки вопроса вам нравится')
|
139 |
parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
|
140 |
args = parser.parse_args()
|
141 |
print(f"Source text: {args.input}\n")
|