File size: 2,397 Bytes
64b4c39
bfa01ee
3259969
 
64b4c39
3259969
 
83510a7
3259969
 
 
 
 
 
83510a7
 
 
 
 
3259969
 
5656ec0
64b4c39
 
3259969
 
64b4c39
3259969
 
 
 
 
 
 
 
 
 
64b4c39
 
3259969
 
 
 
 
 
549ac53
 
3259969
 
64b4c39
 
3259969
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import gradio as gr

class MultiTaskBertModel(nn.Module):
    def __init__(self, num_labels_task1, num_labels_task2):
        super(MultiTaskBertModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier_task1 = nn.Linear(self.bert.config.hidden_size, num_labels_task1)
        self.classifier_task2 = nn.Linear(self.bert.config.hidden_size, num_labels_task2)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits_task1 = self.classifier_task1(pooled_output)
        logits_task2 = self.classifier_task2(pooled_output)
        return logits_task1, logits_task2

# Загрузка сохраненной модели
model = MultiTaskBertModel(num_labels_task1=3, num_labels_task2=4)
model = torch.load('ticket_classifier.pth', weights_only=False, map_location=torch.device('cpu'))
model.eval()

# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Функция для предсказания
def predict(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    
    # Получаем предсказания для двух задач
    logits_task1, logits_task2 = model(input_ids=input_ids, attention_mask=attention_mask)
    
    # Преобразование логитов в предсказания классов
    pred_task1 = torch.argmax(logits_task1, dim=1).item()
    pred_task2 = torch.argmax(logits_task2, dim=1).item()
    
    return {"Task 1 Prediction": pred_task1, "Task 2 Prediction": pred_task2}

# Создание интерфейса с Gradio
iface = gr.Interface(
    fn=predict, 
    inputs=gr.Textbox(lines=2, placeholder="Введите текст для анализа..."),  # Обновлено на gr.Textbox
    outputs=gr.JSON(),  # Обновлено на gr.JSON
    title="Multi-Task BERT Model",
    description="Модель BERT для одновременного решения двух задач: тональность текста и тема.",
)

iface.launch()