Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from transformers import BertTokenizer, BertModel | |
import gradio as gr | |
class MultiTaskBertModel(nn.Module): | |
def __init__(self, num_labels_task1, num_labels_task2): | |
super(MultiTaskBertModel, self).__init__() | |
self.bert = BertModel.from_pretrained('bert-base-uncased') | |
self.classifier_task1 = nn.Linear(self.bert.config.hidden_size, num_labels_task1) | |
self.classifier_task2 = nn.Linear(self.bert.config.hidden_size, num_labels_task2) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
pooled_output = outputs.pooler_output | |
logits_task1 = self.classifier_task1(pooled_output) | |
logits_task2 = self.classifier_task2(pooled_output) | |
return logits_task1, logits_task2 | |
# Загрузка сохраненной модели | |
model = MultiTaskBertModel(num_labels_task1=3, num_labels_task2=4) | |
model = torch.load('ticket_classifier.pth', weights_only=False, map_location=torch.device('cpu')) | |
model.eval() | |
# Загрузка токенизатора | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Функция для предсказания | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
input_ids = inputs["input_ids"] | |
attention_mask = inputs["attention_mask"] | |
# Получаем предсказания для двух задач | |
logits_task1, logits_task2 = model(input_ids=input_ids, attention_mask=attention_mask) | |
# Преобразование логитов в предсказания классов | |
pred_task1 = torch.argmax(logits_task1, dim=1).item() | |
pred_task2 = torch.argmax(logits_task2, dim=1).item() | |
return {"Task 1 Prediction": pred_task1, "Task 2 Prediction": pred_task2} | |
# Создание интерфейса с Gradio | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=2, placeholder="Введите текст для анализа..."), # Обновлено на gr.Textbox | |
outputs=gr.JSON(), # Обновлено на gr.JSON | |
title="Multi-Task BERT Model", | |
description="Модель BERT для одновременного решения двух задач: тональность текста и тема.", | |
) | |
iface.launch() | |