import streamlit as st from transformers import AutoTokenizer, AutoModel import torch from torch import nn # Загрузка модели и токенизатора (кешируем для ускорения) @st.cache_resource def load_model(): MODEL_NAME = "cointegrated/rubert-tiny2" model = AutoModel.from_pretrained(MODEL_NAME, num_labels=5) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) return model, tokenizer PATH = "models/model_weight_bert.pt" class MyTinyBERT(nn.Module): def __init__(self, model): super().__init__() self.bert = model for param in self.bert.parameters(): param.requires_grad = False self.linear = nn.Sequential( nn.Linear(312, 256), nn.Dropout(0.3), nn.ReLU(), nn.Linear(256, 5) ) def forward(self, input_ids, attention_mask): bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask) normed_bert_out = bert_out.last_hidden_state[:, 0, :] out = self.linear(normed_bert_out) return out def classification_myBERT(text, model, tokenizer): model = MyTinyBERT(model) model.load_state_dict(torch.load(PATH, weights_only=True)) model.eval() my_classes = {0: "Крипта", 1: "Мода", 2: "Спорт", 3: "Технологии", 4: "Финансы"} t = tokenizer(text, padding=True, truncation=True, return_tensors="pt") return f'Хоть я и не ChatGPT, осмелюсь предположить, что данный текст относится к следующему классу:\n{my_classes[torch.argmax(model(t["input_ids"], t["attention_mask"])).item()]}' # Интерфейс Streamlit def main(): st.markdown( "

Классификация тематики новостей из телеграм каналов.

", unsafe_allow_html=True, ) st.markdown("---") col1, col2, col3 = st.columns([1, 8, 1]) # Центральная колонка шире остальных with col2: st.markdown( "
Использование классического алгоритма
", unsafe_allow_html=True, ) # st.text("Использование классического алгоритма") st.image("./images/Struct.png", width=500) st.image("./images/L_A.png", width=800) st.image("./images/C_M.png", width=800) st.markdown( "
Стандартный rubert_tiny2
", unsafe_allow_html=True, ) # st.text("Использование классического алгоритма") st.image("./images/LogReg.png", width=800) st.markdown( "
rubert_tiny2 с обучаемым fc слоем
", unsafe_allow_html=True, ) # st.text("Использование классического алгоритма") st.image("./images/myTinyBERT.png", width=800) # Загрузка модели model, tokenizer = load_model() # Параметры генерации with st.sidebar: st.header("Настройки генерации") prompt = st.text_area("Введите начальный текст:", height=100) # Кнопка генерации if st.sidebar.button("Сгенерировать текст"): if not prompt: st.warning("Введите начальный текст!") return st.subheader("Результаты:") st.text(classification_myBERT(prompt, model, tokenizer)) if __name__ == "__main__": main()