File size: 3,789 Bytes
1867879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
from torch import nn


# Загрузка модели и токенизатора (кешируем для ускорения)
@st.cache_resource
def load_model():
    MODEL_NAME = "cointegrated/rubert-tiny2"
    model = AutoModel.from_pretrained(MODEL_NAME, num_labels=5)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    return model, tokenizer


PATH = "models/model_weight_bert.pt"


class MyTinyBERT(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.bert = model
        for param in self.bert.parameters():
            param.requires_grad = False
        self.linear = nn.Sequential(
            nn.Linear(312, 256), nn.Dropout(0.3), nn.ReLU(), nn.Linear(256, 5)
        )

    def forward(self, input_ids, attention_mask):
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        normed_bert_out = bert_out.last_hidden_state[:, 0, :]
        out = self.linear(normed_bert_out)
        return out


def classification_myBERT(text, model, tokenizer):
    model = MyTinyBERT(model)
    model.load_state_dict(torch.load(PATH, weights_only=True))
    model.eval()
    my_classes = {0: "Крипта", 1: "Мода", 2: "Спорт", 3: "Технологии", 4: "Финансы"}
    t = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
    return f'Хоть я и не ChatGPT, осмелюсь предположить, что данный текст относится к следующему классу:\n{my_classes[torch.argmax(model(t["input_ids"], t["attention_mask"])).item()]}'


# Интерфейс Streamlit
def main():
    st.markdown(
        "<h1 style='text-align: center;'>Классификация тематики новостей из телеграм каналов.</h1>",
        unsafe_allow_html=True,
    )
    st.markdown("---")

    col1, col2, col3 = st.columns([1, 8, 1])  # Центральная колонка шире остальных
    with col2:
        st.markdown(
            "<h5 style='text-align: center;'>Использование классического алгоритма</h5>",
            unsafe_allow_html=True,
        )
        # st.text("Использование классического алгоритма")
        st.image("./images/Struct.png", width=500)
        st.image("./images/L_A.png", width=800)
        st.image("./images/C_M.png", width=800)

        st.markdown(
            "<h5 style='text-align: center;'>Стандартный rubert_tiny2</h5>",
            unsafe_allow_html=True,
        )
        # st.text("Использование классического алгоритма")
        st.image("./images/LogReg.png", width=800)

        st.markdown(
            "<h5 style='text-align: center;'>rubert_tiny2 с обучаемым fc слоем</h5>",
            unsafe_allow_html=True,
        )
        # st.text("Использование классического алгоритма")
        st.image("./images/myTinyBERT.png", width=800)

    # Загрузка модели
    model, tokenizer = load_model()

    # Параметры генерации
    with st.sidebar:
        st.header("Настройки генерации")
        prompt = st.text_area("Введите начальный текст:", height=100)

    # Кнопка генерации
    if st.sidebar.button("Сгенерировать текст"):
        if not prompt:
            st.warning("Введите начальный текст!")
            return

        st.subheader("Результаты:")
        st.text(classification_myBERT(prompt, model, tokenizer))


if __name__ == "__main__":
    main()