from models.bert_classifier import MyTinyBERT from models.lstm_attention import LSTMAttention from models.text_preprocessor import MyCustomTextPreprocessor import streamlit as st from sklearn.utils.class_weight import compute_class_weight import torch.nn.functional as F import torch.optim as optim import joblib from torch import nn from sklearn.base import BaseEstimator, TransformerMixin from transformers import AutoTokenizer, AutoModel from sklearn.metrics import confusion_matrix, f1_score, accuracy_score from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, TensorDataset from time import time from sklearn.feature_extraction.text import TfidfVectorizer import pymorphy3 import string import re import pandas as pd import numpy as np import torch import sklearn import matplotlib.pyplot as plt import warnings warnings.simplefilter("ignore") # Metrics # custom # ======= Глобальная инициализация токенизатора ======= tokenizer = AutoTokenizer.from_pretrained( "cointegrated/rubert-tiny2") # Для LSTM и BERT # ======= Инициализация обработчика текста ======= preprocessor = MyCustomTextPreprocessor() # ======= Загрузка моделей и векторизатора ======= # @st.cache_resource def load_resources(): # Загрузка TF-IDF векторизатора vectorizer = joblib.load('models/vectorizer.pkl') # TF-IDF # Загрузка модели логистической регрессии # Логистическая регрессия model1 = joblib.load('models/Sasha_logistic_model2.pkl') # Настройка модели LSTM # Используем уже загруженный токенизатор VOCAB_SIZE = len(tokenizer.get_vocab()) EMBEDDING_DIM = 128 HIDDEN_DIM = 256 OUTPUT_DIM = 10 model2 = LSTMAttention(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM) model2.load_state_dict(torch.load( 'models/Sasha_best_lstm_model3.pth', map_location=torch.device('cpu'))) model2.eval() # Настройка модели BERT model3 = MyTinyBERT() model3.load_state_dict(torch.load( 'models/Sasha_best_model_bert.pth', map_location=torch.device('cpu'))) model3.eval() return model1, model2, model3, vectorizer # Загружаем ресурсы model1, model2, model3, vectorizer = load_resources() # ======= Предобработка текста ======= def preprocess_for_model1(text): """TF-IDF векторизация для логистической регрессии""" processed_text = preprocessor.preprocess( text, lemmatize=True) # Лемматизация включена return vectorizer.transform([processed_text]) def preprocess_for_model2_and_model3(text): """Общая обработка для LSTM и BERT моделей (без лемматизации)""" processed_text = preprocessor.preprocess( text, lemmatize=False) # Лемматизация выключена return processed_text def preprocess_for_model2(text, tokenizer): """Токенизация для LSTM модели""" processed_text = preprocess_for_model2_and_model3(text) tokenized_data = tokenizer( [processed_text], padding=True, truncation=True, return_tensors="pt", max_length=256 ) return tokenized_data["input_ids"], tokenized_data["attention_mask"] def preprocess_for_model3(text, tokenizer): """Токенизация для BERT модели""" processed_text = preprocess_for_model2_and_model3(text) tokenized_data = tokenizer( [processed_text], padding=True, truncation=True, return_tensors="pt", max_length=256 ) return tokenized_data # ======= Прогноз и визуализация ======= def predict_and_visualize(text): # ======= Модель 1 (Logistic Regression) ======= start_time = time() # Начало времени предсказания vectorized_text = preprocess_for_model1(text) probs1 = model1.predict_proba(vectorized_text)[0] model1_time = time() - start_time # Рассчитываем время предсказания для модели 1 # ======= Модель 2 (LSTM & Attention) ======= start_time = time() # Начало времени предсказания input_ids, _ = preprocess_for_model2( text, tokenizer) # Получаем только input_ids with torch.no_grad(): logits2, attn_weights = model2(input_ids) # Передаём только input_ids probs2 = torch.softmax(logits2, dim=1).numpy()[0] attention_vector = attn_weights.cpu().numpy()[0] model2_time = time() - start_time # Рассчитываем время предсказания для модели 2 # ======= Модель 3 (BERT) ======= start_time = time() # Начало времени предсказания tokenized_text = preprocess_for_model3(text, tokenizer) with torch.no_grad(): logits3 = model3(tokenized_text) probs3 = torch.softmax(logits3, dim=1).numpy()[0] model3_time = time() - start_time # Рассчитываем время предсказания для модели 3 # ======= Финальное предсказание ======= final_probs = (probs1 + probs2 + probs3) / 3 final_class = np.argmax(final_probs) # ======= Визуализация ======= st.subheader("Распределение вероятностей") for probs, model_name in zip([probs1, probs2, probs3], ['Model 1 (Logistic Regression)', 'Model 2 (LSTM)', 'Model 3 (BERT)']): fig, ax = plt.subplots() ax.bar(range(1, len(probs) + 1), probs) # Сдвиг индекса на +1 ax.set_title(f'{model_name} Probabilities') ax.set_xlabel('Class (1-10)') ax.set_ylabel('Probability') st.pyplot(fig) # ======= Визуализация внимания (LSTM) ======= st.subheader("Веса внимания (LSTM)") # Проверяем наличие attention weights tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) tokens = tokens[:len(attention_vector)] attention_vector = attention_vector[:len(tokens)] fig, ax = plt.subplots(figsize=(12, 6)) ax.bar(range(len(tokens)), attention_vector, align="center") ax.set_xticks(range(len(tokens))) ax.set_xticklabels(tokens, rotation=45, ha="right") ax.set_title("Attention Weights (LSTM)") ax.set_xlabel("Токены") ax.set_ylabel("Вес внимания") st.pyplot(fig) # Итоговое предсказание st.subheader("Итоговое предсказание") # Смещение на +1 st.write(f"Наиболее вероятный класс: **{final_class + 1}**") # Вывод времени выполнения st.subheader("Время выполнения моделей") st.write(f"Модель 1 (Logistic Regression): {model1_time:.4f} секунд") st.write(f"Модель 2 (LSTM): {model2_time:.4f} секунд") st.write(f"Модель 3 (BERT): {model3_time:.4f} секунд") return final_class # ======= Streamlit UI ======= st.title("Классификация текстов с 3 моделями") st.write("Введите текст отзыва, чтобы получить результаты классификации от трёх моделей.") # Ввод текста пользователем user_input = st.text_area("Введите текст отзыва:", "") if st.button("Классифицировать"): if user_input.strip(): # Прогноз и визуализация predict_and_visualize(user_input) else: st.warning("Введите текст для анализа.") st.subheader("F1 macro, валидационная выборка") st.write(f'f1 macro valid logreg=0.2516') st.write(f'f1 macro valid lstm=0.2515') st.write(f'f1 macro valid bert=0.2709')