File size: 4,824 Bytes
64b4c39
bfa01ee
3259969
 
64b4c39
3259969
 
83510a7
3259969
 
 
 
 
 
83510a7
 
 
 
 
3259969
 
5656ec0
64b4c39
 
3259969
 
64b4c39
3259969
09de64d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3259969
 
 
 
 
 
 
 
 
64b4c39
 
3259969
b009f2e
3259969
 
 
09de64d
549ac53
09de64d
 
64b4c39
 
7efd2f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import gradio as gr

class MultiTaskBertModel(nn.Module):
    def __init__(self, num_labels_task1, num_labels_task2):
        super(MultiTaskBertModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier_task1 = nn.Linear(self.bert.config.hidden_size, num_labels_task1)
        self.classifier_task2 = nn.Linear(self.bert.config.hidden_size, num_labels_task2)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits_task1 = self.classifier_task1(pooled_output)
        logits_task2 = self.classifier_task2(pooled_output)
        return logits_task1, logits_task2

# Загрузка сохраненной модели
model = MultiTaskBertModel(num_labels_task1=3, num_labels_task2=4)
model = torch.load('ticket_classifier.pth', weights_only=False, map_location=torch.device('cpu'))
model.eval()

# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Функция для предсказания


label_mapping_type = {'варианты доставки': 0,
 'варианты оплаты': 1,
 'возврат средств': 2,
 'восстановление пароля': 3,
 'время доставки': 4,
 'выбор адреса доставки': 5,
 'жалоба': 6,
 'изменение адреса доставки': 7,
 'изменение заказа': 8,
 'отзыв': 9,
 'отмена заказа': 10,
 'отслеживание возврата средств': 11,
 'отслеживание заказа': 12,
 'подписка на новостную рассылку': 13,
 'политика возврата': 14,
 'получение информации': 15,
 'проблемы с оплатой': 16,
 'проблемы с регистрацией': 17,
 'проверка платы за отмену': 18,
 'проверка счета': 19,
 'проверка счетов': 20,
 'размещение заказа': 21,
 'редактирование учетной записи': 22,
 'связь с человеком': 23,
 'связь со службой поддержки': 24,
 'смена учетной записи': 25,
 'создание учетной записи': 26,
 'удаление аккаунта': 27}

label_mapping_priority = {'высокий': 0, 'низкий': 1, 'средний': 2}

def get_key_by_value(dictionary, value):
    reverse_dict = {v: k for k, v in dictionary.items()}
    return reverse_dict.get(value)

def predict(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    
    # Получаем предсказания для двух задач
    logits_task1, logits_task2 = model(input_ids=input_ids, attention_mask=attention_mask)
    
    # Преобразование логитов в предсказания классов
    pred_task1 = torch.argmax(logits_task1, dim=1).item()
    pred_task2 = torch.argmax(logits_task2, dim=1).item()
    
    return {"Тема": get_key_by_value(label_mapping_type, pred_task1)}

iface = gr.Interface(
    fn=predict, 
    inputs=gr.Textbox(lines=2, placeholder="Введите запрос для анализа..."),  # Обновлено на gr.Textbox
    outputs=gr.JSON(),  # Обновлено на gr.JSON
    title="Классификация запроса",
    description="'варианты доставки', 'варианты оплаты', 'возврат средств', 'восстановление пароля', 'время доставки', 'выбор адреса доставки', 'жалоба', 'изменение адреса доставки', 'изменение заказа', 'отзыв', 'отмена заказа', 'отслеживание возврата средств', 'отслеживание заказа', 'подписка на новостную рассылку', 'политика возврата', 'получение информации', 'проблемы с оплатой', 'проблемы с регистрацией', 'проверка платы за отмену', 'проверка счета', 'размещение заказа', 'редактирование учетной записи', 'связь с человеком', 'связь со службой поддержки', 'смена учетной записи', 'создание учетной записи', 'удаление аккаунта'",
)

iface.launch(share=True)