DmitriySv commited on
Commit
3259969
·
verified ·
1 Parent(s): cd1dc6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -64
app.py CHANGED
@@ -1,79 +1,51 @@
1
- import gradio as gr
2
- from transformers import BertTokenizer, BertForSequenceClassification, BertModel, BertConfig
3
  import torch
4
- import os
 
5
 
6
- class MultiTaskBertModel(torch.nn.Module):
7
- def __init__(self, bert_model, num_labels_task1, num_labels_task2):
8
  super(MultiTaskBertModel, self).__init__()
9
- self.bert = bert_model
10
- self.classifier_task1 = torch.nn.Linear(self.bert.config.hidden_size, num_labels_task1)
11
- self.classifier_task2 = torch.nn.Linear(self.bert.config.hidden_size, num_labels_task2)
12
-
13
- def forward(self, input_ids, attention_mask=None, token_type_ids=None):
14
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
15
  pooled_output = outputs.pooler_output
16
-
17
  logits_task1 = self.classifier_task1(pooled_output)
18
  logits_task2 = self.classifier_task2(pooled_output)
19
-
20
  return logits_task1, logits_task2
21
 
22
- def save_pretrained(self, save_directory):
23
- # Создаем директорию, если она не существует
24
- os.makedirs(save_directory, exist_ok=True)
25
-
26
- # Сохраняем веса модели
27
- model_path = os.path.join(save_directory, 'pytorch_model.bin')
28
- torch.save(self.state_dict(), model_path)
29
-
30
- # Сохраняем конфигурацию модели
31
- config = self.bert.config
32
- config.save_pretrained(save_directory)
33
-
34
- @classmethod
35
- def from_pretrained(cls, load_directory, num_labels_task1, num_labels_task2):
36
- # Загружаем конфигурацию BERT
37
- config = BertConfig.from_pretrained(load_directory)
38
-
39
- # Загружаем BERT модель
40
- bert_model = BertModel.from_pretrained(load_directory, config=config)
41
-
42
- # Создаем экземпляр кастомной модели
43
- model = cls(bert_model, num_labels_task1, num_labels_task2)
44
-
45
- # Загружаем сохраненные веса
46
- model_path = os.path.join('pytorch_model.bin')
47
- model.load_state_dict(torch.load(model_path))
48
-
49
- return model
50
-
51
- model = MultiTaskBertModel.from_pretrained("DmitriySv/ticket_classifer", 28, 3)
52
- tokenizer = BertTokenizer.from_pretrained("DmitriySv/ticket_classifer")
53
-
54
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
55
- model = model.to(device)
56
  model.eval()
57
 
58
- def classify(text):
59
- inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
60
-
61
- with torch.no_grad():
62
- outputs = model(**inputs)
63
- print(outputs)
64
- logits_task1, logits_task2 = model(**inputs)
65
 
 
 
 
 
 
 
 
 
 
 
66
  pred_task1 = torch.argmax(logits_task1, dim=1).item()
67
  pred_task2 = torch.argmax(logits_task2, dim=1).item()
68
-
69
- return {"Тип": pred_task1, "Приоритет": pred_task2}
70
-
71
- interface = gr.Interface(
72
- fn=classify,
73
- inputs=gr.Textbox(label="Введите запрос для классификации"),
74
- outputs=[gr.Label(label="Тип"), gr.Label(label="Приоритет")],
75
- title="Классификация запроса по типу и приоритету",
76
- description="Классификация запроса по типу и приоритету."
 
77
  )
78
 
79
- interface.launch()
 
 
 
1
  import torch
2
+ from transformers import BertTokenizer, BertModel
3
+ import gradio as gr
4
 
5
+ class MultiTaskBertModel(nn.Module):
6
+ def __init__(self, num_labels_task1, num_labels_task2):
7
  super(MultiTaskBertModel, self).__init__()
8
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
9
+ self.classifier_task1 = nn.Linear(self.bert.config.hidden_size, num_labels_task1)
10
+ self.classifier_task2 = nn.Linear(self.bert.config.hidden_size, num_labels_task2)
11
+
12
+ def forward(self, input_ids, attention_mask):
13
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
14
  pooled_output = outputs.pooler_output
 
15
  logits_task1 = self.classifier_task1(pooled_output)
16
  logits_task2 = self.classifier_task2(pooled_output)
 
17
  return logits_task1, logits_task2
18
 
19
+ # Загрузка сохраненной модели
20
+ model = MultiTaskBertModel(num_labels_task1=3, num_labels_task2=4)
21
+ model.load_state_dict(torch.load("ticket_classifier.pth"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  model.eval()
23
 
24
+ # Загрузка токенизатора
25
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
 
 
 
 
26
 
27
+ # Функция для предсказания
28
+ def predict(text):
29
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
30
+ input_ids = inputs["input_ids"]
31
+ attention_mask = inputs["attention_mask"]
32
+
33
+ # Получаем предсказания для двух задач
34
+ logits_task1, logits_task2 = model(input_ids=input_ids, attention_mask=attention_mask)
35
+
36
+ # Преобразование логитов в предсказания классов
37
  pred_task1 = torch.argmax(logits_task1, dim=1).item()
38
  pred_task2 = torch.argmax(logits_task2, dim=1).item()
39
+
40
+ return {"Task 1 Prediction": pred_task1, "Task 2 Prediction": pred_task2}
41
+
42
+ # Создание интерфейса с Gradio
43
+ iface = gr.Interface(
44
+ fn=predict,
45
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Введите текст для анализа..."),
46
+ outputs="json",
47
+ title="Multi-Task BERT Model",
48
+ description="Модель BERT для одновременного решения двух задач: тональность текста и тема.",
49
  )
50
 
51
+ iface.launch()