DmitriySv's picture
Update app.py
bfa01ee verified
raw
history blame
2.29 kB
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import gradio as gr
class MultiTaskBertModel(nn.Module):
def __init__(self, num_labels_task1, num_labels_task2):
super(MultiTaskBertModel, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.classifier_task1 = nn.Linear(self.bert.config.hidden_size, num_labels_task1)
self.classifier_task2 = nn.Linear(self.bert.config.hidden_size, num_labels_task2)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits_task1 = self.classifier_task1(pooled_output)
logits_task2 = self.classifier_task2(pooled_output)
return logits_task1, logits_task2
# Загрузка сохраненной модели
model = MultiTaskBertModel(num_labels_task1=3, num_labels_task2=4)
model.load_state_dict(torch.load("ticket_classifier.pth"))
model.eval()
# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Функция для предсказания
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Получаем предсказания для двух задач
logits_task1, logits_task2 = model(input_ids=input_ids, attention_mask=attention_mask)
# Преобразование логитов в предсказания классов
pred_task1 = torch.argmax(logits_task1, dim=1).item()
pred_task2 = torch.argmax(logits_task2, dim=1).item()
return {"Task 1 Prediction": pred_task1, "Task 2 Prediction": pred_task2}
# Создание интерфейса с Gradio
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Textbox(lines=2, placeholder="Введите текст для анализа..."),
outputs="json",
title="Multi-Task BERT Model",
description="Модель BERT для одновременного решения двух задач: тональность текста и тема.",
)
iface.launch()