DanilO0o commited on
Commit
8899279
·
1 Parent(s): 4232110
Files changed (1) hide show
  1. Владик/model.py +229 -0
Владик/model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter, ImageDraw
2
+ import streamlit as st
3
+
4
+ import pickle
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch import Tensor
9
+ from dataclasses import dataclass
10
+ from typing import Union
11
+ import re
12
+ import string
13
+ import pymorphy3
14
+ import nltk
15
+ from nltk.corpus import stopwords
16
+ # stop_words = set(stopwords.words("english"))
17
+ stop_words = nltk.download('stopwords')
18
+
19
+
20
+ # ------------------------------------------------------------#
21
+ # Упрощенный метод создания класса
22
+
23
+ @dataclass
24
+ class ConfigRNN:
25
+ vocab_size: int # сколько слов - столько embedding-ов; для инициализации embedding параметров
26
+ device: str
27
+ n_layers: int
28
+ embedding_dim: int # чем больше, тем сложнее можно закодировать слово
29
+ hidden_size: int
30
+ seq_len: int
31
+ bidirectional: Union[bool, int]
32
+
33
+
34
+ net_config = ConfigRNN(
35
+ vocab_size=17259 + 1, # -> hand
36
+ device="cpu",
37
+ n_layers=1,
38
+ embedding_dim=8, # не лучшее значение, но в рамках задачи сойдет
39
+ hidden_size=16,
40
+ seq_len=30, # -> hand
41
+ bidirectional=False,
42
+ )
43
+ # ------------------------------------------------------------#
44
+
45
+
46
+ class LSTMClassifier(nn.Module):
47
+ def __init__(self, rnn_conf=net_config) -> None:
48
+ super().__init__()
49
+
50
+ self.embedding_dim = rnn_conf.embedding_dim
51
+ self.hidden_size = rnn_conf.hidden_size
52
+ self.bidirectional = rnn_conf.bidirectional
53
+ self.n_layers = rnn_conf.n_layers
54
+
55
+ self.embedding = nn.Embedding(rnn_conf.vocab_size, self.embedding_dim)
56
+ self.lstm = nn.LSTM(
57
+ input_size=self.embedding_dim,
58
+ hidden_size=self.hidden_size,
59
+ bidirectional=self.bidirectional,
60
+ batch_first=True,
61
+ num_layers=self.n_layers,
62
+ dropout=0.5
63
+ )
64
+ self.bidirect_factor = 2 if self.bidirectional else 1
65
+ self.clf = nn.Sequential(
66
+ nn.Linear(self.hidden_size * self.bidirect_factor, 32),
67
+ nn.Dropout(),
68
+ nn.Tanh(),
69
+ nn.Dropout(),
70
+ nn.Linear(32, 5) # len(df['label'].unique())
71
+ )
72
+
73
+ def model_description(self):
74
+ direction = "bidirect" if self.bidirectional else "onedirect"
75
+ return f"lstm_{direction}_{self.n_layers}"
76
+
77
+ def forward(self, x: torch.Tensor):
78
+ embeddings = self.embedding(x)
79
+ out, _ = self.lstm(embeddings)
80
+ # print(out.shape)
81
+ # [все элементы батча, последний h_n, все элементы последнего h_n]
82
+ out = out[:, -1, :]
83
+ # print(out.shape)
84
+ out = self.clf(out)
85
+ return out
86
+ # ------------------------------------------------------------#
87
+ # Загрузка модели
88
+
89
+
90
+ @st.cache_resource
91
+ def load_model():
92
+ model = LSTMClassifier(net_config)
93
+ model.load_state_dict(torch.load(
94
+ "models/lstm_weights.pth", map_location=torch.device("cpu")))
95
+ model.eval()
96
+ return model
97
+
98
+
99
+ model_lstm = load_model()
100
+ # ------------------------------------------------------------#
101
+
102
+
103
+ def padding(text_int: list, seq_len: int) -> np.ndarray:
104
+ """Make left-sided padding for input list of tokens
105
+
106
+ Args:
107
+ review_int (list): input list of tokens
108
+ seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
109
+
110
+ Returns:
111
+ np.array: padded sequences
112
+ """
113
+ features = np.zeros((len(text_int), seq_len), dtype=int)
114
+ for i, review in enumerate(text_int):
115
+ if len(review) <= seq_len:
116
+ zeros = list(np.zeros(seq_len - len(review)))
117
+ new = zeros + review
118
+ else:
119
+ new = review[:seq_len]
120
+ features[i, :] = np.array(new)
121
+ return features
122
+
123
+
124
+ morph = pymorphy3.MorphAnalyzer()
125
+
126
+
127
+ def lemmatize(text):
128
+ # Разбиваем текст на слова
129
+ words = text.split()
130
+
131
+ # Лемматизируем каждое слово и убираем стоп-слова
132
+ lemmatized_words = [morph.parse(word)[0].normal_form for word in words]
133
+
134
+ # Собираем текст из лемматизированных слов
135
+ lemmatized_text = ' '.join(lemmatized_words)
136
+ return lemmatized_text
137
+
138
+
139
+ def data_preprocessing(text):
140
+ # From Phase 1
141
+ text = re.sub(r':[a-zA-Z]+:', '', text) # Убираем смайлики
142
+ text = text.lower() # Переводим текст в нижний регистр
143
+ text = re.sub(r'@[\w_-]+', '', text) # Убираем упоминания пользователей
144
+ text = re.sub(r'#(\w+)', '', text) # Убираем хэштеги
145
+ text = re.sub(r'\d+', '', text) # Убираем цифры
146
+ # Убираем ссылки
147
+ text = re.sub(
148
+ r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
149
+ text = re.sub(r'\s+', ' ', text) # Убираем лишние пробелы
150
+ # Удаление английских слов
151
+ text = ' '.join(re.findall(r'\b[а-яА-ЯёЁ]+\b', text))
152
+ # From Phase 2
153
+ text = re.sub("<.*?>", "", text) # html tags
154
+ text = "".join([c for c in text if c not in string.punctuation])
155
+ splitted_text = [word for word in text.split() if word not in stop_words]
156
+ text = " ".join(splitted_text)
157
+ return text.strip()
158
+
159
+
160
+ def preprocess_single_string(
161
+ input_string: str,
162
+ seq_len: int,
163
+ vocab_to_int: dict,
164
+ verbose: bool = False
165
+ ) -> Tensor:
166
+ """Function for all preprocessing steps on a single string
167
+
168
+ Args:
169
+ input_string (str): input single string for preprocessing
170
+ seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
171
+ vocab_to_int (dict, optional): word corpus {'word' : int index}. Defaults to vocab_to_int.
172
+
173
+ Returns:
174
+ list: preprocessed string
175
+ """
176
+ preprocessed_string = lemmatize(input_string)
177
+ preprocessed_string = data_preprocessing(input_string)
178
+ result_list = []
179
+ for word in preprocessed_string.split():
180
+ try:
181
+ result_list.append(vocab_to_int[word])
182
+ except KeyError as e:
183
+ if verbose:
184
+ print(f'{e}: not in dictionary!')
185
+ pass
186
+ result_padded = padding([result_list], seq_len)[0]
187
+
188
+ return Tensor(result_padded)
189
+ # ------------------------------------------------------------#
190
+
191
+
192
+ st.title("Классификация тематики новостей из телеграм каналов")
193
+ # st.write('Model summary:')
194
+ text = st.text_input('Input some news')
195
+ text_4_test = text
196
+
197
+ # Загрузка словаря из файла
198
+ with open('model/vocab_to_int.pkl', 'rb') as f:
199
+ vocab_to_int = pickle.load(f)
200
+
201
+ if text != '':
202
+ test_review = preprocess_single_string(
203
+ text_4_test, net_config.seq_len, vocab_to_int)
204
+ test_review = torch.tensor(test_review, dtype=torch.int64)
205
+ result = torch.sigmoid(model_lstm(test_review.unsqueeze(0)))
206
+ num = result.argmax().item()
207
+
208
+ st.write('---')
209
+ st.write('Initial text:')
210
+ st.write(text)
211
+ st.write('---')
212
+ st.write('Preprocessing:')
213
+ st.write(data_preprocessing(text))
214
+ st.write('---')
215
+ st.write('Classes:')
216
+ classes = ['крипта', 'мода', 'спорт', 'технологии', 'финансы']
217
+ st.write('крипта *', 'мода *', 'спорт *', 'технологии *', 'финансы')
218
+ st.write('---')
219
+
220
+ st.write('Predict:')
221
+ if text != '':
222
+ st.write('Classification: ', classes[num])
223
+ st.write('Label num: ', num)
224
+
225
+ # Загружаем изображение через PIL
226
+ image = Image.open("images/tg_metrics.png")
227
+
228
+ # Отображение
229
+ st.image(image, caption="Кошмареус переобучения", use_column_width=True)