DmitriySv's picture
Update app.py
6b9c12e verified
raw
history blame
3.32 kB
import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification, BertModel, BertConfig
import torch
import os
class MultiTaskBertModel(torch.nn.Module):
def __init__(self, bert_model, num_labels_task1, num_labels_task2):
super(MultiTaskBertModel, self).__init__()
self.bert = bert_model
self.classifier_task1 = torch.nn.Linear(self.bert.config.hidden_size, num_labels_task1)
self.classifier_task2 = torch.nn.Linear(self.bert.config.hidden_size, num_labels_task2)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = outputs.pooler_output
logits_task1 = self.classifier_task1(pooled_output)
logits_task2 = self.classifier_task2(pooled_output)
return logits_task1, logits_task2
def save_pretrained(self, save_directory):
# Создаем директорию, если она не существует
os.makedirs(save_directory, exist_ok=True)
# Сохраняем веса модели
model_path = os.path.join(save_directory, 'pytorch_model.bin')
torch.save(self.state_dict(), model_path)
# Сохраняем конфигурацию модели
config = self.bert.config
config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, load_directory, num_labels_task1, num_labels_task2):
# Загружаем конфигурацию BERT
config = BertConfig.from_pretrained(load_directory)
# Загружаем BERT модель
bert_model = BertModel.from_pretrained(load_directory, config=config)
# Создаем экземпляр кастомной модели
model = cls(bert_model, num_labels_task1, num_labels_task2)
# Загружаем сохраненные веса
model_path = os.path.join('pytorch_model.bin')
model.load_state_dict(torch.load(model_path))
return model
model = MultiTaskBertModel.from_pretrained("DmitriySv/ticket_classifer", 28, 3)
tokenizer = BertTokenizer.from_pretrained("DmitriySv/ticket_classifer")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
def classify(text):
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
print(outputs)
logits_task1, logits_task2 = model(**inputs)
pred_task1 = torch.argmax(logits_task1, dim=1).item()
pred_task2 = torch.argmax(logits_task2, dim=1).item()
return {"Тип": pred_task1, "Приоритет": pred_task2}
interface = gr.Interface(
fn=classify,
inputs=gr.Textbox(label="Введите запрос для классификации"),
outputs=[gr.Label(label="Тип"), gr.Label(label="Приоритет")],
title="Классификация запроса по типу и приоритету",
description="Классификация запроса по типу и приоритету."
)
interface.launch()