Almira commited on
Commit
2b720c7
·
1 Parent(s): 37f0390

Remove unnecessary funcs and processing

Browse files
Files changed (1) hide show
  1. sbert-punc-case-ru/sbertpunccase.py +8 -56
sbert-punc-case-ru/sbertpunccase.py CHANGED
@@ -8,7 +8,6 @@ import numpy as np
8
  from transformers import AutoTokenizer, AutoModelForTokenClassification
9
  import re
10
  import string
11
- from typing import List, Optional
12
 
13
 
14
  TOKEN_RE = re.compile(r'-?\d*\.\d+|[a-zа-яё]+|-?[\d\+\(\)\-]+|\S', re.I)
@@ -84,49 +83,6 @@ def decode_label(label, classes='all'):
84
  return INVERSE_LABELS[label]
85
 
86
 
87
- def make_labeling(text: str):
88
- # Разобъем предложение на слова и знаки препинания
89
- tokens = TOKEN_RE.findall(text)
90
- # Предобработаем слова, удалим знаки препинания и зададим метки
91
-
92
- preprocessed_tokens = []
93
- token_labels: List[List[str]] = []
94
-
95
- # Убираем всю пунктуацию в начале предложения
96
- while tokens[0] in string.punctuation:
97
- tokens.pop(0)
98
-
99
- for token in tokens:
100
- if token in string.punctuation:
101
- # Если встретился знак препинания который мы прогнозируем изменим метку предыдущего слова, иначе проигнорируем его
102
- if token in PUNK_MAPPING:
103
- token_labels[-1][1] = PUNK_MAPPING[token]
104
- else:
105
- # Если встретилось слово, то укажем метку регистра и добавим в список предобработанных слов в нижнем регистре
106
- if sum(char.isupper() for char in token) > 1:
107
- token_labels.append(['UPPER_TOTAL', 'O'])
108
- elif token[0].isupper():
109
- token_labels.append(['UPPER', 'O'])
110
- else:
111
- token_labels.append(['LOWER', 'O'])
112
- preprocessed_tokens.append(token.lower())
113
- token_labels_merged = ['_'.join(label) for label in token_labels]
114
- token_labels_ids = [LABELS[label] for label in token_labels_merged]
115
- return dict(words=preprocessed_tokens, labels=token_labels_merged, label_ids=token_labels_ids)
116
-
117
-
118
- def align_labels(label_ids: list[int], word_ids: list[Optional[int]]):
119
- aligned_label_ids = []
120
- previous_id = None
121
- for word_id in word_ids:
122
- if word_id is None or word_id == previous_id:
123
- aligned_label_ids.append(LABELS['O'])
124
- else:
125
- aligned_label_ids.append(label_ids.pop(0))
126
- previous_id = word_id
127
- return aligned_label_ids
128
-
129
-
130
  MODEL_REPO = "kontur-ai/sbert-punc-case-ru"
131
 
132
 
@@ -151,22 +107,18 @@ class SbertPuncCase(nn.Module):
151
  def punctuate(self, text):
152
  text = text.strip().lower()
153
 
154
- # preprocess
155
- words_with_labels = make_labeling(text)
156
- words = words_with_labels['words']
157
- label_ids = words_with_labels['label_ids']
158
 
159
  tokenizer_output = self.tokenizer(words, is_split_into_words=True)
160
- aligned_label_ids = [align_labels(label_ids, tokenizer_output.word_ids())]
161
-
162
- result = dict(tokenizer_output)
163
- result.update({'labels': aligned_label_ids})
164
 
165
- if len(result['input_ids']) > 512:
166
  return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
167
 
168
- predictions = self(torch.tensor([result['input_ids']], device=self.model.device),
169
- torch.tensor([result['attention_mask']], device=self.model.device)).logits.cpu().data.numpy()
170
  predictions = np.argmax(predictions, axis=2)
171
 
172
  # decode punctuation and casing
@@ -183,7 +135,7 @@ class SbertPuncCase(nn.Module):
183
 
184
  if __name__ == '__main__':
185
  parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
186
- parser.add_argument("-i", "--input", type=str, help="text to restore", default='SbertPuncCase расставляет точки запятые и знаки вопроса вам нравится')
187
  parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
188
  args = parser.parse_args()
189
  print(f"Source text: {args.input}\n")
 
8
  from transformers import AutoTokenizer, AutoModelForTokenClassification
9
  import re
10
  import string
 
11
 
12
 
13
  TOKEN_RE = re.compile(r'-?\d*\.\d+|[a-zа-яё]+|-?[\d\+\(\)\-]+|\S', re.I)
 
83
  return INVERSE_LABELS[label]
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  MODEL_REPO = "kontur-ai/sbert-punc-case-ru"
87
 
88
 
 
107
  def punctuate(self, text):
108
  text = text.strip().lower()
109
 
110
+ # Разобъем предложение на слова и знаки препинания
111
+ tokens = TOKEN_RE.findall(text)
112
+ # Удалим знаки препинания
113
+ words = [token for token in tokens if token not in string.punctuation]
114
 
115
  tokenizer_output = self.tokenizer(words, is_split_into_words=True)
 
 
 
 
116
 
117
+ if len(tokenizer_output.input_ids) > 512:
118
  return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
119
 
120
+ predictions = self(torch.tensor([tokenizer_output.input_ids], device=self.model.device),
121
+ torch.tensor([tokenizer_output.attention_mask], device=self.model.device)).logits.cpu().data.numpy()
122
  predictions = np.argmax(predictions, axis=2)
123
 
124
  # decode punctuation and casing
 
135
 
136
  if __name__ == '__main__':
137
  parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
138
+ parser.add_argument("-i", "--input", type=str, help="text to restore", default='sbert punc case расставляет точки запятые и знаки вопроса вам нравится')
139
  parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
140
  args = parser.parse_args()
141
  print(f"Source text: {args.input}\n")