DanilO0o's picture
fgeegdfr
8899279
from PIL import Image, ImageFilter, ImageDraw
import streamlit as st
import pickle
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from dataclasses import dataclass
from typing import Union
import re
import string
import pymorphy3
import nltk
from nltk.corpus import stopwords
# stop_words = set(stopwords.words("english"))
stop_words = nltk.download('stopwords')
# ------------------------------------------------------------#
# Упрощенный метод создания класса
@dataclass
class ConfigRNN:
vocab_size: int # сколько слов - столько embedding-ов; для инициализации embedding параметров
device: str
n_layers: int
embedding_dim: int # чем больше, тем сложнее можно закодировать слово
hidden_size: int
seq_len: int
bidirectional: Union[bool, int]
net_config = ConfigRNN(
vocab_size=17259 + 1, # -> hand
device="cpu",
n_layers=1,
embedding_dim=8, # не лучшее значение, но в рамках задачи сойдет
hidden_size=16,
seq_len=30, # -> hand
bidirectional=False,
)
# ------------------------------------------------------------#
class LSTMClassifier(nn.Module):
def __init__(self, rnn_conf=net_config) -> None:
super().__init__()
self.embedding_dim = rnn_conf.embedding_dim
self.hidden_size = rnn_conf.hidden_size
self.bidirectional = rnn_conf.bidirectional
self.n_layers = rnn_conf.n_layers
self.embedding = nn.Embedding(rnn_conf.vocab_size, self.embedding_dim)
self.lstm = nn.LSTM(
input_size=self.embedding_dim,
hidden_size=self.hidden_size,
bidirectional=self.bidirectional,
batch_first=True,
num_layers=self.n_layers,
dropout=0.5
)
self.bidirect_factor = 2 if self.bidirectional else 1
self.clf = nn.Sequential(
nn.Linear(self.hidden_size * self.bidirect_factor, 32),
nn.Dropout(),
nn.Tanh(),
nn.Dropout(),
nn.Linear(32, 5) # len(df['label'].unique())
)
def model_description(self):
direction = "bidirect" if self.bidirectional else "onedirect"
return f"lstm_{direction}_{self.n_layers}"
def forward(self, x: torch.Tensor):
embeddings = self.embedding(x)
out, _ = self.lstm(embeddings)
# print(out.shape)
# [все элементы батча, последний h_n, все элементы последнего h_n]
out = out[:, -1, :]
# print(out.shape)
out = self.clf(out)
return out
# ------------------------------------------------------------#
# Загрузка модели
@st.cache_resource
def load_model():
model = LSTMClassifier(net_config)
model.load_state_dict(torch.load(
"models/lstm_weights.pth", map_location=torch.device("cpu")))
model.eval()
return model
model_lstm = load_model()
# ------------------------------------------------------------#
def padding(text_int: list, seq_len: int) -> np.ndarray:
"""Make left-sided padding for input list of tokens
Args:
review_int (list): input list of tokens
seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
Returns:
np.array: padded sequences
"""
features = np.zeros((len(text_int), seq_len), dtype=int)
for i, review in enumerate(text_int):
if len(review) <= seq_len:
zeros = list(np.zeros(seq_len - len(review)))
new = zeros + review
else:
new = review[:seq_len]
features[i, :] = np.array(new)
return features
morph = pymorphy3.MorphAnalyzer()
def lemmatize(text):
# Разбиваем текст на слова
words = text.split()
# Лемматизируем каждое слово и убираем стоп-слова
lemmatized_words = [morph.parse(word)[0].normal_form for word in words]
# Собираем текст из лемматизированных слов
lemmatized_text = ' '.join(lemmatized_words)
return lemmatized_text
def data_preprocessing(text):
# From Phase 1
text = re.sub(r':[a-zA-Z]+:', '', text) # Убираем смайлики
text = text.lower() # Переводим текст в нижний регистр
text = re.sub(r'@[\w_-]+', '', text) # Убираем упоминания пользователей
text = re.sub(r'#(\w+)', '', text) # Убираем хэштеги
text = re.sub(r'\d+', '', text) # Убираем цифры
# Убираем ссылки
text = re.sub(
r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
text = re.sub(r'\s+', ' ', text) # Убираем лишние пробелы
# Удаление английских слов
text = ' '.join(re.findall(r'\b[а-яА-ЯёЁ]+\b', text))
# From Phase 2
text = re.sub("<.*?>", "", text) # html tags
text = "".join([c for c in text if c not in string.punctuation])
splitted_text = [word for word in text.split() if word not in stop_words]
text = " ".join(splitted_text)
return text.strip()
def preprocess_single_string(
input_string: str,
seq_len: int,
vocab_to_int: dict,
verbose: bool = False
) -> Tensor:
"""Function for all preprocessing steps on a single string
Args:
input_string (str): input single string for preprocessing
seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
vocab_to_int (dict, optional): word corpus {'word' : int index}. Defaults to vocab_to_int.
Returns:
list: preprocessed string
"""
preprocessed_string = lemmatize(input_string)
preprocessed_string = data_preprocessing(input_string)
result_list = []
for word in preprocessed_string.split():
try:
result_list.append(vocab_to_int[word])
except KeyError as e:
if verbose:
print(f'{e}: not in dictionary!')
pass
result_padded = padding([result_list], seq_len)[0]
return Tensor(result_padded)
# ------------------------------------------------------------#
st.title("Классификация тематики новостей из телеграм каналов")
# st.write('Model summary:')
text = st.text_input('Input some news')
text_4_test = text
# Загрузка словаря из файла
with open('model/vocab_to_int.pkl', 'rb') as f:
vocab_to_int = pickle.load(f)
if text != '':
test_review = preprocess_single_string(
text_4_test, net_config.seq_len, vocab_to_int)
test_review = torch.tensor(test_review, dtype=torch.int64)
result = torch.sigmoid(model_lstm(test_review.unsqueeze(0)))
num = result.argmax().item()
st.write('---')
st.write('Initial text:')
st.write(text)
st.write('---')
st.write('Preprocessing:')
st.write(data_preprocessing(text))
st.write('---')
st.write('Classes:')
classes = ['крипта', 'мода', 'спорт', 'технологии', 'финансы']
st.write('крипта *', 'мода *', 'спорт *', 'технологии *', 'финансы')
st.write('---')
st.write('Predict:')
if text != '':
st.write('Classification: ', classes[num])
st.write('Label num: ', num)
# Загружаем изображение через PIL
image = Image.open("images/tg_metrics.png")
# Отображение
st.image(image, caption="Кошмареус переобучения", use_column_width=True)