DmitriySv's picture
Update app.py
549ac53 verified
raw
history blame
2.4 kB
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import gradio as gr
class MultiTaskBertModel(nn.Module):
def __init__(self, num_labels_task1, num_labels_task2):
super(MultiTaskBertModel, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.classifier_task1 = nn.Linear(self.bert.config.hidden_size, num_labels_task1)
self.classifier_task2 = nn.Linear(self.bert.config.hidden_size, num_labels_task2)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits_task1 = self.classifier_task1(pooled_output)
logits_task2 = self.classifier_task2(pooled_output)
return logits_task1, logits_task2
# Загрузка сохраненной модели
model = MultiTaskBertModel(num_labels_task1=3, num_labels_task2=4)
model = torch.load('ticket_classifier.pth', weights_only=False, map_location=torch.device('cpu'))
model.eval()
# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Функция для предсказания
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Получаем предсказания для двух задач
logits_task1, logits_task2 = model(input_ids=input_ids, attention_mask=attention_mask)
# Преобразование логитов в предсказания классов
pred_task1 = torch.argmax(logits_task1, dim=1).item()
pred_task2 = torch.argmax(logits_task2, dim=1).item()
return {"Task 1 Prediction": pred_task1, "Task 2 Prediction": pred_task2}
# Создание интерфейса с Gradio
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Введите текст для анализа..."), # Обновлено на gr.Textbox
outputs=gr.JSON(), # Обновлено на gr.JSON
title="Multi-Task BERT Model",
description="Модель BERT для одновременного решения двух задач: тональность текста и тема.",
)
iface.launch()