Spaces:
Sleeping
Sleeping
File size: 4,824 Bytes
64b4c39 bfa01ee 3259969 64b4c39 3259969 83510a7 3259969 83510a7 3259969 5656ec0 64b4c39 3259969 64b4c39 3259969 09de64d 3259969 64b4c39 3259969 b009f2e 3259969 09de64d 549ac53 09de64d 64b4c39 7efd2f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import gradio as gr
class MultiTaskBertModel(nn.Module):
def __init__(self, num_labels_task1, num_labels_task2):
super(MultiTaskBertModel, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.classifier_task1 = nn.Linear(self.bert.config.hidden_size, num_labels_task1)
self.classifier_task2 = nn.Linear(self.bert.config.hidden_size, num_labels_task2)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits_task1 = self.classifier_task1(pooled_output)
logits_task2 = self.classifier_task2(pooled_output)
return logits_task1, logits_task2
# Загрузка сохраненной модели
model = MultiTaskBertModel(num_labels_task1=3, num_labels_task2=4)
model = torch.load('ticket_classifier.pth', weights_only=False, map_location=torch.device('cpu'))
model.eval()
# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Функция для предсказания
label_mapping_type = {'варианты доставки': 0,
'варианты оплаты': 1,
'возврат средств': 2,
'восстановление пароля': 3,
'время доставки': 4,
'выбор адреса доставки': 5,
'жалоба': 6,
'изменение адреса доставки': 7,
'изменение заказа': 8,
'отзыв': 9,
'отмена заказа': 10,
'отслеживание возврата средств': 11,
'отслеживание заказа': 12,
'подписка на новостную рассылку': 13,
'политика возврата': 14,
'получение информации': 15,
'проблемы с оплатой': 16,
'проблемы с регистрацией': 17,
'проверка платы за отмену': 18,
'проверка счета': 19,
'проверка счетов': 20,
'размещение заказа': 21,
'редактирование учетной записи': 22,
'связь с человеком': 23,
'связь со службой поддержки': 24,
'смена учетной записи': 25,
'создание учетной записи': 26,
'удаление аккаунта': 27}
label_mapping_priority = {'высокий': 0, 'низкий': 1, 'средний': 2}
def get_key_by_value(dictionary, value):
reverse_dict = {v: k for k, v in dictionary.items()}
return reverse_dict.get(value)
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Получаем предсказания для двух задач
logits_task1, logits_task2 = model(input_ids=input_ids, attention_mask=attention_mask)
# Преобразование логитов в предсказания классов
pred_task1 = torch.argmax(logits_task1, dim=1).item()
pred_task2 = torch.argmax(logits_task2, dim=1).item()
return {"Тема": get_key_by_value(label_mapping_type, pred_task1)}
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Введите запрос для анализа..."), # Обновлено на gr.Textbox
outputs=gr.JSON(), # Обновлено на gr.JSON
title="Классификация запроса",
description="'варианты доставки', 'варианты оплаты', 'возврат средств', 'восстановление пароля', 'время доставки', 'выбор адреса доставки', 'жалоба', 'изменение адреса доставки', 'изменение заказа', 'отзыв', 'отмена заказа', 'отслеживание возврата средств', 'отслеживание заказа', 'подписка на новостную рассылку', 'политика возврата', 'получение информации', 'проблемы с оплатой', 'проблемы с регистрацией', 'проверка платы за отмену', 'проверка счета', 'размещение заказа', 'редактирование учетной записи', 'связь с человеком', 'связь со службой поддержки', 'смена учетной записи', 'создание учетной записи', 'удаление аккаунта'",
)
iface.launch(share=True)
|