DmitriySv commited on
Commit
09de64d
·
verified ·
1 Parent(s): 549ac53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -5
app.py CHANGED
@@ -26,6 +26,43 @@ model.eval()
26
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
27
 
28
  # Функция для предсказания
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def predict(text):
30
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
31
  input_ids = inputs["input_ids"]
@@ -38,15 +75,14 @@ def predict(text):
38
  pred_task1 = torch.argmax(logits_task1, dim=1).item()
39
  pred_task2 = torch.argmax(logits_task2, dim=1).item()
40
 
41
- return {"Task 1 Prediction": pred_task1, "Task 2 Prediction": pred_task2}
42
 
43
- # Создание интерфейса с Gradio
44
  iface = gr.Interface(
45
  fn=predict,
46
- inputs=gr.Textbox(lines=2, placeholder="Введите текст для анализа..."), # Обновлено на gr.Textbox
47
  outputs=gr.JSON(), # Обновлено на gr.JSON
48
- title="Multi-Task BERT Model",
49
- description="Модель BERT для одновременного решения двух задач: тональность текста и тема.",
50
  )
51
 
52
  iface.launch()
 
26
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
27
 
28
  # Функция для предсказания
29
+
30
+
31
+ label_mapping_type = {'варианты доставки': 0,
32
+ 'варианты оплаты': 1,
33
+ 'возврат средств': 2,
34
+ 'восстановление пароля': 3,
35
+ 'время доставки': 4,
36
+ 'выбор адреса доставки': 5,
37
+ 'жалоба': 6,
38
+ 'изменение адреса доставки': 7,
39
+ 'изменение заказа': 8,
40
+ 'отзыв': 9,
41
+ 'отмена заказа': 10,
42
+ 'отслеживание возврата средств': 11,
43
+ 'отслеживание заказа': 12,
44
+ 'подписка на новостную рассылку': 13,
45
+ 'политика возврата': 14,
46
+ 'получение информации': 15,
47
+ 'проблемы с оплатой': 16,
48
+ 'проблемы с регистрацией': 17,
49
+ 'проверка платы за отмену': 18,
50
+ 'проверка счета': 19,
51
+ 'проверка счетов': 20,
52
+ 'размещение заказа': 21,
53
+ 'редактирование учетной записи': 22,
54
+ 'связь с человеком': 23,
55
+ 'связь со службой поддержки': 24,
56
+ 'смена учетной записи': 25,
57
+ 'создание учетной записи': 26,
58
+ 'удаление аккаунта': 27}
59
+
60
+ label_mapping_priority = {'высокий': 0, 'низкий': 1, 'средний': 2}
61
+
62
+ def get_key_by_value(dictionary, value):
63
+ reverse_dict = {v: k for k, v in dictionary.items()}
64
+ return reverse_dict.get(value)
65
+
66
  def predict(text):
67
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
68
  input_ids = inputs["input_ids"]
 
75
  pred_task1 = torch.argmax(logits_task1, dim=1).item()
76
  pred_task2 = torch.argmax(logits_task2, dim=1).item()
77
 
78
+ return {"Тема": get_key_by_value(label_mapping_type, pred_task1), "Приоритет": get_key_by_value(label_mapping_priority, pred_task2)}
79
 
 
80
  iface = gr.Interface(
81
  fn=predict,
82
+ inputs=gr.Textbox(lines=2, placeholder="Введите запрос для анализа..."), # Обновлено на gr.Textbox
83
  outputs=gr.JSON(), # Обновлено на gr.JSON
84
+ title="Классификация запроса",
85
+ description="'варианты доставки', 'варианты оплаты', 'возврат средств', 'восстановление пароля', 'время доставки', 'выбор адреса доставки', 'жалоба', 'изменение адреса доставки', 'изменение заказа', 'отзыв', 'отмена заказа', 'отслеживание возврата средств', 'отслеживание заказа', 'подписка на новостную рассылку', 'политика возврата', 'получение информации', 'проблемы с оплатой', 'проблемы с регистрацией', 'проверка платы за отмену', 'проверка счета', 'размещение заказа', 'редактирование учетной записи', 'связь с человеком', 'связь со службой поддержки', 'смена учетной записи', 'создание учетной записи', 'удаление аккаунта'",
86
  )
87
 
88
  iface.launch()