File size: 3,318 Bytes
64b4c39
7e8fde4
64b4c39
aaa5966
64b4c39
83510a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b9c12e
83510a7
 
 
 
eade3ec
 
64b4c39
 
 
 
 
 
 
 
 
cb22199
 
64b4c39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification, BertModel, BertConfig
import torch
import os

class MultiTaskBertModel(torch.nn.Module):
    def __init__(self, bert_model, num_labels_task1, num_labels_task2):
        super(MultiTaskBertModel, self).__init__()
        self.bert = bert_model
        self.classifier_task1 = torch.nn.Linear(self.bert.config.hidden_size, num_labels_task1)
        self.classifier_task2 = torch.nn.Linear(self.bert.config.hidden_size, num_labels_task2)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output

        logits_task1 = self.classifier_task1(pooled_output)
        logits_task2 = self.classifier_task2(pooled_output)

        return logits_task1, logits_task2

    def save_pretrained(self, save_directory):
        # Создаем директорию, если она не существует
        os.makedirs(save_directory, exist_ok=True)
        
        # Сохраняем веса модели
        model_path = os.path.join(save_directory, 'pytorch_model.bin')
        torch.save(self.state_dict(), model_path)
        
        # Сохраняем конфигурацию модели
        config = self.bert.config
        config.save_pretrained(save_directory)
    
    @classmethod
    def from_pretrained(cls, load_directory, num_labels_task1, num_labels_task2):
        # Загружаем конфигурацию BERT
        config = BertConfig.from_pretrained(load_directory)
        
        # Загружаем BERT модель
        bert_model = BertModel.from_pretrained(load_directory, config=config)
        
        # Создаем экземпляр кастомной модели
        model = cls(bert_model, num_labels_task1, num_labels_task2)
        
        # Загружаем сохраненные веса
        model_path = os.path.join('pytorch_model.bin')
        model.load_state_dict(torch.load(model_path))
        
        return model

model = MultiTaskBertModel.from_pretrained("DmitriySv/ticket_classifer", 28, 3)
tokenizer = BertTokenizer.from_pretrained("DmitriySv/ticket_classifer")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

def classify(text):
    inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        print(outputs)
        logits_task1, logits_task2 = model(**inputs)

    pred_task1 = torch.argmax(logits_task1, dim=1).item()
    pred_task2 = torch.argmax(logits_task2, dim=1).item()

    return {"Тип": pred_task1, "Приоритет": pred_task2}

interface = gr.Interface(
    fn=classify,  
    inputs=gr.Textbox(label="Введите запрос для классификации"),
    outputs=[gr.Label(label="Тип"), gr.Label(label="Приоритет")],
    title="Классификация запроса по типу и приоритету",
    description="Классификация запроса по типу и приоритету."
)

interface.launch()