DmitriySv commited on
Commit
83510a7
·
verified ·
1 Parent(s): cb22199

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -2
app.py CHANGED
@@ -1,8 +1,53 @@
1
  import gradio as gr
2
- from transformers import BertTokenizer, BertForSequenceClassification
3
  import torch
4
 
5
- model = BertForSequenceClassification.from_pretrained("DmitriySv/ticket_classifer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  tokenizer = BertTokenizer.from_pretrained("DmitriySv/ticket_classifer")
7
 
8
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
1
  import gradio as gr
2
+ from transformers import BertTokenizer, BertForSequenceClassification, BertModel
3
  import torch
4
 
5
+ class MultiTaskBertModel(torch.nn.Module):
6
+ def __init__(self, bert_model, num_labels_task1, num_labels_task2):
7
+ super(MultiTaskBertModel, self).__init__()
8
+ self.bert = bert_model
9
+ self.classifier_task1 = torch.nn.Linear(self.bert.config.hidden_size, num_labels_task1)
10
+ self.classifier_task2 = torch.nn.Linear(self.bert.config.hidden_size, num_labels_task2)
11
+
12
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None):
13
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
14
+ pooled_output = outputs.pooler_output
15
+
16
+ logits_task1 = self.classifier_task1(pooled_output)
17
+ logits_task2 = self.classifier_task2(pooled_output)
18
+
19
+ return logits_task1, logits_task2
20
+
21
+ def save_pretrained(self, save_directory):
22
+ # Создаем директорию, если она не существует
23
+ os.makedirs(save_directory, exist_ok=True)
24
+
25
+ # Сохраняем веса модели
26
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
27
+ torch.save(self.state_dict(), model_path)
28
+
29
+ # Сохраняем конфигурацию модели
30
+ config = self.bert.config
31
+ config.save_pretrained(save_directory)
32
+
33
+ @classmethod
34
+ def from_pretrained(cls, load_directory, num_labels_task1, num_labels_task2):
35
+ # Загружаем конфигурацию BERT
36
+ config = BertConfig.from_pretrained(load_directory)
37
+
38
+ # Загружаем BERT модель
39
+ bert_model = BertModel.from_pretrained(load_directory, config=config)
40
+
41
+ # Создаем экземпляр кастомной модели
42
+ model = cls(bert_model, num_labels_task1, num_labels_task2)
43
+
44
+ # Загружаем сохраненные веса
45
+ model_path = os.path.join(load_directory, 'pytorch_model.bin')
46
+ model.load_state_dict(torch.load(model_path))
47
+
48
+ return model
49
+
50
+ model = MultiTaskBertModel.from_pretrained("DmitriySv/ticket_classifer")
51
  tokenizer = BertTokenizer.from_pretrained("DmitriySv/ticket_classifer")
52
 
53
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')